This project aims to create a simple and scalable repo, to reproduce [Sora](https://openai.com/sora) (OpenAI, but we prefer to call it "ClosedAI" ).
本项目希望通过开源社区的力量复现Sora,由北大-兔展AIGC联合实验室共同发起,来自兔展、华为、鹏城实验室和开源社区伙伴均有深度贡献力量。
当前V1.5版本**完全基于华为昇腾训练(昇腾纯血版)**,欢迎Pull Request和使用!
我们正在快速迭代新版本,欢迎更多合作者或算法工程师加入,[算法工程师招聘-兔展智能.pdf](https://github.com/user-attachments/files/19107972/-.pdf)
If you like our project, please give us a star ⭐ on GitHub for latest update.
# 📣 News
* **[2026.03.08]** 👋👋👋 We introduce [Helios](https://github.com/PKU-YuanGroup/Helios), a breakthrough video generation model that achieves minute-scale, high-quality video synthesis at **19.5 FPS on a single H100** GPU — without relying on conventional long video anti-drifting strategies or standard video acceleration techniques. Welcome to check [Technical Report](https://huggingface.co/papers/2603.04379)!
* **[2025.06.05]** 🔥🔥🔥 We release version 1.5.0, our most powerful model! By introducing a **higher-compression WFVAE** and an improved sparse DiT architecture, **SUV**, we achieve performance **comparable to HunyuanVideo (Open-Source)** using an 8B-scale model and 40 million video samples. Version 1.5.0 is **fully trained and inferred on Ascend 910-series accelerators**; Please check the [mindspeed_mmdit](https://github.com/PKU-YuanGroup/Open-Sora-Plan/tree/mindspeed_mmdit) branch for our new code and [Report-v1.5.0.md](docs/Report-v1.5.0.md) for our report. The GPU version is coming soon.
* **[2024.12.03]** ⚡️ We released our [arxiv paper](https://arxiv.org/abs/2412.00131) and WF-VAE [paper](https://arxiv.org/abs/2411.17459) for v1.3. The next more powerful version is coming soon.
* **[2024.10.16]** 🎉 We released version 1.3.0, featuring: **WFVAE**, **prompt refiner**, **data filtering strategy**, **sparse attention**, and **bucket training strategy**. We also support 93x480p within **24G VRAM**. More details can be found at our latest [report](docs/Report-v1.3.0.md).
* **[2024.08.13]** 🎉 We are launching Open-Sora Plan v1.2.0 **I2V** model, which is based on Open-Sora Plan v1.2.0. The current version supports image-to-video generation and transition generation (the starting and ending frames conditions for video generation). Check out the Image-to-Video section in this [report](https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/docs/Report-v1.2.0.md#training-image-to-video-diffusion-model).
* **[2024.07.24]** 🔥🔥🔥 v1.2.0 is here! Utilizing a 3D full attention architecture instead of 2+1D. We released a true 3D video diffusion model trained on 4s 720p. Check out our latest [report](docs/Report-v1.2.0.md).
* **[2024.05.27]** 🎉 We are launching Open-Sora Plan v1.1.0, which significantly improves video quality and length, and is fully open source! Please check out our latest [report](docs/Report-v1.1.0.md). Thanks to [ShareGPT4Video's](https://sharegpt4video.github.io/) capability to annotate long videos.
* **[2024.04.09]** 🤝 Excited to share our latest exploration on metamorphic time-lapse video generation: [MagicTime](https://github.com/PKU-YuanGroup/MagicTime), which learns real-world physics knowledge from time-lapse videos.
* **[2024.04.07]** 🎉🎉🎉 Today, we are thrilled to present Open-Sora-Plan v1.0.0, which significantly enhances video generation quality and text control capabilities. See our [report](docs/Report-v1.0.0.md). Thanks to HUAWEI NPU for supporting us.
* **[2024.03.27]** 🚀🚀🚀 We release the report of [VideoCausalVAE](docs/CausalVideoVAE.md), which supports both images and videos. We present our reconstructed video in this demonstration as follows. The text-to-video model is on the way.
* **[2024.03.01]** 🤗 We launched a plan to reproduce Sora, called Open-Sora Plan! Welcome to **watch** 👀 this repository for the latest updates.
# 😍 Gallery
Text-to-Video Generation of Open-Sora Plan v1.5.0.
### Youtube:
[](https://youtu.be/IiWTdx2EHCY)
### Bilibili:
[](https://www.bilibili.com/video/BV1X77tzxE3b/)
# 😮 Highlights
Open-Sora Plan shows excellent performance in video generation.
### 🔥 WFVAE with higher performance and compression
- With an 8×8×8 downsampling rate, but achieves higher PSNR than the VAE used in Wan2.1. Lowers the training cost for the DiT built upon it.
### 🚀 More powerful sparse dit
- The more powerful sparse attention architecture, SUV, achieves performance close to dense DiT while providing over a 35% speedup.
# 🐳 Resource
| Version | Architecture | Diffusion Model | CausalVideoVAE | Data | Prompt Refiner |
|:---|:---|:---|:---|:---|:---|
| v1.5.0 | SUV (Skiparse 3D) | [121x576x1024](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.5.0/blob/main/MindSpeed/model_ema.pt)[5] | [Anysize_8x8x8_32dim](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.5.0/blob/main/MindSpeed/wfvae_888_dim32.ckpt) | - | - |
| v1.3.0 [4] | Skiparse 3D | [Anysize in 93x640x640](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/any93x640x640)[3], [Anysize in 93x640x640_i2v](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/any93x640x640_i2v)[3] | [Anysize](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/vae)| [prompt_refiner](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner) | [checkpoint](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner)| |
| v1.2.0 | Dense 3D | [93x720p](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/93x720p), [29x720p](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/29x720p)[1], [93x480p](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/93x480p)[1,2], [29x480p](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/29x480p), [1x480p](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/1x480p), [93x480p_i2v](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/93x480p_i2v) | [Anysize](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/vae)| [Annotations](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0) | - |
| v1.1.0 | 2+1D | [221x512x512](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main/221x512x512), [65x512x512](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main/65x512x512) |[Anysize](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main/vae) |[Data and Annotations](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.1.0)| - |
| v1.0.0 | 2+1D | [65x512x512](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x512x512), [65x256x256](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x256x256), [17x256x256](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/17x256x256) | [Anysize](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/vae) | [Data and Annotations](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.0.0)| - |
> [1] Please note that the weights for v1.2.0 29×720p and 93×480p were trained on Panda70M and have not undergone final high-quality data fine-tuning, so they may produce watermarks.
> [2] We fine-tuned 3.5k steps from 93×720p to get 93×480p for community research use.
> [3] The model is trained arbitrarily on stride=32. So keep the resolution of the inference a multiple of 32. Frames need to be 4n+1, e.g. 93, 77, 61, 45, 29, 1 (image).
> [4] Model weights are also available at [OpenMind](https://modelers.cn/models/linbin/Open-Sora-Plan-v1.3.0) and [WiseModel](https://wisemodel.cn/models/PKU-YUAN/Open-Sora-Plan-v1.3.0).
> [5] The current model weights are only compatible with the NPU + MindSpeed-MM framework. Model weights are also available at and [modelers](https://modelers.cn/models/PKU-YUAN-Group/Open-Sora-Plan-v1.5.0/tree/main/MindSpeed).
> [!Warning]
>
>
>
> 🚨 For version 1.2.0, we no longer support 2+1D models.
>
>
# ⚙️ How to start
### GPU
coming soon...
### NPU
Please check out the **[mindspeed_mmdit](https://github.com/PKU-YuanGroup/Open-Sora-Plan/tree/mindspeed_mmdit)** branch and follow the README.md for configuration.
# 📖 Technical report
Please check [Report-v1.5.0.md](docs/Report-v1.5.0.md).
# 💡 How to Contribute
We greatly appreciate your contributions to the Open-Sora Plan open-source community and helping us make it even better than it is now!
For more details, please refer to the [Contribution Guidelines](docs/Contribution_Guidelines.md)
# 👍 Acknowledgement and Related Work
* [Allegro](https://github.com/rhymes-ai/Allegro): Allegro is a powerful text-to-video model that generates high-quality videos up to 6 seconds at 15 FPS and 720p resolution from simple text input based on our Open-Sora Plan. The significance of open-source is becoming increasingly tangible.
* [Latte](https://github.com/Vchitect/Latte): It is a wonderful 2+1D video generation model.
* [PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha): Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis.
* [ShareGPT4Video](https://github.com/InternLM/InternLM-XComposer/tree/main/projects/ShareGPT4Video): Improving Video Understanding and Generation with Better Captions.
* [VideoGPT](https://github.com/wilson1yan/VideoGPT): Video Generation using VQ-VAE and Transformers.
* [DiT](https://github.com/facebookresearch/DiT): Scalable Diffusion Models with Transformers.
* [FiT](https://github.com/whlzy/FiT): Flexible Vision Transformer for Diffusion Model.
* [Positional Interpolation](https://arxiv.org/abs/2306.15595): Extending Context Window of Large Language Models via Positional Interpolation.
# 🔒 License
* See [LICENSE](LICENSE) for details.
## ✨ Star History
[](https://star-history.com/#PKU-YuanGroup/Open-Sora-Plan&Date)
# ✏️ Citing
```bibtex
@article{lin2024open,
title={Open-Sora Plan: Open-Source Large Video Generation Model},
author={Lin, Bin and Ge, Yunyang and Cheng, Xinhua and Li, Zongjian and Zhu, Bin and Wang, Shaodong and He, Xianyi and Ye, Yang and Yuan, Shenghai and Chen, Liuhan and others},
journal={arXiv preprint arXiv:2412.00131},
year={2024}
}
```
```bibtex
@article{helios,
title={Helios: Real Real-Time Long Video Generation Model},
author={Yuan, Shenghai and Yin, Yuanyang and Li, Zongjian and Huang, Xinwei and Yang, Xiao and Yuan, Li},
journal={arXiv preprint arXiv:2603.04379},
year={2026}
}
```
```bibtex
@article{li2024wf,
title={WF-VAE: Enhancing Video VAE by Wavelet-Driven Energy Flow for Latent Video Diffusion Model},
author={Li, Zongjian and Lin, Bin and Ye, Yang and Chen, Liuhan and Cheng, Xinhua and Yuan, Shenghai and Yuan, Li},
journal={arXiv preprint arXiv:2411.17459},
year={2024}
}
```
# 🤝 Community contributors
================================================
FILE: docs/Contribution_Guidelines.md
================================================
# Contributing to the Open-Sora Plan Community
The Open-Sora Plan open-source community is a collaborative initiative driven by the community, emphasizing a commitment to being free and void of exploitation. Organized spontaneously by community members, we invite you to contribute to the Open-Sora Plan open-source community and help elevate it to new heights!
## Submitting a Pull Request (PR)
As a contributor, before submitting your request, kindly follow these guidelines:
1. Start by checking the [Open-Sora Plan GitHub](https://github.com/PKU-YuanGroup/Open-Sora-Plan/pulls) to see if there are any open or closed pull requests related to your intended submission. Avoid duplicating existing work.
2. [Fork](https://github.com/PKU-YuanGroup/Open-Sora-Plan/fork) the [open-sora plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) repository and download your forked repository to your local machine.
```bash
git clone [your-forked-repository-url]
```
3. Add the original Open-Sora Plan repository as a remote to sync with the latest updates:
```bash
git remote add upstream https://github.com/PKU-YuanGroup/Open-Sora-Plan
```
4. Sync the code from the main repository to your local machine, and then push it back to your forked remote repository.
```
# Pull the latest code from the upstream branch
git fetch upstream
# Switch to the main branch
git checkout main
# Merge the updates from the upstream branch into main, synchronizing the local main branch with the upstream
git merge upstream/main
# Additionally, sync the local main branch to the remote branch of your forked repository
git push origin main
```
> Note: Sync the code from the main repository before each submission.
5. Create a branch in your forked repository for your changes, ensuring the branch name is meaningful.
```bash
git checkout -b my-docs-branch main
```
6. While making modifications and committing changes, adhere to our [Commit Message Format](#Commit-Message-Format).
```bash
git commit -m "[docs]: xxxx"
```
7. Push your changes to your GitHub repository.
```bash
git push origin my-docs-branch
```
8. Submit a pull request to `Open-Sora-Plan:main` on the GitHub repository page.
## Commit Message Format
Commit messages must include both `` and `` sections.
```bash
[]:
│ │
│ └─⫸ Briefly describe your changes, without ending with a period.
│
└─⫸ Commit Type: |docs|feat|fix|refactor|
```
### Type
* **docs**: Modify or add documents.
* **feat**: Introduce a new feature.
* **fix**: Fix a bug.
* **refactor**: Restructure code, excluding new features or bug fixes.
### Summary
Describe modifications in English, without ending with a period.
> e.g., git commit -m "[docs]: add a contributing.md file"
This guideline is borrowed by [minisora](https://github.com/mini-sora/minisora). We sincerely appreciate MiniSora authors for their awesome templates.
================================================
FILE: docs/Prompt_Refiner.md
================================================
## Data
We have open-sourced our dataset of 32,555 pairs, which includes Chinese data. The dataset is available [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner). The details can be found [here](https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/docs/Report-v1.3.0.md#prompt-refiner).
In fact, it is a JSON file with the following structure.
```
[
{
"instruction": "Refine the sentence: \"A newly married couple sharing a piece of there wedding cake.\" to contain subject description, action, scene description. (Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. Make sure it is a fluent sentence, not nonsense.",
"input": "",
"output": "The newlywed couple, dressed in elegant attire..."
},
...
]
```
## Train
`--data_path` is the path to the prepared JSON file.
`--model_path` is the directory containing the LLaMA 3.1 weights, including `config.json` and some weight files.
`--lora_out_path` is the path where the LoRA model will be saved.
```
cd opensora/models/prompt_refiner
CUDA_VISIBLE_DEVICES=0 python train.py \
--data_path path/to/data.json \
--model_path path/to/llama_model \
--lora_out_path path/to/save/lora_model
```
## Merge
`--model_path` is the directory containing the LLaMA 3.1 weights, including `config.json` and some weight files.
`--lora_in_path` is the directory containing the pre-trained LoRA model.
`--lora_out_path` is the path for the merged model.
```
cd opensora/models/prompt_refiner
CUDA_VISIBLE_DEVICES=0 python merge.py \
--base_path path/to/llama_model \
--lora_in_path path/to/save/lora_model \
--lora_out_path path/to/save/merge_model
```
## Inference
`--model_path` is the directory containing the weights (LLaMA 3.1 or merged Lora weight), including `config.json` and some weight files.
`--prompt` is the text you want to input, which will be refined.
```
cd opensora/models/prompt_refiner
CUDA_VISIBLE_DEVICES=0 python merge.py \
--mode_path path/to/data.json \
--prompt path/to/save/lora_model
```
================================================
FILE: docs/Report-v1.0.0-cn.md
================================================
# 技术报告 v1.0.0
在2024年3月,我们推出了Open-Sora-Plan,一个旨在复现OpenAI [Sora](https://openai.com/sora)的开源计划。它作为一个基础的开源框架,能够训练视频生成模型包括无条件视频生成,类别引导视频生成,文生视频。
**今天,我们兴奋地展示Open-Sora-Plan v1.0.0,极大地改进视频生成质量、文本控制能力。**
相比于之前的视频生成模型,Open-Sora-Plan v1.0.0 有以下的改进:
1. **CausalVideoVAE高效的训练与推理**。 我们用4×8×8的对视频进行时间和空间的压缩。
2. **图片视频联合训练提升视觉质量**。 CasualVideoVAE 将首帧看作图片,天然支持同时编码图片和视频。这允许扩散模型提取更多时空细节来改善质量。
### Open-Source Release
我们开源了Open-Sora-Plan去促进视频生成社区的进一步发展。公开代码、数据、模型。
- 在线演示:Hugging Face [](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.0.0), [](https://replicate.com/camenduru/open-sora-plan-512x512) 和 [](https://colab.research.google.com/github/camenduru/Open-Sora-Plan-jupyter/blob/main/Open_Sora_Plan_jupyter.ipynb), 感谢[@camenduru](https://github.com/camenduru)大力支持我们的工作!🤝
- 代码:所有训练脚本和采样代码。
- 模型:包括扩散模型和CausalVideoVAE [这里](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0)。
- 数据:所有原视频和对应描述 [这里](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.0.0)。
## 效果
Open-Sora-Plan v1.0.0支持图片视频联合训练。我们在此展示视频和图片的重建以及生成:
720×1280**视频重建**。 因为github的限制,原视频放在: [1](https://streamable.com/gqojal), [2](https://streamable.com/6nu3j8).
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/c100bb02-2420-48a3-9d7b-4608a41f14aa
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/8aa8f587-d9f1-4e8b-8a82-d3bf9ba91d68
1536×1024**图片重建**
65×1024×1024**文生视频**
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/2641a8aa-66ac-4cda-8279-86b2e6a6e011
65×512×512**文生视频**
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/37e3107e-56b3-4b09-8920-fa1d8d144b9e
512×512**文生视频**

## 详细技术报告
### CausalVideoVAE
#### 模型结构

因果VAE架构继承了[Stable-Diffusion Image VAE](https://github.com/CompVis/stable-diffusion/tree/main)。 为了保证图片VAE的预训练权重可以无缝运用到视频VAE中,模型结构采取如下设计:
1. **CausalConv3D**: 将Conv2D 转变成CausalConv3D可以实现图片和视频的联合训练. CausalConv3D 对第一帧进行特殊处理,因为它无法访问后续帧。对于更多细节,请参考https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/145
2. **初始化**:将Conv2D扩展到Conv3D常用的[方法](https://github.com/hassony2/inflated_convnets_pytorch/blob/master/src/inflate.py#L5)有两种:平均初始化和中心初始化。 但我们采用了特定的初始化方法(尾部初始化)。 这种初始化方法确保模型无需任何训练就能够直接重建图像,甚至视频。
#### 训练细节
我们展示了 17×256×256 下两种不同初始化方法的损失曲线。黄色曲线代表使用尾部初始化的损失,而蓝色曲线对应中心初始化的损失。 如图所示,尾部初始化在损失曲线上表现出更好的性能。 此外,我们发现中心初始化会导致错误累积,导致长时间内崩溃。
#### 推理技巧
尽管训练Diffusion中VAE始终是冻住的,我们仍然无法负担CasualVideoVAE的花销。在我们的实验中, 80G的显存只能够在半精度下推理一个256×512×512或32×1024×1024的视频 ,这限制了我们扩展到更长更高清的视频。因此我们采用tile convolution,能够以几乎恒定的内存推理任意时长或任意分辨率的视频。
### 数据构建
我们定义高质量的视频数据集包括两个核心法则:(1) 没有与内容无关的水印。(2) 高质量的文本注释。
**对于法则1**,我们从开源网站(CC0协议)爬取了大约40k videos:1234个来自[mixkit](https://mixkit.co/),7408个来自[pexels](https://www.pexels.com/),31616个来自[pixabay](https://pixabay.com/)。我们根据[Panda70M](https://github.com/snap-research/Panda-70M/blob/main/splitting/README.md)提供的场景变换剪切script将这些视频切成大约434k video clips。事实上,根据我们的剪切结果,从这些网上上爬取的99%的视频都是单一的场景。另外,我们发现爬取的数据中超过60%为风景相关视频。更多细节可以在[这](https://github.com/PKU-YuanGroup/Open-Sora-Dataset)找到。
**对于法则2**,很难有大量的高质量的文本注释能够从网上直接爬取。因此我们用成熟的图片标注模型来获取高质量的稠密描述。我们对2个多模态大模型进行消融实验:[ShareGPT4V-Captioner-7B](https://github.com/InternLM/InternLM-XComposer/blob/main/projects/ShareGPT4V/README.md) 和 [LLaVA-1.6-34B](https://github.com/haotian-liu/LLaVA)。前者是专门用来制作文本注释的模型,而后者是一个通用的多模态大模型。经过我们的消融实验,他们在caption的表现差不多。然而他们的推理速度在A800上差距很大:40s/it of batch size of 12 for ShareGPT4V-Captioner-7B,15s/it of batch size of 1 for ShareGPT4V-Captioner-7B。我们开源所有的[文本注释和原视频](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.0.0)。
| 模型名字 | 平均长度 | 最大值 | 标准差 |
|---|---|---|---|
| ShareGPT4V-Captioner-7B | 170.0827524529121 | 467 | 53.689967539537776 |
| LLaVA-1.6-34B | 141.75851073472666 | 472 | 48.52492072346965 |
### 训练扩散模型
与之前的工作类似,我们采用多阶段的级联的训练方法,总共消耗了2048个A800 GPU 小时。我们发现联合图片训练能够显著加速模型的收敛并且增强视觉观感,这与[Latte](https://github.com/Vchitect/Latte)一致。以下是我们的训练花销。
| 名字 | Stage 1 | Stage 2 | Stage 3 | Stage 4 |
|---|---|---|---|---|
| 训练视频尺寸 | 17×256×256 | 65×256×256 | 65×512×512 | 65×1024×1024 |
| 计算资源 (#A800 GPU x #小时) | 32 × 40 | 32 × 18 | 32 × 6 | 训练中 |
| 权重 | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/17x256x256) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x256x256) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x512x512) | 训练中 |
| 日志 | [wandb](https://api.wandb.ai/links/linbin/p6n3evym) | [wandb](https://api.wandb.ai/links/linbin/t2g53sew) | [wandb](https://api.wandb.ai/links/linbin/uomr0xzb) | 训练中 |
| 训练数据 | ~40k videos | ~40k videos | ~40k videos | ~40k videos |
## 下版本预览
### CausalVideoVAE
目前我们发布的CausalVideoVAE v1.0.0版本存在2个主要的缺陷:**运动模糊**以及**网格效应**。我们对CasualVideoVAE做了一系列的改进使它推理成本更低且性能更强大,我们暂时叫它为预览版本,将在下个版本发布。
**1分钟720×1280视频重建**。 受限于GitHub,我们将原视频放在这:[原视频](https://streamable.com/u4onbb),[重建视频](https://streamable.com/qt8ncc)。
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/cdcfa9a3-4de0-42d4-94c0-0669710e407b
我们从kinetic 400的验证集中随机选取100个样本进行评估,结果表如下所示:
| | SSIM↑ | LPIPS↓ | PSNR↑ | FLOLPIPS↓ |
|---|---|---|---|---|
| v1.0.0 | 0.829 | 0.106 | 27.171 | 0.119 |
| Preview | 0.877 | 0.064 | 29.695 | 0.070 |
#### 运动模糊
| **v1.0.0** | **预览版本** |
| --- | --- |
|  |  |
#### 网格效应
| **v1.0.0** | **预览版本** |
| --- | --- |
|  |  |
### 数据构建
**数据源**:正如上文提到,我们的数据集中超过60%为风景视频。这意味着我们的开域视频生成能力有限。然而当前的大规模开源数据集大多从YouTube爬取,尽管视频的数量多,但我们担忧视频本身的质量是否达标。因此,我们将继续收集高质量的数据集,同时也欢迎开源社区的推荐。
**Caption生成流程**:当我们训练时长增加时,我们不得不考虑更有效的视频caption生成方法,而不是多模态图片大模型。我们正在开发一个新的视频注释生成管线,它能够很好的支持长视频,敬请期待。
### 训练扩散模型
尽管目前v1.0.0展现了可喜的结果,但我们仍然离Sora有一段距离。在接下来的工作中,我们主要围绕这三个方面:
1. **动态分辨率与时长的训练**: 我们的目标是开发出能够以不同分辨率和持续时间训练模型的技术,使训练过程更加灵活、适应性更强。
2. **更长的视频生成**: 我们将探索扩展模型生成能力的方法,使其能够制作更长的视频,超越目前的限制。
3. **更多条件控制**: 我们力求增强模型的条件控制能力,为用户提供更多的选项和对生成视频的控制能力。
另外,通过仔细观察生成的视频,我们发现存在一些不符合常理的斑点或异常的流动,这是由于CasualVideoVAE的性能不足导致的 如上面提到。在未来的实验中,我们将使用更强的VAE,重新训练一个扩散模型。
================================================
FILE: docs/Report-v1.0.0.md
================================================
# Report v1.0.0
In March 2024, we launched a plan called Open-Sora-Plan, which aims to reproduce the OpenAI [Sora](https://openai.com/sora) through an open-source framework. As a foundational open-source framework, it enables training of video generation models, including Unconditioned Video Generation, Class Video Generation, and Text-to-Video Generation.
**Today, we are thrilled to present Open-Sora-Plan v1.0.0, which significantly enhances video generation quality and text control capabilities.**
Compared with previous video generation model, Open-Sora-Plan v1.0.0 has several improvements:
1. **Efficient training and inference with CausalVideoVAE**. We apply a spatial-temporal compression to the videos by 4×8×8.
2. **Joint image-video training for better quality**. Our CausalVideoVAE considers the first frame as an image, allowing for the simultaneous encoding of both images and videos in a natural manner. This allows the diffusion model to grasp more spatial-visual details to improve visual quality.
### Open-Source Release
We open-source the Open-Sora-Plan to facilitate future development of Video Generation in the community. Code, data, model are made publicly available.
- Demo: Hugging Face demo [](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.0.0). 🤝 Enjoying the [](https://replicate.com/camenduru/open-sora-plan-512x512) and [](https://colab.research.google.com/github/camenduru/Open-Sora-Plan-jupyter/blob/main/Open_Sora_Plan_jupyter.ipynb), created by [@camenduru](https://github.com/camenduru), who generously supports our research!
- Code: All training scripts and sample scripts.
- Model: Both Diffusion Model and CausalVideoVAE [here](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0).
- Data: Both raw videos and captions [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.0.0).
## Gallery
Open-Sora-Plan v1.0.0 supports joint training of images and videos. Here, we present the capabilities of Video/Image Reconstruction and Generation:
### CausalVideoVAE Reconstruction
**Video Reconstruction** with 720×1280. Since github can't upload large video, we put it here: [1](https://streamable.com/gqojal), [2](https://streamable.com/6nu3j8).
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/c100bb02-2420-48a3-9d7b-4608a41f14aa
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/8aa8f587-d9f1-4e8b-8a82-d3bf9ba91d68
**Image Reconstruction** in 1536×1024.
**Text-to-Video Generation** with 65×1024×1024
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/2641a8aa-66ac-4cda-8279-86b2e6a6e011
**Text-to-Video Generation** with 65×512×512
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/37e3107e-56b3-4b09-8920-fa1d8d144b9e
**Text-to-Image Generation** with 512×512

## Detailed Technical Report
### CausalVideoVAE
#### Model Structure

The CausalVideoVAE architecture inherits from the [Stable-Diffusion Image VAE](https://github.com/CompVis/stable-diffusion/tree/main). To ensure that the pretrained weights of the Image VAE can be seamlessly applied to the Video VAE, the model structure has been designed as follows:
1. **CausalConv3D**: Converting Conv2D to CausalConv3D enables joint training of image and video data. CausalConv3D applies a special treatment to the first frame, as it does not have access to subsequent frames. For more specific details, please refer to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/145
2. **Initialization**: There are two common [methods](https://github.com/hassony2/inflated_convnets_pytorch/blob/master/src/inflate.py#L5) to expand Conv2D to Conv3D: average initialization and center initialization. But we employ a specific initialization method (tail initialization). This initialization method ensures that without any training, the model is capable of directly reconstructing images, and even videos.
#### Training Details
We present the loss curves for two distinct initialization methods under 17×256×256. The yellow curve represents the loss using tail init, while the blue curve corresponds to the loss from center initialization. As shown in the graph, tail initialization demonstrates better performance on the loss curve. Additionally, we found that center initialization leads to error accumulation, causing the collapse over extended durations.
#### Inference Tricks
Despite the VAE in Diffusion training being frozen, we still find it challenging to afford the cost of the CausalVideoVAE. In our case, with 80GB of GPU memory, we can only infer a video of either 256×512×512 or 32×1024×1024 resolution using half-precision, which limits our ability to scale up to longer and higher-resolution videos. Therefore, we adopt tile convolution, which allows us to infer videos of arbitrary duration or resolution with nearly constant memory usage.
### Data Construction
We define a high-quality video dataset based on two core principles: (1) No content-unrelated watermarks. (2) High-quality and dense captions.
**For principles 1**, we crawled approximately 40,000 videos from open-source websites under the CC0 license. Specifically, we obtained 1,234 videos from [mixkit](https://mixkit.co/), 7,408 videos from [pexels](https://www.pexels.com/), and 31,616 videos from [pixabay](https://pixabay.com/). These videos adhere to the principle of having no content-unrelated watermarks. According to the scene transformation and clipping script provided by [Panda70M](https://github.com/snap-research/Panda-70M/blob/main/splitting/README.md), we have divided these videos into approximately 434,000 video clips. In fact, based on our clipping results, 99% of the videos obtained from these online sources are found to contain single scenes. Additionally, we have observed that over 60% of the crawled data comprises landscape videos. More details can be found [here](https://github.com/PKU-YuanGroup/Open-Sora-Dataset).
**For principles 2**, it is challenging to directly crawl a large quantity of high-quality dense captions from the internet. Therefore, we utilize a mature Image-captioner model to obtain high-quality dense captions. We conducted ablation experiments on two multimodal large models: [ShareGPT4V-Captioner-7B](https://github.com/InternLM/InternLM-XComposer/blob/main/projects/ShareGPT4V/README.md) and [LLaVA-1.6-34B](https://github.com/haotian-liu/LLaVA). The former is specifically designed for caption generation, while the latter is a general-purpose multimodal large model. After conducting our ablation experiments, we found that they are comparable in performance. However, there is a significant difference in their inference speed on the A800 GPU: 40s/it of batch size of 12 for ShareGPT4V-Captioner-7B, 15s/it of batch size of 1 for LLaVA-1.6-34B. We open-source all annotations [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.0.0). We show some statistics here, and we set the maximum length of the model to 300, which covers almost 99% of the samples.
| Name | Avg length | Max | Std |
|---|---|---|---|
| ShareGPT4V-Captioner-7B | 170.0827524529121 | 467 | 53.689967539537776 |
| LLaVA-1.6-34B | 141.75851073472666 | 472 | 48.52492072346965 |
### Training Diffusion Model
Similar to previous work, we employ a multi-stage cascaded training approach, which consumes a total of 2,528 A800 GPU hours. We found that joint training with images significantly accelerates model convergence and enhances visual perception, aligning with the findings of [Latte](https://github.com/Vchitect/Latte). Below is our training card:
| Name | Stage 1 | Stage 2 | Stage 3 | Stage 4 |
|---|---|---|---|---|
| Training Video Size | 17×256×256 | 65×256×256 | 65×512×512 | 65×1024×1024 |
| Compute (#A800 GPU x #Hours) | 32 × 40 | 32 × 22 | 32 × 17 | Under training |
| Checkpoint | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/17x256x256) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x256x256) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x512x512) | Under training |
| Log | [wandb](https://api.wandb.ai/links/linbin/p6n3evym) | [wandb](https://api.wandb.ai/links/linbin/t2g53sew) | [wandb](https://api.wandb.ai/links/linbin/uomr0xzb) | Under training |
| Training Data | ~40k videos | ~40k videos | ~40k videos | ~40k videos |
## Next Release Preview
### CausalVideoVAE
Currently, the released version of CausalVideoVAE (v1.0.0) has two main drawbacks: **motion blurring** and **gridding effect**. We have made a series of improvements to CausalVideoVAE to reduce its inference cost and enhance its performance. We are currently referring to this enhanced version as the "preview version," which will be released in the next update. Preview reconstruction is as follows:
**1 min Video Reconstruction with 720×1280**. Since github can't put too big video, we put it here: [origin video](https://streamable.com/u4onbb), [reconstruction video](https://streamable.com/qt8ncc).
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/cdcfa9a3-4de0-42d4-94c0-0669710e407b
We randomly selected 100 samples from the validation set of Kinetics-400 for evaluation, and the results are presented in the following table:
| | SSIM↑ | LPIPS↓ | PSNR↑ | FLOLPIPS↓ |
|---|---|---|---|---|
| v1.0.0 | 0.829 | 0.106 | 27.171 | 0.119 |
| Preview | 0.877 | 0.064 | 29.695 | 0.070 |
#### Motion Blurring
| **v1.0.0** | **Preview** |
| --- | --- |
|  |  |
#### Gridding effect
| **v1.0.0** | **Preview** |
| --- | --- |
|  |  |
### Data Construction
**Data source**. As mentioned earlier, over 60% of our dataset consists of landscape videos. This implies that our ability to generate videos in other domains is limited. However, most of the current large-scale open-source datasets are primarily obtained through web scraping from platforms like YouTube. While these datasets provide a vast quantity of videos, we have concerns about the quality of the videos themselves. Therefore, we will continue to collect high-quality datasets and also welcome recommendations from the open-source community. We are launching an Open-Sora-Dataset project, check out the details at [Open-Sora-Dataset](https://github.com/PKU-YuanGroup/Open-Sora-Dataset)
**Caption Generation Pipeline**. As the video duration increases, we need to consider more efficient methods for video caption generation instead of relying solely on large multimodal image models. We are currently developing a new video caption generation pipeline that provides robust support for long videos. We are excited to share more details with you in the near future. Stay tuned!
### Training Diffusion Model
Although v1.0.0 has shown promising results, we acknowledge that we still have a ways to go to reach the level of Sora. In our upcoming work, we will primarily focus on three aspects:
1. **Training support for dynamic resolution and duration**: We aim to develop techniques that enable training models with varying resolutions and durations, allowing for more flexible and adaptable training processes.
2. **Support for longer video generation**: We will explore methods to extend the generation capabilities of our models, enabling them to produce longer videos beyond the current limitations.
3. **Enhanced conditional control**: We seek to enhance the conditional control capabilities of our models, providing users with more options and control over the generated videos.
Furthermore, through careful observation of the generated videos, we have noticed the presence of some non-physiological speckles or abnormal flow. This can be attributed to the limited performance of CausalVideoVAE, as mentioned earlier. In future experiments, we plan to retrain a diffusion model using a more powerful version of CausalVideoVAE to address these issues.
================================================
FILE: docs/Report-v1.1.0.md
================================================
# Report v1.1.0
In April 2024, we launched Open-Sora-Plan v1.0.0, featuring a simple and efficient design along with remarkable performance in text-to-video generation. It has already been adopted as a foundational model in numerous research projects, including its data and model.
**Today, we are excited to present Open-Sora-Plan v1.1.0, which significantly improves video generation quality and duration.**
Compared to the previous version, Open-Sora-Plan v1.1.0, the improvements include:
1. **Better compressed visual representations**. We optimized the CausalVideoVAE architecture, which now has stronger performance and higher inference efficiency.
2. **Generate higher quality, longer videos**. We used higher quality visual data and captions by [ShareGPT4Video](https://sharegpt4video.github.io/), enabling the model to better understand the workings of the world.
Along with performance improvements, Open-Sora-Plan v1.1.0 maintains the minimalist design and data efficiency of v1.0.0. Remarkably, we found that v1.1.0 exhibits similar performance to the Sora base model, indicating that our version's evolution aligns with the scaling law demonstrated by Sora.
### Open-Source Release
We open-source the Open-Sora-Plan to facilitate future development of Video Generation in the community. Code, data, model will be made publicly available.
- Demo: Hugging Face demo [here](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.1.0).
- Code: All training scripts and sample scripts.
- Model: Both Diffusion Model and CasualVideoVAE [here](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.1.0).
- Data: Both raw videos and captions [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.1.0).
## Gallery
### 221×512×512 Text-to-Video Generation
| 221×512×512 (9.2s) | 221×512×512 (9.2s) | 221×512×512 (9.2s) | 221×512×512 (9.2s) |
| --- | --- | --- | --- |
| | | | |
| This close-up shot of a Victoria crowned pigeon showcases its striking blue plumage ... | a cat wearing sunglasses and working as a lifeguard at pool. | Photorealistic closeup video of two pirate ships battling each other as they sail ... | A movie trailer featuring the adventures ofthe 30 year old spacemanwearing a redwool ... |
| | | | |
| A snowy forest landscape with a dirt road running through it. The road is flanked by ... | Drone shot along the Hawaii jungle coastline, sunny day. Kayaks in the water. | Alpacas wearing knit wool sweaters, graffiti background, sunglasses. | The camera rotates around a large stack of vintage televisions all showing different ... |
| | | | |
| A drone camera circles around a beautiful historic church built on a rocky outcropping ... | Aerial view of Santorini during the blue hour, showcasing the stunning architecture ... | A robot dog explores the surface of Mars, kicking up red dust as it investigates ... | An aerial shot of a lighthouse standing tall on a rocky cliff, its beacon cutting ... |
| | | | |
| 3D animation of a small, round, fluffy creature with big, expressive eyes explores ... | A corgi vlogging itself in tropical Maui. | A single drop of liquid metal falls from a floating orb, landing on a mirror-like ... | The video presents an abstract composition centered around a hexagonal shape adorned ... |
### 65×512×512 Text-to-Video Generation
| 65×512×512 (2.7s) | 65×512×512 (2.7s) | 65×512×512 (2.7s) | 65×512×512 (2.7s) |
| --- | --- | --- | --- |
| | | | |
| Extreme close-up of chicken and green pepper kebabs grilling on a barbeque with flames. | 3D animation of a small, round, fluffy creature with big, expressive eyes explores a ... | A corgi vlogging itself in tropical Maui. | In a studio, there is a painting depicting a ship sailing through the rough sea. |
| | | | |
| A robot dog trots down a deserted alley at night, its metallic paws clinking softly ... | A solitary spider weaves its web in a quiet corner. The web shimmers and glows with ... | A lone surfer rides a massive wave, skillfully maneuvering through the surf. The water ... | A solitary cheetah sprints across the savannah, its powerful muscles propelling it ... |
| | | | |
| A solitary astronaut plants a flag on an alien planet covered in crystal formations ... | At dawn's first light, a spaceship slowly exits the edge of the galaxy against a ...| A dapper puppy in a miniature suit, basking in the afternoon sun, adjusting his tie ... | A wise old elephant painting abstract art with its trunk, each stroke a burst of color ... |
| | | | |
| In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two ... | A Shiba Inu dog wearing a beret and black turtleneck. | A painting of a boat on water comes to life, with waves crashing and the boat becoming ... | Many spotted jellyfish pulsating under water. Their bodies are transparent and glowing ... |
| | | | |
| An animated hedgehog with distinctive spiky hair and large eyes is seen exploring a ... | An animated rabbit in a playful pink snowboarding outfit is carving its way down a ... | A person clad in a space suit with a helmet and equipped with a chest light and arm ... | |
### 65×512×512 Video Editing
| generated 65×512×512 (2.7s) | edited 65×512×512 (2.7s) |
| --- | --- |
| | |
| | |
| | |
### 512×512 Text-to-Image Generation
## Detailed Technical Report
### CasualVideoVAE
#### Model Structure
As the number of frames increases, the encoder overhead of CausalVideoVAE gradually rises. When training with 257 frames, 80GB of VRAM is insufficient for the VAE to encode the video. Therefore, we reduced the number of CausalConv3D layers, retaining only the last two stages of CausalConv3D in the encoder. This change significantly lowers the overhead while maintaining nearly the same performance. Note that we only modified the encoder; the decoder still retains all CausalConv3D layers, as training the Diffusion Model does not require the decoder.
We compare the computational overhead of the two versions by testing the forward inference of the encoder on the H100.
| Version | 129×256×256 | | 257×256×256 | | 513×256×256 | |
|---|---|---|---|---|---|---|
| | Peak Mem. | Speed | Peak Mem. | Speed |Peak Mem. | Speed |
| v1.0.0 | 22G | 2.9 it/s | OOM | - | OOM | - |
| v1.1.0 | 18G | 4.9 it/s | 34G | 2.5 it/s | 61G | 1.2 it/s |
#### Temporal Module
In v1.0.0, our temporal module had only TemporalAvgPool. TemporalAvgPool leads to the loss of high-frequency information in the video, such as details and edges. To address this issue, we improved this module in v1.1.0. As shown in the figure below, we introduced convolution and added learnable weights, allowing different branches to decouple different features. When we omit CausalConv3D, the video is reconstructed very blurry. Similarly, when we omit TemporalAvgPool, the video becomes very sharp.
| | SSIM↑ | LPIPS↓ | PSNR↑ |
|---|---|---|---|
| Base | 0.850 | 0.091 | 28.047 |
| + Frames | 0.868 | 0.070 | 28.829 |
| + Reset mixed factor | 0.873 | 0.070 | 29.140 |
#### Training Details
Similar to v1.0.0, we initialized from the Latent Diffusion's VAE and used tail initialization. For CausalVideoVAE, we trained for 100k steps in the first stage with a video shape of 9×256×256. Subsequently, we increased the frame count from 9 to 25 and found that this significantly improved the model's performance. It is important to clarify that we enabled the mixed factor during both the first and second stages, with a value of a (sigmoid(mixed factor)) reaching 0.88 at the end of training, indicating the model's tendency to retain low-frequency information. In the third stage, we reinitialized the mixed factor to 0.5 (sigmoid(0.5) = 0.6225), which further enhanced the model's capabilities.
#### Loss Function
We found that using GAN loss helps retain high-frequency information and alleviates grid artifacts. Additionally, we observed that switching from 2D GAN to 3D GAN provides further improvements.
| GAN Loss/Step | SSIM↑ | LPIPS↓ | PSNR↑ |
|---|---|---|---|
| 2D/80k | 0.879 | 0.068 | 29.480 |
| 3D/80k | 0.882 | 0.067 | 29.890 |
#### Inference Tricks
Therefore, we introduced a method called **temporal rollback tiled convolution**, a tiling approach specifically designed for CausalVideoVAE. Specifically, all windows except the first one discard the first frame because the first frame in a window is treated as an image, while the remaining frames should be treated as video frames.
We tested the speed on the H100 with a window size of 65×256×256.
| Version | 129×256×256 | | 257×256×256 | | 513×256×256 | |
|---|---|---|---|---|---|---|
| | Peak Mem. | Speed | Peak Mem. | Speed |Peak Mem. | Speed |
| 4×8×8 | 10G | 1.3 s/it | 10G | 2.6 s/it | 10G | 5.3 s/it |
### Data Construction
Since Open-Sora-Plan supports joint training of images and videos, our data collection is divided into two parts: images and videos. Images do not need to originate from videos; they are independent datasets. We spent approximately 32×240 H100 hours generating image and video captions, and all of this is **open source**!
#### Image-Text Collection Pipeline
We obtained 11 million image-text pairs from [Pixart-Alpha](https://huggingface.co/datasets/PixArt-alpha/SAM-LLaVA-Captions10M), with captions generated by [LLaVA](https://github.com/haotian-liu/LLaVA). Additionally, we utilized the high-quality OCR dataset [Anytext-3M](https://github.com/tyxsspa/AnyText), which pairs each image with corresponding OCR characters. However, these captions were insufficient to describe the entire image, so we used [InternVL-1.5](https://github.com/OpenGVLab/InternVL) for supplementary descriptions. Since T5 only supports English, we filtered for English data, which constitutes about half of the complete dataset. Furthermore, we selected high-quality images from [Laion-5B](https://laion.ai/blog/laion-5b/) to enhance human-like generation quality. The selection criteria included high resolution, high aesthetic scores, and watermark-free images containing people.
Here, we are open-sourcing the prompt used for InternVL-1.5:
```
# for anytext-3m
Combine this rough caption: "{}", analyze the image in a comprehensive and detailed manner. "{}" can be recognized in the image.
# for human-160k
Analyze the image in a comprehensive and detailed manner.
```
| Name | Image Source | Text Captioner | Num pair |
|---|---|---|---|
| SAM-11M | [SAM](https://ai.meta.com/datasets/segment-anything/) | [LLaVA](https://github.com/haotian-liu/LLaVA) | 11,185,255 |
| Anytext-3M-en | [Anytext](https://github.com/tyxsspa/AnyText) | [InternVL-1.5](https://github.com/OpenGVLab/InternVL) | 1,886,137 |
| Human-160k | [Laion](https://laion.ai/blog/laion-5b/) | [InternVL-1.5](https://github.com/OpenGVLab/InternVL) | 162,094 |
#### Video-Text Collection Pipeline
In v1.0.0, we sampled one frame from each video to generate captions. However, as video length increased, a single frame could not adequately describe the entire video's content or temporal movements. Therefore, we used a video captioner to generate captions for the entire video clip. Specifically, we used [ShareGPT4Video](https://sharegpt4video.github.io/), which effectively covers temporal information and describes the entire video content. The v1.1.0 video dataset comprises approximately 3k hours, compared to only 300 hours in v1.0.0. As before, we have open-sourced all text annotations and videos (both under the CC0 license), which can be found [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main).
| Name | Hours | Num frames | Num pair |
|---|---|---|---|
| [Mixkit](https://mixkit.co/) | 42.0h | 65 | 54,735 |
| | | 513 | 1,997 |
| [Pixabay](https://pixabay.com/) | 353.3h | 65 | 601,513 |
| | | 513 | 51,483 |
| [Pexel](https://www.pexels.com/) | 2561.9h | 65 | 3,832,666 |
| | | 513 | 271,782 |
### Training Diffusion Model
Similar to our previous work, we employed a multi-stage cascaded training method. Below is our training card:
#### Stage 1
Surprisingly, we initially believed that the performance of the diffusion model would improve with longer training. However, by observing the [logs](https://api.wandb.ai/links/linbin/o76j03j4), we found that videos generated at 50k steps were of higher quality than those at 70-100k steps. In fact, extensive sampling revealed that checkpoints at 40-60k steps outperformed those at 80-100k steps. Quantitatively, 50k steps correspond to approximately 2 epochs of training. It is currently unclear whether this is due to overfitting from a small dataset or the limited capacity of the 2+1D model.
#### Stage 2
In the second stage, we used Huawei Ascend computing power for training. This stage's training and inference were fully supported by Huawei. We conducted sequence parallel training and inference on a large-scale cluster, distributing one sample across eight ranks. Models trained on Huawei Ascend can also be loaded into GPUs and generate videos of the same quality.
#### Stage 3
In the third stage, we further increased the frame count to 513 frames, approximately 21 seconds at 24 FPS. However, this stage presents several challenges, such as ensuring temporal consistency in the 2+1D model over long durations and whether the current amount of data is sufficient. We are still training the model for this stage and continuously monitoring its progress.
| Name | Stage 1 | Stage 2 | Stage 3 |
|---|---|---|---|
| Training Video Size | 65×512×512 | 221×512×512 | 513×512×512 |
| Compute (#Num x #Hours) | 80 H100 × 72 | 512 Ascend × 72 | Under Training |
| Checkpoint | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main/65x512x512) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main/221x512x512) | Under Training |
| Log | [wandb](https://api.wandb.ai/links/linbin/o76j03j4) | - | - |
| Training Data | ~3k hours videos + 13M images | | |
### Video Editing
The recently proposed [ReVideo](https://mc-e.github.io/project/ReVideo/) achieves accurate video editing by modifying the first frame and applying motion control within the edited area. Although it achieves excellent video editing performance, the editing length is limited by the base model [SVD](https://github.com/Stability-AI/generative-models). Open-Sora, as a fundamental model for long-video generation, can compensate for this issue. Currently, we are collaborating with the ReVideo team to use Open-Sora as the base model for long video editing. Some preliminary results are shown [here]().
The initial version still needs improvement in several aspects. In the future, we will continue to explore integration with ReVideo to develop improved long-video editing models.
## Failed Case and Discussion
Despite the promising results of v1.1.0, there remains a gap between our model and Sora. Here, we present some failure cases and discuss them.
### CasualVideoVAE
Despite the significant performance improvement of VAE in v1.1.0 over v1.0.0, we still encounter failures in challenging cases, such as sand dunes and leaves. The video on the left shows the reconstructed video downsampled by a factor of 4 in time, while the video on the right is downsampled by a factor of 2. Both exhibit jitter when reconstructing fine-grained features. This indicates that reducing temporal downsampling alone cannot fully resolve the jitter issue.
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/1a87d6d8-4bf1-4b4e-83bb-84870c5c3a11
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/1a87d6d8-4bf1-4b4e-83bb-84870c5c3a11
### Diffusion Model
#### Semantic distortion
On the left is a video generated by v1.1.0 showing a puppy in the snow. In this video, the puppy's head exhibits semantic distortion, indicating that the model struggles to correctly identify which head belongs to which dog. On the right is a video generated by Sora's [base model](https://openai.com/index/video-generation-models-as-world-simulators/). We observe that Sora's early base model also experienced semantic distortion issues. This suggests that we may achieve better results by scaling up the model and increasing the amount of training data.
Prompt:A litter of golden retriever puppies playing in the snow.Their heads pop out of the snow, covered in.
| Our | Sora Base×1 | Sora Base×4 | Sora Base×32 |
|---|---|---|---|
| | | | |
#### Limited dynamics
The primary difference between videos and images lies in their dynamic nature, where objects undergo a series of changes across consecutive frames. However, the videos generated by v1.1.0 still contain many instances of limited dynamics. Upon reviewing a large number of training videos, we found that while web-crawled videos have high visual quality, they are often filled with meaningless close-up shots. These close-ups typically show minimal movement or are even static. On the left, we present a generated video of a bird, while on the right is a training video we found, which is almost static. There are many similar videos in the dataset from stock footage sites.
Prompt:This close-up shot of a Victoria crowned pigeon showcases its striking blue plumage and red chest. Its crest is made of delicate, lacy feathers, while its eye is a striking red color. The bird's head is tilted slightly to the side,giving the impression of it looking regal and majestic. The background is blurred,drawing attention to the bird's striking appearance.
| Our | Raw video |
|---|---|
| | |
#### Negative prompt
We found that using negative prompts can significantly improve video quality, even though we did not explicitly tag the training data with different labels. On the left is a video sampled using a negative prompt, while on the right is a video generated without a negative prompt. This suggests that we may need to incorporate more prior knowledge into the training data. For example, when a video has a watermark, we should note "watermark" in the corresponding caption. When a video's bitrate is too low, we should add more tags to distinguish it from high-quality videos, such as "low quality" or "blurry." We believe that explicitly injecting these priors can help the model differentiate between the vast amounts of pretraining data (low quality) and the smaller amounts of fine-tuning data (high quality), thereby generating higher quality videos.
Prompt:A litter of golden retriever puppies playing in the snow.Their heads pop out of the snow, covered in.
Negative Prompt:distorted, discontinuous, ugly, blurry, low resolution, motionless, static, low quality
| With Negative Prompt | Without Negative Prompt |
|---|---|
| | |
## Future Work
In our future work, we will focus on two main areas: (1) data scaling and (2) model design. Once we have a robust baseline model, we will extend it to handle variable durations and conditional control models.
### Data Scaling
#### Data source
As mentioned earlier, our dataset is entirely sourced from stock footage websites. Although these videos are of high quality, many consist of close-up shots of specific areas, resulting in slow motion in the videos. We believe this is one of the main reasons for the limited dynamics observed. Therefore, we will continue to collect datasets from diverse sources to address this issue.
#### Data volume
In v1.1.0, our dataset comprises only ~3k hours of video. We are actively collecting more data and anticipate that the video dataset for the next version will reach ~100k hours. We welcome recommendations from the open-source community for additional datasets.
### Model Design
#### CasualVideoVAE
In our internal testing, even without downsampling in time, we found that it is not possible to completely resolve the jitter issue in reconstructing fine-grained features. Therefore, we need to reconsider how to mitigate video jitter to the greatest extent possible while simultaneously supporting both images and videos. We will introduce a more powerful CasualVideoVAE in the next version.
#### Diffusion Model
In v1.1.0, we found that 2+1D models can generate higher-quality videos in short durations. However, for long videos, they tend to exhibit discontinuities and inconsistencies. Therefore, we will explore more possibilities in model architecture to address this issue.
================================================
FILE: docs/Report-v1.2.0.md
================================================
# Report v1.2.0
In May 2024, we launched Open-Sora-Plan v1.1.0, featuring a 2+1D model architecture that could be quickly utilized for exploratory training in text-to-video generation tasks. However, when handling dense visual tokens, the 2+1D architecture could not simultaneously process spatial and temporal dimensions. Therefore, we transitioned to **a 3D full attention architecture**, which better captures the joint spatial-temporal features. Although this version is experimental, it advances video generation architecture to a new realm, leading us to release it as v1.2.0.
Compared to previous video generation models, Open-Sora-Plan v1.2.0 offers the following improvements:
1. **Better compressed visual representations**. We optimized the structure of CausalVideoVAE, which now delivers enhanced performance and higher inference efficiency.
2. **Better video generation architecture**. Instead of 2+1D, we use a diffusion model with a 3D full attention architecture, which provides a better understanding of the world.
### Open-Source Release
We open-source the Open-Sora-Plan to facilitate future development of Video Generation in the community. Code, data, model are made publicly available.
- Code: All training scripts and sample scripts.
- Model: Both Diffusion Model and CausalVideoVAE [here](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0).
- Data: Filtered data [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0).
## Gallery
93×1280×720 Text-to-Video Generation. The video quality has been compressed for playback on GitHub.
## Detailed Technical Report
### CausalVideoVAE
#### Model Structure
The VAE in version 1.2.0 maintains the overall architecture of the previous version but merges the temporal and spatial downsampling layers. In version 1.1.0, we performed spatial downsampling (stride=1,2,2) followed by temporal downsampling (stride=2,1,1). In version 1.2.0, we conduct both spatial and temporal downsampling simultaneously (stride=2,2,2) and perform spatial-temporal upsampling in the decoder (interpolate_factor=2,2,2).
Due to the absence of additional convolutions during downsampling and upsampling, this method more seamlessly inherits the weights from the SD2.1 VAE, leading to improved initialization of our VAE.
#### Training Details
As with v1.1.0, we initialize from the [SD2.1 VAE](https://huggingface.co/stabilityai/sd-vae-ft-mse) using tail initialization. We perform the first phase of training on the Kinetic400 video dataset, then use the EMA weights from this phase to initialize the second phase, which is fine-tuned on high-quality data (collected in v1.1.0). All training is conducted on 25-frame 256×256 videos using **one A100 node**.
| Training stage | Dataset | Training steps |
|---|---|---|
| 1 | K400 | 200,000 |
| 2 | collected in v1.1.0 | 450,000 |
#### Evaluation
We evaluated our VAE on the validation sets of two video datasets: [Webvid](https://github.com/m-bain/webvid) and [Panda70m](https://github.com/snap-research/Panda-70M/), and compared it with our [v1.1.0](https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/docs/Report-v1.1.0.md), [SD2.1 VAE](https://huggingface.co/stabilityai/sd-vae-ft-mse), [CV-VAE](https://github.com/AILab-CVC/CV-VAE), and [Open-Sora's VAE](https://github.com/hpcaitech/Open-Sora). The Webvid validation set contains 5k videos, while the Panda70m validation set has 6k videos. The videos were resized to 256 pixels on the short side, center-cropped to 256x256, and then 33 consecutive frames were extracted. We used PSNR, SSIM, and LPIPS metrics, and measured the encoding speed on an A100 GPU. The specific results are as follows:
**WebVid**
| Model | Compress Ratio |PNSR↑ | SSIM↑ |LPIPS↓ |
|---|---|---|---|---|
| SD2-1 VAE | 1x8x8 | 30.19 | 0.8379 | 0.0568 |
| SVD VAE | 1x8x8 |31.15 |0.8686 | **0.0547** |
| CV-VAE | 4x8x8 | 30.76 | 0.8566 | 0.0803 |
| Open-Sora VAE | 4x8x8 | 31.12 | 0.8569 | 0.1003 |
| Open-Sora Plan v1.1 | 4x8x8 | 30.26 | 0.8597 |0.0551 |
| Open-Sora Plan v1.2 | 4x8x8| **31.16** | **0.8694** | 0.0586 |
**Panda70M**
| Model | Compress Ratio| PNSR↑ | SSIM↑ |LPIPS↓ |
|---|---|---|---|---|
| SD2-1 VAE | 1x8x8 |30.40 | 0.8894 | 0.0396 |
| SVD VAE | 1x8x8 |31.00 | **0.9058** | **0.0379** |
| CV-VAE | 4x8x8| 29.57 | 0.8795 | 0.0673 |
| Open-Sora VAE | 4x8x8 | **31.06** | 0.8969 | 0.0666 |
| Open-Sora Plan v1.1 | 4x8x8 | 29.16 | 0.8844 | 0.0481 |
| Open-Sora Plan v1.2 | 4x8x8| 30.49 |0.8970 |0.0454|
**Encode Time on A100**
|Input Size| CV-VAE | Open-Sora | Open-Sora Plan v1.1 | Open-Sora Plan v1.2 |
|---|---|---|---|---|
| 33x256x256 | 0.186 | 0.147 |0.104 | **0.102** |
| 81x256x256 | 0.465 | 0.357 |0.243 | **0.242** |
### Training Text-to-Video Diffusion Model
#### Model Structure
The most significant change is that we **replaced all 2+1D Transformer blocks with 3D full attention blocks**. Each video is first processed by a patch embedding layer, which downsamples the spatial dimensions by a factor of 2. The video is then flattened into a one-dimensional sequence across the frame, width, and height dimensions. We replaced [T5-XXL](https://huggingface.co/DeepFloyd/t5-v1_1-xxl) with [mT5-XXL](https://huggingface.co/google/mt5-xxl) to enhance multilingual adaptation. Additionally, we incorporated RoPE.
### Sequence Parallelism
Due to the high computational complexity of 3D full attention, we must allocate a video across 2 GPUs for parallel processing when training with long-duration and high-resolution videos. We can control the number of GPUs used for a video sample by adjusting the batch size on a node. For example, with `sp_size=8` and `train_sp_batch_size=4`, 2 GPUs are used for a single sample. **We support sequence parallelism for both training and inference**.
**Training on 93×720p**, we report speed on H100.
| GPU (sp_size) | batch size | Enable sp | Train_sp_batch_size | Speed | Step per day |
|---|---|---|---|---|---|
|8|8|×|-|100s/step|~850|
|8|-|√|4|53s/step|~1600|
|8|-|√|2|27s/step|~3200|
**Inference on 93×720p**, we report speed on H100.
| Size | 1 GPU | 8 GPUs |
|---|---|---|
|29×720p|420s/100step|80s/100step|
|93×720p|3400s/100step|450s/100step|
#### Dynamic training
Deep neural networks are typically trained using batched inputs. For efficient hardware processing, batch shapes are fixed, leading to a fixed data size. This requires either cropping or padding images to a uniform size, both of which have drawbacks: cropping degrades performance, while padding is inefficient and results in significant information loss. Generally, there are three methods for training with arbitrary token counts: Patch n' Pack, bucket, and pad-mask.
**Patch n' Pack** ([NaViT](https://arxiv.org/abs/2307.06304)): bypasses the fixed sequence length limitation by combining tokens from multiple samples into a new sample. This approach allows variable-resolution images while maintaining aspect ratios by packaging multiple samples together, thereby reducing training time and enhancing performance and flexibility. However, this method involves significant code modifications and requires re-adaptation when exploring different model architectures in fields with unstable model designs.
**Bucket** ([Pixart-alpha](https://arxiv.org/abs/2310.00426), [Open-Sora](https://github.com/hpcaitech/Open-Sora)): This method packages data of different resolutions into buckets, sampling batches from each bucket to ensure same resolution within each batch. It requires minimal code modifications to the model, mainly adjusting the data sampling strategy.
**Pad-mask** ([FiT](https://arxiv.org/abs/2402.12376), our v1.0/v1.1): This method sets a maximum resolution and pads all data to this resolution, generating a corresponding mask. Although the approach is straightforward, it is computationally inefficient.
We believe that current video generation models are still in an exploratory phase. Extensive modifications to model code during this period can incur unnecessary development costs. The pad-mask method, while straightforward, is computationally inefficient and can waste resources in video, which involves dense computations. Ultimately, we chose the bucket strategy, which requires no modifications to the model code. Next, we will explain how our bucket strategy supports arbitrary lengths and resolutions. For simplicity, we will use video duration as an example:
We define a megabatch as the total data processed in a single step across all GPUs. A megabatch can be divided into multiple batches, with each batch corresponding to the data processed by a single GPU.
**Sort by frame**: The first step is to count the number of frames in all video data and sort them. This step aims to group similar data together, with sorting being one method to achieve this.
**Group megabatch**: Next, all data is divided into groups, each forming a megabatch. Since all data is pre-sorted, most videos within a megabatch have the same number of frames. However, there will always be boundary cases, such as having both 61-frame and 1-frame videos in a single megabatch.
**Re-organize megabatch**: We re-organize these special megabatches, which actually constitute a small proportion. We randomly replace the minority data in the megabatch with the majority data, thus re-organizing it into a megabatch with same frame counts.
**Shuffle megabatch**: To ensure data randomness, we shuffle both within each megabatch and between different megabatches.
When supporting dynamic resolutions, we simply replace each sample's frame sequence with (frame × height × width). This method ensures that the data dimension processed by each GPU in every step is the same, preventing situations where GPU1 waits for GPU0 to finish processing a longer video. Moreover, it is entirely decoupled from the model code, serving as a plug-and-play video sampling strategy.
#### Training stage
Similar to previous work, we use a multi-stage training approach. With the 3D DiT architecture, all parameters can be transferred from images to videos without loss. To explore training costs, all parameters of the diffusion model are trained from scratch. Therefore, we first train an text-to-image model, using the training strategy from [Pixart-alpha](https://arxiv.org/abs/2310.00426).
The video model is initialized with weights from a 480p image model. We first train 480p videos with 29 frames. Next, we adapt the weights to 720p resolution, training on approximately 6 million higher-quality (HQ) samples from Panda70M, filtered for aesthetic quality and motion. Finally, we refine the model with a more higher-quality (HQ) subset of 1 million samples. After that, we use a filtered data (collected in v1.1.0) for fine-tuning 93-frame 720p videos. Below is our training card. We release the annotation file [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/anno_json).
| Name | Stage 1 | Stage 2 | Stage 3 | Stage 4 |Stage 5 |
|---|---|---|---|---|---|
| Training Video Size | 1×320×240 | 1×640×480 | 29×640×480 | 29×1280×720 | 93×1280×720 |
| Training Step| 146k | 200k | 30k | 21k | 3k |
| Compute (#Num x #Hours) | 32 Ascend × 81 | 32 Ascend × 142 | 128 Ascend × 38 | 256 H100 × 64 | 256 H100 × 84 |
| Checkpoint | - | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/1x480p) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/29x480p) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/29x720p) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/93x720p) |
| Log | - | - | [wandb](https://api.wandb.ai/links/1471742727-Huawei/trdu2kba) | [wandb](https://api.wandb.ai/links/linbin/vvxvcd7s) | [wandb](https://api.wandb.ai/links/linbin/easg3qkl)
| Training Data | [10M SAM](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0/blob/main/anno_json/sam_image_11185255_resolution.json) | 5M internal image data | [6M HQ Panda70M](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0/blob/main/anno_json/Panda70M_HQ6M.json) | [6M HQ Panda70M](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0/blob/main/anno_json/Panda70M_HQ6M.json) | [1M HQ Panda70M](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0/blob/main/anno_json/Panda70M_HQ1M.json) and [100k HQ data](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/anno_json) (collected in v1.1.0) |
Additionally, we fine-tuned 3.5k steps from the final 93×720p to get [93×480p](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/93x480p) for community research use.
### Training Image-to-Video Diffusion Model
#### Model Structure
To reuse the weights of the Text-to-Video model, our Image-to-Video model is inspired by the Stable Diffusion Inpainting Model and adopts a strategy based on frame-level inpainting. By incorporating three types of information—original noise, masked video, and mask—under different control frame conditions, our model can generate coherent videos while ensuring flexibility in its usage.
Compared to the denoiser structure of the Text-to-Video model, the Inpainting Model's denoiser has only changed the number of channels in the `conv in` layer. To ensure the model has a good prior knowledge, we introduce the masked video and mask information through zero initialization. We believe this is due to the 2+1D structure's lack of ability to establish long-range information dependencies, and relying solely on attention in the temporal dimension makes it difficult to capture information changes under frame control. In Text-to-Video tasks, this phenomenon is not as evident because all frames share the same text prompt embedding. However, in Image-to-Video tasks, simply concatenating images in the channel dimension does not ensure the model can accurately capture changes between frames. This is because the model cannot directly replicate image information from the channels to reduce the loss, and the 2+1D structure's interaction solely on the temporal axis fails to allow the model to discern which information from the control frames can be utilized, especially there are significant differences between frames. Therefore, without a shared image-semantic information, the control frame information might not be effectively conveyed to each frame.
##### About Semantic Adapter
In previous models based on the Unet 2+1D architecture, it is necessary to input the control frames into the CLIP model to obtain semantic embeddings. These semantic embeddings are then injected into the denoiser through cross-attention. The structure that extracts CLIP embeddings and injects them into the denoiser is commonly referred to as a semantic adapter.
In the 2+1D architecture, the semantic adapter is commonly present. Additionally, papers like [DynamiCrafter](https://arxiv.org/abs/2310.12190) have pointed out that incorporating the semantic adapter helps maintain stability in the generated videos. We believe this is because the 2+1D structure lacks the ability to establish long-range information dependencies, and relying solely on attention in the temporal dimension makes it difficult to capture information changes under frame control. In the Text-to-Video task, this phenomenon is not as evident because all frames share the same text prompt embedding. However, in the Image-to-Video task, without shared semantic information, it may lead to the inability to effectively transfer control frame information to each individual frame.
We conducted a simple comparison of the performance of using the Inpainting Model under the 2+1D structure (Open-Sora Plan v1.1, left in the figure) versus the 3D structure (Open-Sora Plan v1.2, right in the figure). With the same number of optimization steps, the probability of unstable visual performance in the 2+1D structure was significantly higher than in the 3D structure. Even at convergence, the 2+1D structure's visual stability was still inferior to that of the 3D structure, and it was even worse than the early training stages of the 3D structure.
## Future Work and Discussion
#### CausalVideoVAE
We observed that high-frequency motion information in videos tends to exhibit jitter, and increasing training duration and data volume does not significantly alleviate this issue. In videos, compressing the duration while maintaining the original latent dimension can lead to significant information loss. A more robust VAE will be released in the next version.
#### Diffusion Model
We replaced T5 with mT5 to enhance multilingual capabilities, but this capability is limited as our training data is currently only in English. The multilingual ability primarily comes from the mT5 mapping space. We will explore additional text encoders and expand the data in the next steps.
Our model performs well in generating character consistency, likely due to panda70m being a character-centric dataset. However, it still shows poor performance in text consistency and object generalization. We suspect this may be due to the limited amount of data the model has seen, as evidenced by the non-convergence of the loss in the final stage. **We hope to collaborate with the open-source community to optimize the 3D DiT architecture.**
================================================
FILE: docs/Report-v1.3.0.md
================================================
# Report v1.3.0
In August 2024, we released Open-Sora-Plan v1.2.0, transitioning to a 3D full attention architecture, which enhanced the capture of joint spatial-temporal features. However, the substantial computational cost made it unsustainable, and the lack of a clear training strategy hindered continuous progress along a focused path.
In version 1.3.0, Open-Sora-Plan introduced the following five key features:
**1. A more powerful and cost-efficient WFVAE.** We decompose video into several sub-bands using wavelet transforms, naturally capturing information across different frequency domains, leading to more efficient and robust VAE learning.
**2. Prompt Refiner.** A large language model designed to refine short text inputs.
**3. High-quality data cleaning strategy.** The cleaned panda70m dataset retains only 27% of the original data.
**4. DiT with new sparse attention.** A more cost-effective and efficient learning approach.
**5. Dynamic resolution and dynamic duration.** This enables more efficient utilization of videos with varying lengths (treating a single frame as an image).
### Open-Source Release
We open-source the Open-Sora-Plan to facilitate future development of Video Generation in the community. Code, data, model will be made publicly available.
- Code: All training scripts and sample scripts.
- Model: Both Diffusion Model and CasualVideoVAE [here](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0).
- Data: The data of prompt refiner is [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner).
## Gallery
Text & Image to Video Generation.
[](https://www.bilibili.com/video/BV1KR2fYPEF5/?spm_id_from=333.999.0.0&vd_source=cfda99203e659100629b465161f1d87d)
## Detailed Technical Report
### WF-VAE
As video generation models move toward higher resolutions and longer durations, the computational cost of video VAEs grows exponentially, becoming unsustainable. Most related work addresses this by using tiling to reduce inference memory consumption. However, in high-resolution, long-duration scenarios, tiling significantly increases inference time. Additionally, since tiling is lossy for latents, it can lead to visual artifacts such as shadows or flickering in the generated videos. Then, we introduce WFVAE, which provide a new model to handle these problems.
#### Model Structure
The compression rate fundamentally determines the quality of VAE-reconstructed videos. We analyzed the energy and entropy of different subbands obtained through wavelet transform and found that most of the energy in videos is concentrated in the low-frequency bands. Moreover, by replacing the `LLL` subband of the VAE-reconstructed video with the original video's `LLL` subband, we observed a significant improvement in the spatiotemporal quality of the videos.
In previous VAE architectures, the lack of a "highway" for transmitting the dominant energy during video compression meant that this pathway had to be gradually established during model training, leading to redundancy in model parameters and structure. Therefore, in our model design, we created a more efficient transmission path for the LLL subband energy, significantly simplifying the model architecture, reducing inference time, and lowering memory consumption.
#### Training Details
More details will be provided in the forthcoming paper.
#### Ablation Study
In our experiments, we used the K400 training and validation sets, conducted on 8xH100 GPUs. The latent dimension was fixed at 4. We observed that as model parameters increased, there was still room for improvement in reconstruction metrics. GroupNorm showed instability during training, performing worse than LayerNorm on PSNR but better on LPIPS.
#### Performance
The following metrics were tested on H100 with float32 precision. For fairness, tiling was disabled for all models, and direct inference was performed.
#### Evaluation
We evaluated PSNR and LPIPS on the Panda70M test set at 256 pixels and 33 frames. In the open-source WF-VAE-S (8-dim), our encoder was distilled from the 8-dim OD-VAE, resulting in some metric degradation compared to direct training.
| Latent Dim | Model | Params | PSNR | LPIPS |
|---|---|---|---|---|
| 4 | OD-VAE(Our VAE in v1.2.0) | 94M + 144M | 30.311| 0.043|
| 4 | WFVAE-S | 38M + 108M | 30.579 | 0.044 |
| 8 | WFVAE-S(Distillion) |38M + 108M | 31.764|0.050 |
#### Causal Cache
To address the issue of tiling, we replaced GroupNorm with LayerNorm and introduced a novel method called **Causal Cache**, enabling lossless temporal block-wise inference.
First, we replaced GroupNorm with LayerNorm and utilized the properties of CausalConv3D to achieve lossless inference through temporal dimension chunking. In each layer of CausalConv3D, we cache the information from the previous few frames to maintain continuity during the convolution sliding operation for the next temporal chunk, thereby enabling lossless processing. As illustrated, we use a kernel size of 3 and a stride of 1 as an example:
**Initial Chunk (chunk idx=0):** For the first time chunk, we perform standard causal padding to support joint processing of graphs and videos. After the convolution operation, we cache the last two frames of this chunk into the causal cache in preparation for the next chunk's inference.
**Subsequent Chunks (chunk idx=1 and beyond):** Starting from the second time chunk, we no longer use causal padding. Instead, we concatenate the cached causal information from the previous chunk to the front of the current chunk. We continue to cache the last two frames of the current input into the causal cache for use in subsequent chunks.
## Prompt Refiner
User-provided captions are typically fewer than 10 words, whereas the text annotations in the current training data are often dense. This inconsistency between training and inference may result in poor visual quality and weak text alignment. We categorize captions into four types:
(1) Short captions from real user input; we collected 11k from [COCO](https://cocodataset.org/#home).
(2) Captions composed of multiple tags; we collected 5k from [DiffusionDB](https://github.com/poloclub/diffusiondb).
(3) Medium-length captions generated by large language models; 3k sourced from [JourneyDB](https://github.com/JourneyDB/JourneyDB).
(4) Ultra-long, surrealist captions, sourced from Sora/Vidu/Pika/Veo and approximately 0.5k generated by GPT.
We used ChatGPT to rewrite the above captions, with the following instructions provided to ChatGPT:
```
rewrite the sentence to contain subject description action, scene description.
Optional: camera language, light and shadow, atmosphere and
conceive some additional actions to make the sentence more dynamic,
make sure it is a fluent sentence, not nonsense.
```
Finally, we performed LoRA fine-tuning using [LLaMa 3.1](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct), completing the training in just 30 minutes with a single H100. We fine-tuned for only 1 epoch, using a batch size of 32 and a LoRA rank of 64. The log can be found [here](https://api.wandb.ai/links/1471742727-Huawei/p5xmkft5). We open-sourced the data [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner).
### Data Construction
We randomly sampled from the original Panda70m dataset and found many videos to be static, contain multiple subtitles, or suffer from motion blur. Additionally, the captions in Panda70m did not always accurately describe the video content. To address this, we designed a video filtering pipeline, which retained approximately 27% of the videos after processing.
#### Jump Cut and Detect Motion
We used [LPIPS](https://github.com/richzhang/PerceptualSimilarity) frame-skipping to compute inter-frame semantic similarity, identifying anomalies as cut points and taking the mean as the motion score. We found that videos with motion scores below 0.001 were nearly static, while those above 0.3 exhibited significant jitter and flicker. After applying this method, we manually reviewed 2k videos and concluded that the cut detection accuracy was sufficient for pre-training requirements.
#### OCR
We estimated the average position of subtitles on common video platforms to be around 18%. Consequently, we set the maximum crop threshold to 20% of the video's original dimensions and used [EasyOCR](https://github.com/JaidedAI/EasyOCR) to detect subtitles (sampling one frame per second). However, not all videos have subtitles or printed text located at the edges; this method may miss text appearing in central areas, such as in advertisement videos or speeches. Nonetheless, we cannot assume that the presence of text in a video necessitates filtering it out, as certain texts in specific contexts can be meaningful. We leave such judgments to aesthetic considerations.
#### Aesthetic
As before, we used the [Laion aesthetic predictor](https://github.com/christophschuhmann/improved-aesthetic-predictor) for evaluation. Based on the visualization [website](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html), we determined that a score of 4.75 serves as a suitable threshold, effectively filtering out excessive text while retaining high-quality aesthetics. We will add an additional aesthetic prompt, such as `A high-aesthetic scene, ` for data with a score above 6.25.
#### Video Quality
Some old photos or videos have very low bit rates, resulting in blurry visual effects even at 480P resolution, often resembling a mosaic appearance. Aesthetic filtering struggles to exclude these videos, as it resizes images to 224 resolution. We aim to establish a metric for assessing absolute video quality, independent of the visual content itself, focusing solely on compression artifacts, low bit rates, and jitter. We employed the technical prediction score from [DOVER](https://github.com/VQAssessment/DOVER) and excluded videos with scores below 0.
#### Recheck Motion
Since some videos contain subtitles, variations in the subtitles may lead to inaccurate motion values. Therefore, we re-evaluated the motion values and discarded static videos.
#### Captioning
We used [QWen2-VL-7B](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) for video annotation.
```
Please describe the content of this video in as much detail as possible,
including the objects, scenery, animals, characters, and camera movements within the video.
Do not include '\n' in your response.
Please start the description with the video content directly.
Please describe the content of the video and the changes that occur, in chronological order.
```
However, the 7B model tends to generate certain prefixes, such as "This video" or "The video." We compiled a list of all irrelevant opening strings and removed them.
```
'The video depicts ',
'The video captures ',
'In the video, ',
'The video showcases ',
'The video features ',
'The video is ',
'The video appears to be ',
'The video shows ',
'The video begins with ',
'The video displays ',
'The video begins in ',
'The video consists of ',
'The video opens with ',
'The video opens on ',
'The video appears to capture ',
'The video appears to show ',
"The video appears to depict ",
"The video opens in ",
"The video appears to focus closely on ",
"The video starts with ",
"The video begins inside ",
"The video presents ",
"The video takes place in ",
"The video appears to showcase ",
"The video appears to display ",
"The video appears to focus on ",
"The video appears to feature "
```
### Training Text-to-Video Diffusion Model
#### Framework
##### Skiparse (Skip-Sparse) Attention
In video generation models, alternating 2+1D spatial-temporal blocks is a commonly used approach, yet these models lack long-range modeling, limiting their performance ceiling. Consequently, models like [CogVideoX](https://arxiv.org/abs/2408.06072), [Meta Movie Gen](https://ai.meta.com/research/movie-gen/), and Open-Sora Plan v1.2 employ **Full 3D Attention** as a denoiser, achieving substantially improved visual fidelity and motion quality compared to 2+1D models. This approach, however, requires calculating attention across all tokens in each clip encoding, which significantly raises training costs. For instance, Open-Sora Plan v1.2, training a 2.7-billion-parameter model, takes **100 seconds per step at 93x720p and over 15 seconds per step at 93x480p**, severely constraining scalability under limited computational resources.
To accelerate training while ensuring adequate performance, we propose the **Skiparse (Skip-Sparse) Attention** method. Specifically, under a fixed sparse ratio $$k$$ , we organize candidate tokens for attention through two alternating skip-gather methods. This approach preserves the attention operation is global while effectively reducing FLOPS, enabling faster training of 3D Attention models. In our experiments, applying Skiparse with sparse ratio $$k=4$$ to a 2.7B model reduced training time to **42 seconds per step at 93x720p and 8 seconds per step at 93x480p**.
**Skiparse DiT modifies only the Attention component** within the Transformer Block, using two alternating Skip Sparse Transformer Blocks. With sparse ratio $$k$$, the sequence length in the attention operation reduces to $$\frac{1}{k}$$ of the original, and batch size increases by $$k$$-fold, lowering the theoretical complexity of self-attention to $$\frac{1}{k}$$ of the original, while cross-attention complexity remains unchanged. Due to GPU/NPU parallel processing, increasing the batch size by $$k$$-fold does not linearly decrease speed to $$\frac{1}{k}$$, resulting in a performance boost that exceeds theoretical expectations.
In Single Skip mode, the elements located at positions $$[0, k, 2k, 3k, ...]$$ , $$[1, k+1, 2k+1, 3k+1, ...]$$ , ..., $$[k-1, 2k-1, 3k-1, ...]$$ are grouped into the same scope (with each list forming one scope of elements). The figure above, using $$k=2$$ as an example, illustrates this organizational structure. This concept is straightforward, as each token performs attention with tokens spaced $$k-1$$ apart.
In Group Skip mode, elements at positions $$[(0, 1, ..., k-1), (k^2, k^2+1, ..., k^2+k-1), (2k^2, 2k^2+1, ..., 2k^2+k-1), ...]$$ , $$[(k, k+1, ..., 2k-1), (k^2+k, k^2+k+1, ..., k^2+2k-1), (2k^2+k, 2k^2+k+1, ..., 2k^2+2k-1), ...]$$ , ..., $$[(k^2-k, k^2-k-1, ..., k^2-1), (2k^2-k, 2k^2-k-1, ..., 2k^2-1), (3k^2-k, 3k^2 -k-1, ..., 3k^2-1), ...]$$ are grouped together as a scope (with each list forming a scope). This arrangement may seem complex numerically, so it can be helpful to understand with the above figure.
In this pattern, we first **group adjacent tokens** in segments of length $$k$$ , then **bundle these groups** with other groups that are spaced $$k-1$$ groups apart into a single scope.For example, in $$[(0, 1, ..., k-1), (k^2, k^2+1, ..., k^2+k-1), (2k^2, 2k^2+1, ..., 2k^2+k-1), ...]$$ , each set of indices in parentheses represents a group. Each group is then connected with another group that is offset by $$k-1$$ groups, forming one scope.
Since the last index of the first group is $$k-1$$ , the first token in the next group to be linked will be at index $$k-1+k(k-1)+1=k^2$$ . Following this pattern, you can determine the indices for each scope in this configuration.
##### Why "Skiparse"?
The 2+1D DiT models temporal understanding only along the time axis of a single spatial location, theoretically and practically limiting performance. In real-world scenarios, changes at a specific spatial location are typically influenced not by prior content at that same location but by content across all spatial locations at preceding times. This constraint makes it challenging for 2+1D DiT to model complex physical dynamics accurately.
Full 3D Attention represents global attention, allowing any spatial position at any time to access information from any other position across all times, aligning well with real-world physical modeling. However, this approach is time-consuming and inefficient, as visual information often contains considerable redundancy, making it unnecessary to establish attention across all spatiotemporal tokens.
**A ideal spatiotemporal modeling approach should employ attention that minimizes the overhead from redundant visual information while capturing the complexities of the dynamic physical world**. Reducing redundancy requires avoiding connections among all tokens, yet global spatiotemporal attention remains essential for modeling complex physical interactions.
To achieve a balance between 2+1D efficiency and Full 3D’s strong spatiotemporal modeling, we developed Skiparse Attention. This approach provides global spatiotemporal attention within each block, with each block having the same “receptive field”. The use of "group" operations also introduces a degree of locality, aligning well with visual tasks.
Interestingly, once you understand the Skiparse Attention mechanism, you’ll notice that **the attention in 2+1D DiT corresponds to a sparse ratio of $$k=HW$$ (since $$T \ll HW$$ , making the "skip" in Group Skip negligible), while Full 3D DiT corresponds to a sparse ratio of $$k=1$$.** In Skiparse Attention, $$k$$ is typically chosen to be close to 1, yet far smaller than $$HW$$ , making it a 3D Attention that approaches the effectiveness of Full 3D Attention.
In Skiparse Attention, Single Skip is a straightforward operation, easily understood by most. Within Group Skip, the Group operation is also intuitive, serving as a means to model local information. However, **Group Skip involves not only grouping but also skipping**—particularly between groups—which is often overlooked. This oversight frequently leads researchers to confuse Skiparse Attention with a Skip + Window Attention approach. The key difference lies in even-numbered blocks: Window Attention only groups tokens without skipping between groups. The distinctions among these attention methods are illustrated in the figure below, which shows the attention scopes for self-attention only, with dark tokens representing the tokens involved in each attention calculation.
To deeply understand why nearly global attention is necessary and why Skiparse Attention theoretically approximates Full 3D Attention more closely than other common methods, we introduce the concept of **Average Attention Distance**. This concept is defined as follows: for any two tokens, if it takes $$m$$ attention operations to establish a connection between them, the attention distance is $$m$$ . The average attention distance for a tensor is then the mean of the attention distances across all token pairs, representing the corresponding attention method’s overall connectivity efficiency. The average attention distance of all tokens within a tensor is defined as the average attention distance for that particular attention method.
For example, in Full 3D Attention, any token can connect with any other token in just one attention operation, resulting in an average attention distance of 1.
In 2+1D Attention, the process is somewhat more complex, though still straightforward to understand. In all configurations above, any two different tokens can connect with an attention distance between 1 and 2 (Note that we define the attention distance between a token and itself as zero). Thus, for the other three attention methods, we can first identify which tokens have an attention distance of 1. Subsequently, tokens with an attention distance of 2 can be determined, allowing us to calculate the average attention distance.
In the $$2N$$ Block, attention operates over the $$(H, W)$$ dimensions, where tokens within this region have an attention distance of 1. In the $$2N+1$$ Block, attention operates along the $$(T)$$ dimension, also assigning an attention distance of 1 for these tokens. The total number of tokens with an attention distance of 1 in this case is $$HW + T - 2$$ (excluding the token itself, hence $$(HW + T - 1) - 1 = HW + T - 2$$).
Therefore, in 2+1D Attention, the average attention distance (AVG Attention Distance) is:
$$
\begin{aligned}
d&=\frac{1}{THW}\left[ 1\times 0+\left( HW+T-2 \right) \times 1+\left[ THW-\left( HW+T-1 \right) \right] \times 2 \right]\\
&=2-\left( \frac{1}{T}+\frac{1}{HW} \right)\\
\end{aligned}
$$
In Skip+Window Attention, aside from the token itself, there are $$\frac{THW}{k} - 1$$ tokens with an attention distance of 1 in the $$2N$$ Block, and $$k - 1$$ tokens with an attention distance of 1 in the $$2N+1$$ Block. Thus, the total number of tokens with an attention distance of 1 is $$\frac{THW}{k} + k - 2$$.
Therefore, in Skip+Window Attention, the average attention distance (AVG Attention Distance) is:
$$
\begin{aligned}
d&=\frac{1}{THW}\left[ 1\times 0+\left( \frac{THW}{k}+k-2 \right) \times 1+\left[ THW-\left( \frac{THW}{k}+k-1 \right) \right] \times 2 \right]\\
&=2-\left( \frac{1}{k}+\frac{k}{THW} \right)\\
\end{aligned}
$$
In Skiparse Attention, aside from the token itself, $$\frac{THW}{k} - 1$$ tokens have an attention distance of 1 in the $$2N$$ Block, and $$\frac{THW}{k} - 1$$ tokens have an attention distance of 1 in the $$2N+1$$ Block. Notably, $$\frac{THW}{k^2} - 1$$ tokens can establish an attention distance of 1 in both blocks and should not be counted twice.
Therefore, in Skiparse Attention, the average attention distance (AVG Attention Distance) is:
$$
\begin{aligned}
d&=\frac{1}{THW}\left[ 1\times 0+\left[ \frac{2THW}{k}-2-\left( \frac{THW}{k^2}-1 \right) \right] \times 1+\left[ THW-\left( \frac{2THW}{k}-\frac{THW}{k^2} \right) \right] \times 2 \right]\\
&=2-\frac{2}{k}+\frac{1}{k^2}-\frac{1}{THW}\\
&=2-\frac{2}{k}+\frac{1}{k^2}\left( 1\ll THW \right)\\
\end{aligned}
$$
In fact, in the Group Skip of the $$2N+1$$ Block, the actual sequence length is $$k\lceil \frac{THW}{k^2} \rceil$$ rather than $$\frac{THW}{k}$$. The prior calculation assumes the ideal case where $$k \ll THW$$ and $$k$$ divides $$THW$$ exactly, yielding $$k\lceil \frac{THW}{k^2} \rceil = k \cdot \frac{THW}{k^2} = \frac{THW}{k}$$. In practical applications, excessively large $$k$$ values are typically avoided, making this derivation a reasonably accurate approximation for general use.
Specifically, when $$k = HW$$ and padding is disregarded, since $$T \ll HW$$, group skip attention reduces to window attention with a window size of $$HW$$. Given that padding does not affect the final computation, Skiparse Attention is equivalent to 2+1D Attention when $$k = HW$$.
For the commonly used resolution of 93x512x512, using a causal VAE with a 4x8x8 compression rate and a DiT with a 1x2x2 patch embedding, we obtain a latent shape of 24x32x32 before applying attention. The AVG Attention Distance for different calculation methods would then be as follows:
| | Full 3D Attention | 2+1D Attention |
| ---------------------- | ----------------- | --------------- |
| AVG Attention Distance | 1 | 1.957 |
| | Skip + Window Attention(k=2) | Skip + Window Attention(k=4) | Skip + Window Attention(k=6) | Skip + Window Attention(k=8) |
| ---------------------- | ---------------------------- | ---------------------------- | ---------------------------- | ---------------------------- |
| AVG Attention Distance | 1.500 | 1.750 | 1.833 | 1.875 |
| | Skiparse Attention(k=2) | Skiparse Attention(k=4) | Skiparse Attention(k=6) | Skiparse Attention(k=8) |
| ---------------------- | ----------------------- | ----------------------- | ----------------------- | ----------------------- |
| AVG Attention Distance | 1.250 | 1.563 | 1.694 | 1.766 |
In 2+1D Attention, the average attention distance is 1.957, larger than that of Skip + Window Attention and Skiparse Attention at commonly used sparse ratios. While Skip + Window Attention achieves a shorter average attention distance, its modeling capacity remains limited due to the locality of attention in its 2N+1 blocks. Skiparse Attention, with the shortest average attention distance, applies global attention in both 2N and 2N+1 blocks, making its spatiotemporal modeling capabilities closer to Full 3D Attention than the other two non-Full 3D methods.
The figure above shows how Skiparse Attention’s AVG Attention Distance changes with sparse ratio $$k$$.
We can summarize the characteristics of these attention types as follows:
| | Full 3D Attention | 2+1D Attention | Skip + Window Attention | Skiparse Attention |
| ---------------------------------- | ----------------- | -------------------------------- | --------------------------------------- | ------------------------------------------------------------ |
| Speed | Slow | Fast | Depending on $$k$$ | Depending on $$k$$ |
| Spatiotemporal modeling capability | Strong | Weak | Weak | Approaches Full 3D |
| Is attention global? | Yes | No | Half of the attention blocks are global | Yes |
| Computation load per block | Equal | Not Equal | Not Equal | Equal |
| AVG Attention Distance | 1 | $$2-(\frac{1}{T}+\frac{1}{HW})$$ | $$2-(\frac{1}{k}+\frac{k}{THW})$$ | $$2-\frac{2}{k}+\frac{1}{k^2},1
#### Dynamic training
Overall, we maintained the bucket strategy from v1.2.0, pre-defining the shape of each video during training and aggregating data of the same shape through a sampler. Finally, the dataloader retrieves data based on our aggregated indices.
In our early implementation, we specified `--max_width`, `--max_height`, `--min_width`, and `--min_height`. While this allows for specifying arbitrary resolutions within a certain range, this approach can easily lead to OOM issues during video training. For instance, for a 720P (720×1280) video, if the maximum dimensions are set to 720, the video would be scaled to 405×720. However, if there are square videos with resolutions greater than 720, they would be scaled to 720×720. Most videos are non-square, and to prevent OOM, we need to reserve GPU memory, which leads to significant computational waste. Therefore, we recommend using `--max_token` and `--min_token` to limit any range, as this aligns better with the Transformer architecture.
#### Training scheduler
We replaced the eps-pred loss with v-pred loss and enable ZeroSNR. For videos, we resample to 16 FPS for training.
**Stage 1**: We initially initialized from the image weights of version 1.2.0 and trained images at a resolution of 1x320x320. The objective of this phase was to fine-tune the 3D dense attention model to a sparse attention model. The entire fine-tuning process involved approximately 100k steps, with a batch size of 1024 and a learning rate of 2e-5. The image data was primarily sourced from SAM in version 1.2.0.
**Stage 2**: We trained the model jointly on images and videos, with a maximum resolution of 93x320x320. The entire fine-tuning process involved approximately 300k steps, with a batch size of 1024 and a learning rate of 2e-5. The image data was primarily sourced from SAM in version 1.2.0, while the video data consisted of the unfiltered Panda70m. In fact, the model had nearly converged around 100k steps, and by 300k steps, there were no significant gains. Subsequently, we performed data cleaning and caption rewriting, with further data analysis discussed at the end.
**Stage 3**: We fine-tuned the model using our filtered Panda70m dataset, with a fixed resolution of 93x352x640. The entire fine-tuning process involved approximately 30k steps, with a batch size of 1024 and a learning rate of 1e-5.
### Training Image-to-Video Diffusion Model
#### Framework
In terms of framework, Open-Sora Plan v1.3 continues to use the Inpainting model architecture from Open-Sora Plan v1.2.
#### Data processing
For data processing, Open-Sora Plan v1.3 introduces two new mask types: an all-1 mask and an all-0 mask. This brings the total number of mask types in the Inpainting Model to six.
In the figure above, black indicates retained frames, while white denotes discarded frames. The corresponding frame strategies are as follows:
- **Clear**: Retain all frames.
- **T2V**: Discard all frames.
- **I2V**: Retain only the first frame; discard the rest.
- **Transition**: Retain only the first and last frames; discard the rest.
- **Continuation**: Retain the first $$n$$ frames; discard the rest.
- **Random**: Retain $$n$$ randomly selected frames; discard the rest.
#### Progressive training
The Open-Sora Plan v1.3 uses more data for training and employs a progressive training approach to help the model understand frame-based inpainting tasks.
Since the Inpainting Model supports various mask inputs, different mask inputs correspond to tasks of varying difficulty levels. Therefore, we can first teach the model simple tasks, such as random masks, allowing it to develop a basic capability for frame-based inpainting before gradually increasing the proportion of more challenging tasks. It is important to note that at different training stages, we ensure that at least 5% of the data the model sees pertains to T2V tasks, which is aimed at enhancing the model's understanding of prompts.
The model weights are initialized from the T2V model with zero initialization. The batch size is fixed at 256, and the learning rate is set to 1e-5, using a two-stage training approach.
**Stage 1**: Any resolution and duration within 93x102400 (320x320), using unfiltered motion and aesthetic low-quality data:
(1) Step 1: t2v 10%, continuation 40%, random mask 40%, clear 10%. Ensure that at least 50% of the frames are retained during continuation and random mask, training with 4 million samples.
(2) Step 2: t2v 10%, continuation 40%, random mask 40%, clear 10%. Ensure that at least 25% of the frames are retained during continuation and random mask, training with 4 million samples.
(3) Step 3: t2v 10%, continuation 40%, random mask 40%, clear 10%. Ensure that at least 12.5% of the frames are retained during continuation and random mask, training with 4 million samples.
(4) Step 4: t2v 10%, continuation 25%, random mask 60%, clear 5%. Ensure that at least 12.5% of the frames are retained during continuation and random mask, training with 4 million samples.
(5) Step 5: t2v 10%, continuation 25%, random mask 60%, clear 5%, training with 8 million samples.
(6) Step 6: t2v 10%, continuation 10%, random mask 20%, i2v 40%, transition 20%, training with 16 million samples.
(7) Step 7: t2v 5%, continuation 5%, random mask 10%, i2v 40%, transition 40%, training with 10 million samples.
**Stage 2:** Any resolution and duration within 93x236544 (e.g., 480x480, 640x352, 352x640), using filtered motion and aesthetic high-quality data:
t2v 5%, continuation 5%, random mask 10%, i2v 40%, transition 40%, training with 15 million samples.
#### About the Semantic Adapter
We conducted further experiments on the Semantic Adapter module and compared the video quality of Image-to-Video under various Image Encoders, including [Clip](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K) and [Dino v2](https://huggingface.co/timm/vit_large_patch14_dinov2.lvd142m). We also attempted strategies such as directly injecting image embeddings into cross-attention or extracting features from the Image Encoder using Qformer before injecting them into cross-attention.
Under various strategies, we did not observe significant performance improvements; the impact on video quality was much smaller than that of the dataset. Therefore, we decided not to include the Semantic Adapter in Open-Sora Plan v1.3.
#### Noise Injection Strategy for Conditional Images
Researchs like [CogVideoX](https://arxiv.org/abs/2408.06072) and [Stable Video Diffusion](https://stability.ai/stable-video) have indicated that adding a certain amount of noise to Conditional Images can enhance the generalization capability of I2V models and achieve a greater range of motion. Therefore, we will implement this strategy in Open-Sora Plan v1.3, just the same as in [CogVideoX](https://arxiv.org/abs/2408.06072).
### The implementation of Skiparse Attention
Skiparse is theoretically easy to understand and straightforward to implement. Its implementation mainly relies on the rearrange operation, which reduces the sequence length of latents before entering `F.scaled_dot_product_attention()`. Aside from this adjustment, no other modifications are made. For simplicity, the following discussion focuses solely on the self-attention part, excluding the attention mask.
The pseudocode implementation of Single Skip is as follows:
```python
# x.shape: (B,N,C)
def single_skip_rearrange(x, sparse_k):
return rearrange(x, 'b (g k) d -> (k b) g d', k=sparse_k)
def reverse_sparse(x, sparse_k):
return rearrange(x, '(k b) g d -> b (g k) d', k=sparse_k)
q, k, v = Q(x), K(x), V(x)
q = add_rope(q)
k = add_rope(k)
q = single_skip_rearrange(q)
k = single_skip_rearrange(k)
v = single_skip_rearrange(v)
hidden_states = F.scaled_dot_product_attention(q=q,k=k,v=v)
output = reverse_sparse(hidden_states)
```
The core of the Skiparse operation lies in "rearranging the sequence", which corresponds to the Single Skip operation in the pseudocode:
```python
rearrange(x, '(g k) b d -> g (k b) d', k=sparse_k)
```
This operation can be understood as a combination of a reshape and a transpose operation:
In this way, $$k$$ sub-sequences can be created, and $$k$$ can be moved to the batch dimension, allowing the Attention mechanism to compute the sub-sequences in parallel.
Understanding Single Skip makes Group Skip easy to comprehend as well; it simply adds a grouping operation before the Skip. Its pseudocode is as follows:
```python
# x.shape: (B,N,C)
def group_skip_rearrange(x, sparse_k):
return rearrange(x, ' b (n m k) d -> (m b) (n k) d', m=sparse_k, k=sparse_k)
def reverse_sparse(x, sparse_k):
return rearrange(x, '(m b) (n k) d -> b (n m k) d', m=sparse_k, k=sparse_k)
q, k, v = Q(x), K(x), V(x)
q = add_rope(q)
k = add_rope(k)
q = group_skip_rearrange(q)
k = group_skip_rearrange(k)
v = group_skip_rearrange(v)
hidden_states = F.scaled_dot_product_attention(q=q,k=k,v=v)
output = reverse_sparse(hidden_states)
```
Every $$k^2$$ tokens form a repetition, and every $$k$$ tokens form a group. To help everyone better understand this operation, the following figure illustrates the situation when $$k=3$$:
It is important to note that the rope is added before the Skiparse operation and cannot be placed after it, as the sequence after Skiparse will lose its original spatial positions.
## Future Work and Discussion
### CasualVideoVAE
For videos, increasing the compression ratio while maintaining the original latent dimension leads to significant information loss. Therefore, it is a trend to increase the latent dimension to achieve higher compression ratios. A more advanced VAE will be released in the next version.
### Diffusion Model
The current 2B model in version 1.3.0 shows performance saturation during the later stages of training. However, it does not perform well in understanding physical laws (e.g., a cup overflowing with milk, a car moving forward, or a person walking). We have 4 hypotheses regarding this issue:
#### The current data domain is too narrow.
We randomly sampled 2,000 videos from Panda70m and conducted manual verification, finding that less than 1% featured cars in motion, and there were even fewer than 10 videos of people walking. Approximately 80% of the videos consist of half-body conversations with multiple people in front of the camera. Therefore, we speculate that the narrow data domain of Panda70m restricts the model's ability to generate many scenarios. We plan to collect more data in the next version.
#### Joint training of images and videos
Models such as [Open-Sora v1.2](https://github.com/hpcaitech/Open-Sora), [EasyAnimate v4](https://github.com/aigc-apps/EasyAnimate), and [Vchitect-2.0](https://github.com/Vchitect/Vchitect-2.0) can easily generate high-visual-quality videos, possibly due to their direct inheritance of image weights ([Pixart-Sigma](https://pixart-alpha.github.io/PixArt-sigma-project/), [HunyuanDiT](https://github.com/Tencent/HunyuanDiT), [SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers)). They train the model with a small amount of video data to learn how to flow along the time axis based on 2D images. However, we trained images from scratch with only 10M-level data, which is far from sufficient. We have two hypotheses regarding the training strategy: (1) the first is to start joint training from scratch, with images significantly outnumbering videos; (2) The second is to first train a high-quality image model and then use joint training, with a higher proportion of videos at that stage. Considering the learning path and training costs, the second approach may offer more decoupling, while the first aligns better with scaling laws.
#### The model still needs to scale
By observing the differences between [CogVideoX-2B](https://github.com/THUDM/CogVideo) and its 5B variant, we can clearly see that the 5B model understands more physical laws than the 2B model. We speculate that instead of spending excessive effort designing for smaller models, it may be more effective to leverage scaling laws to solve these issues. In the next version, we will scale up the model to explore the boundaries of video generation.
We currently have two plans: one is to continue using the Deepspeed/FSDP approach, sharding the EMA and text encoder across ranks with Zero3, which is sufficient for training 10-15B models. The other is to adopt [MindSpeed](https://gitee.com/ascend/MindSpeed) for various parallel strategies, enabling us to scale the model up to 30B.
#### Supervised loss in training
Whether flow-based models are more suitable than v-pred models remains uncertain and requires further ablation studies to determine.
### How else can "Skiparse" skip?
The sparse method we use is theoretically and practically straightforward; however, its implementation treats the original video data purely as a one-dimensional sequence, neglecting the 2D spatial priors. Thus, we extended Skiparse to create Skiparse-2D, which is better suited for 2D Visuals.
In Skiparse-2D, a sparse ratio of $$k$$ represents the sparsity along the $$h$$ or $$w$$ direction. In terms of the number of tokens involved in attention computation, it is equivalent to the square of the sparse ratio in Skiparse-1D.
We conducted basic experiments comparing Skiparse-1D and Skiparse-2D. Under identical experimental settings, Skiparse-2D showed no improvement over Skiparse-1D in terms of loss or sampling results. Additionally, Skiparse-2D is less flexible to implement than Skiparse-1D. Therefore, we opted to use the Skiparse-1D approach for training in Open-Sora Plan v1.3.
Nevertheless, given our limited experimentation, the feasibility of Skiparse-2D remains worth exploring. Intuitively, Skiparse-2D better aligns with the spatial characteristics of visuals, and as sparse ratio $$k$$ increases, its approach intuitively approximates that of 2+1D. We therefore encourage interested researchers in the community to pursue further exploration in this area.
================================================
FILE: docs/Report-v1.5.0.md
================================================
## Report v1.5.0
In October 2024, we released Open-Sora Plan v1.3.0, introducing the sparse attention structure, Skiparse Attention, to the field of video generation for the first time. Additionally, we adopted the efficient WFVAE, significantly reducing encoding time and memory usage during training.
In Open-Sora Plan v1.5.0, We introduce several key updates to enhance the framework:
1、Improved Sparse DiT, SUV. Building on Skiparse Attention, we extend sparse DiT into a U-shaped sparse structure. This design preserves speed advantages while enabling sparse DiT to achieve performance comparable to dense DiT.
2、Higher-compression WFVAE. In Open-Sora Plan v1.5.0, we explore a WFVAE with an 8×8×8 downsampling rate. It outperforms the performance of the widely adopted 4×8×8 VAE in the community, while reducing the latent shape by half and shortening the attention sequence length.
3、Data and model scaling. In Open-Sora Plan v1.5.0, we collect 1.1 billion high-quality images and 40 million high-quality videos. The model is scaled up to 8.5 billion parameters, resulting in strong overall performance.
4、Simplified Adaptive Gradient Clipping strategy. Compared to the more complex batch-dropping method in version 1.3.0, version 1.5.0 maintains a simple adaptive gradient norm threshold for clipping, making it more compatible with various parallel training strategies.
Open-Sora Plan v1.5.0 is fully trained and inferred on Ascend 910-series accelerators, using the mindspeed-mm framework to support parallel training strategies.
### Open-Source Release
Open-Sora Plan v1.5.0 is open-sourced with the following components:
1、All training and inference code. You can also find the implementation of Open-Sora Plan v1.5.0 in the official [MindSpeed-MM](https://gitee.com/ascend/MindSpeed-MM) repository.
2、The WFVAE weights with 8×8×8 compression, along with the 8.5B SUV denoiser weights.
## Detailed Technical Report
### Data collection and processing
Our dataset includes 1.1B images from [Recap-DataComp-1B](https://huggingface.co/datasets/UCSC-VLAA/Recap-DataComp-1B)、[COYO-700M](https://github.com/kakaobrain/coyo-dataset)、[LAION-Aesthetics](https://laion.ai/blog/laion-aesthetics/), with no filtering applied aside from resolution checks. The video data are drawn from [Panda-70M](https://github.com/snap-research/Panda-70M) and internal sources, and filtered using the same protocol as in Open-Sora Plan v1.3.0, yielding 40M high-quality videos.
### Adaptive Grad Clipping
In Open-Sora Plan v1.3.0, we introduce an Adaptive Grad Clipping strategy based on discarding gradient-abnormal batches. While highly stable, this method involve overly complex execution logic. In Open-Sora Plan v1.5.0, we optimize the strategy by maintaining the gradient norm threshold via an exponential moving average (EMA). Gradients exceeding the threshold are clipped accordingly. This approach effectively extends the fixed threshold of 1.0, which is commonly used in large-scale models, into a dynamic, training-dependent threshold.
```python
'''
moving_avg_max_grad_norm: the maximum gradient norm maintained via EMA
moving_avg_max_grad_norm_var: the variance of the maximum gradient norm maintained via EMA
clip_threshold: the gradient clipping threshold computed using the 3-sigma rule
ema_decay: the EMA decay coefficient, typically set to 0.99.
grad_norm: grad norm at the current step
'''
clip_threshold = moving_avg_max_grad_norm + 3.0 * (moving_avg_max_grad_norm_var ** 0.5)
if grad_norm <= clip_threshold:
# If the gradient norm is below the clipping threshold, the parameters are updated normally at this step, and both the moving_avg_max_grad_norm and moving_avg_max_grad_norm_var are updated accordingly.
moving_avg_max_grad_norm = ema_decay * moving_avg_max_grad_norm + (1 - ema_decay) * grad_norm
max_grad_norm_var = (moving_avg_max_grad_norm - grad_norm) ** 2
moving_avg_max_grad_norm_var = ema_decay * moving_avg_max_grad_norm_var + (1 - ema_decay) * max_grad_norm_var
# update weights...
else:
# If the gradient norm exceeds the clipping threshold, the gradients are first clipped to reduce the norm to the threshold value before updating the parameters.
clip_coef = grad_norm / clip_threshold
grads = clip(grads, clip_coef) # clipping grads
# update weights...
```
Compared to the strategy in v1.3.0, this approach is simpler to implement and effectively addresses the issue of loss spikes that occur in the later stages of diffusion training when the gradient norm is significantly below 1.0.
### WFVAE with 8x8x8 compression
In version 1.5.0, we increase the temporal compression rate of the VAE from 4× to 8×, reducing the latent shape to half that of the previous version. This enables the generation of videos with higher frame counts.
| Model | THW(C) | PSNR | LPIPS | rFVD |
| ----------------- | ------------- | ------------ | ------------- | ------------ |
| CogVideoX | 4x8x8 (16) | 36.38 | 0.0243 | 50.33 |
| StepVideo | 8x16x16 (16) | 33.61 | 0.0337 | 113.68 |
| LTXVideo | 8x32x32 (128) | 33.84 | 0.0380 | 150.87 |
| Wan2.1 | 4x8x8 (16) | 35.77 | **0.0197** | **46.05** |
| Ours (WF-VAE-M) | 8x8x8 (32) | **36.91** | 0.0205 | 52.53 |
**Test on an open-domain dataset with 1K samples.**
For more details on WFVAE, please refer to [WF-VAE: Enhancing Video VAE by Wavelet-Driven Energy Flow for Latent Video Diffusion Model](https://arxiv.org/abs/2411.17459)
### Training Text-to-Video Diffusion Model
#### Framework —— SUV: A Sparse U-shaped Diffusion Transformer For Fast Video Generation
In Open-Sora Plan v1.3.0, we discuss the strengths and weaknesses of Full 3D Attention and 2+1D Attention. Based on their characteristics, we propose Skiparse Attention, a novel global sparse attention mechanism.
Under a predefined sparsity $k$, Skiparse Attention selects a subsequence of length $\frac{1}{k}$ of the original sequence in an alternating Single-Skip and Group-Skip pattern for attention interaction. This design approximates the effect of Full 3D Attention. As the sparsity increases, the selected positions become more widely spaced; as it decreases, the positions become more concentrated. Regardless of the sparsity, Skiparse Attention remains global.
In Open-Sora Plan v1.5.0, we interpret this sparse interaction pattern as a form of token-level information downsampling. Sparser Skiparse Attention performs more semantic-level interactions, while denser Skiparse Attention captures fine-grained information. Following the multi-scale design principle in neural networks, we introduce Skiparse Attention with U-shaped sparsity variation: low-sparsity Skiparse Attention is used in shallow layers, with Full 3D Attention applied at the shallowest layer, and high-sparsity Skiparse Attention in deeper layers. Inspired by the UNet architecture, we further incorporate long skip connections between stages with identical sparsity. This U-shaped DiT architecture based on Skiparse Attention is referred to as **SUV**.

In Open-Sora Plan v1.5.0, we adopt an SUV architecture based on MMDiT. Skiparse Attention is applied to the video latents, while the text embeddings are only repeated to align with the skiparse-processed latent shape, without any sparsification.
The SUV architecture offers the following advantages:
1、SUV is the first sparsification method proven effective for video generation. Our ablation studies show that it achieves performance comparable to dense DiT within the approximate training steps. Moreover, it can be applied during both pretraining and inference. Testing on the Ascend 910B platform at 121×576×1024 shape shows SUV runs over 35% faster than Dense DiT, with the attention operation alone gaining a speed boost of over 45%.
2、Unlike UNet structures that explicitly downsample feature maps and cause information loss, the U-shaped structure of SUV operates on attention. The shape of the feature map remains unchanged, preserving information while altering only the granularity of token-level interactions.
3、Skiparse Attention and SUV only change the attention computation during the forward pass instead of modifying model weights. This allows dynamic adjustment of sparsity throughout training: lower sparsity for image or low-resolution video training, and higher sparsity for high-resolution video training. As a result, FLOPS grow approximately linearly with increasing of sequence length.
A more detailed analysis of the SUV architecture will be released in a future arXiv update.
#### Training Stage
Our training consists of two stages: Text-to-Image and Text-to-Video.
#### Text-to-Image
Previous studies have shown that image weights trained on synthetic data may negatively impact video training. Therefore, in the v1.5.0 update, we choose to train image weights using a much larger corpus of real-world data, totaling 1.1B images. Since image data come in various resolutions, whereas videos are primarily in a 9:16 aspect ratio, we adopt multi-resolution training for images using five common aspect ratios—(1,1), (3,4), (4,3), (9,16), and (16,9)—along with the Min-Max Token Strategy. In contrast, video training is conducted using a fixed 9:16 resolution.
The difference between Skiparse Attention and Full Attention lies in the token sequences involved in the forward computation; the required weights remain identical. Therefore, we can first train the model using Dense MMDiT with Full 3D Attention, and then fine-tune it to the Sparse MMDiT mode after sufficient training.
**Image-Stage-1:** Training is conducted using 512 Ascend 910B accelerators. We train a randomly initialized Dense MMDiT on 256²-pixel images with multi-resolution enabled. The learning rate is set to 1e-4, with a batch size of 8096. This stage runs for a total of 225k steps.
**Image-Stage-2:** Training is conducted using 384 Ascend 910B accelerators. We train on 384²-pixel images with multi-resolution still enabled. The learning rate remains 1e-4, the batch size is 6144, and training lasts for 150k steps.
**Image-Stage-3:** Training is conducted using 256 Ascend 910B accelerators. We train on 288x512 images with force resolution. The learning rate is 1e-4, the batch size is 4096, and training lasts for 110k steps. This stage completes the Dense MMDiT training.
**Image-Stage-4:** Training is conducted using 256 Ascend 910B accelerators. We initialize the SUV model using the pretrained weights from Dense MMDiT, with skip connections zero-initialized to ensure that the model could produce non-noise outputs at the start. In practice, zero-shot inference reveals that the generated images contained meaningful low-frequency structures. Our experiments confirm that fine-tuning from Dense DiT to SUV converges quickly. This stage uses a fixed resolution of 288×512, a learning rate of 1e-4, a batch size of 4096, and is trained for approximately 160k steps.
#### Text-to-Video
For video training, we fix the aspect ratio at 9:16 and training solely on video data instead of joint training with image data. All training in this stage is performed using 512 Ascend 910B accelerators.
**Video-Stage-1:** Starting from the SUV weights pretrained during the Text-to-Image phase, we train on videos with a shape of 57×288×512 for about 40k steps. The setup includes a learning rate of 6e-5, TP/SP parallelism of 2, gradient accumulation set to 2, a micro batch size of 2, and a global batch size of 1024. Videos are trained at 24 fps, representing approximately 2.4 seconds (57/24 ≈ 2.4s) of content per sample. This stage marks the initial adaptation from image-based to video-based weights, for which shorter video clips are intentionally selected to ensure stable initialization.
**Video-Stage-2:** We further train on videos with a shape of 57×288×512 for 45k steps, keeping the learning rate, TP/SP parallelism, and gradient accumulation settings unchanged. However, the training frame rate is reduced to 12 fps, corresponding to ~4.8 seconds of video content per sample (57/12 ≈ 4.8s). This stage aims to enhance temporal learning without increasing sequence length, serving as preparation for later high-frame-counts training.
**Video-Stage-3:** We train on videos with a shape of 121×288×512 for approximately 25k steps. The learning rate is adjusted to 4e-5, with TP/SP parallelism set to 4, gradient accumulation steps set to 2, a micro batch size of 4, and a global batch size of 1024. In this stage, we revert to a training frame rate of 24 fps.
**Video-Stage-4:** We conduct training on videos with a shape of 121×576×1024 for a total of 16k + 9k steps. The learning rates are set to 2e-5 and 1e-5 for the two phases, respectively. TP/SP parallelism is configured as 4, with gradient accumulation steps set to 4, a micro batch size of 1, and a global batch size of 512.
**Video-Stage-5:** We train on a high-quality subset of the dataset for 5k steps, using a learning rate of 1e-5. TP/SP parallelism is set to 4, with gradient accumulation steps of 4, a micro batch size of 1, and a global batch size of 512.
#### Performance on Vbench
| Model | Parameters | Total Score | Quality Score | Semantic Score | **aesthetic quality** |
| -------------------------- | ---------- | ------------- | ------------- | -------------- | --------------------- |
| Mochi-1 | 10B | 80.13% | 82.64% | 70.08% | 56.94% |
| CogvideoX-2B | 2B | 80.91% | 82.18% | 75.83% | 60.82% |
| CogvideoX-5B | 5B | 81.61% | 82.75% | 77.04% | 61.98% |
| Step-Video-T2V | 30B | 81.83% | 84.46% | 71.28% | 61.23% |
| CogvideoX1.5-5B | 5B | 82.17% | 82.78% | **79.76%** | 62.79% |
| Gen-3 | - | 82.32% | 84.11% | 75.17% | 63.34% |
| HunyuanVideo (Open-Source) | 13B | **83.24%** | **85.09%** | 75.82% | 60.36% |
| Open-Sora Plan v1.5.0 | 8B | 83.02% | 84.24% | 78.18% | **66.89%** |
### Training Image-to-Video Diffusion Model
Coming Soon...
### Future Work
Currently, open-source models such as Wan2.1 have achieved performance comparable to closed-source commercial counterparts. Given the gap in computing resources and data availability compared to industry-scale efforts, the future development of the Open-Sora Plan will focus on the following directions:
1、Latents Cache。
In the training process of Text-to-Video models, the data must be processed through two key modules—the Variational Autoencoder (VAE) and the Text Encoder—to extract features from both video/images and their corresponding prompts. These encoded features serve as inputs to the training model. However, in existing industry practices, feature encoding is redundantly performed on the multimodal training dataset during every training epoch. This leads to additional computational overhead and significantly prolongs the total training time.
Specifically, in conventional training pipelines, the VAE and Text Encoder modules are typically kept resident in GPU memory to perform feature encoding in real time during each epoch. While this ensures on-the-fly encoding, it also results in persistently high GPU memory usage, becoming a major bottleneck for training efficiency. This issue is exacerbated when handling large-scale datasets or complex models, where memory constraints further limit model capacity and training speed.
To address the above issue, we propose an optimization strategy that replaces repeated feature computation with feature lookup. The core idea is to decouple feature encoding from model training. Specifically, during pretraining or the first training epoch, we compute and store the most computationally expensive text prompt features in external high-performance storage. During subsequent training, the model directly loads these precomputed features from storage, avoiding redundant encoding operations. This design significantly reduces computational overhead and GPU memory usage, allowing more memory to be allocated to model training.
Based on the following configuration environment, we compare the training time per epoch and per step before and after applying the feature caching strategy. Experimental results show that storing precomputed features reduces multi-epoch training time by approximately 30% and frees up around 20% of GPU memory resources.
| **Configuration** | **Details** |
| :---------------: | :-----------------------------------------: |
| Model | Open-Sora Plan v1.5.0 (2B-level parameters) |
| Dataset | 100K images and 10K videos |
| Accelerators | 8× Nvidia A800 GPUs |
| Feature Storage | Huawei OceanStor AI Storage |
Test cases:
| **Training Stage** | **Test Type** | **Batch Size** | **Time per Step** | **Time per Epoch** | **Memory Usage** |
| ------------------ | ---------------------- | -------------- | ----------------- | ------------------ | ---------------- |
| Low-Res Images | General Method | 64 | 6.53s | 21 min 12s | 56 GB |
| | Feature Caching Method | 64 | 4.10s | 13 min 19s | 40 GB |
| | General Method | 128 | 12.78s | 20 min 39s | 74 GB |
| | Feature Caching Method | 128 | 7.81s | 12 min 38s | 50 GB |
| Low-Res Videos | General Method | 8 | 8.90s | 26 min 23s | 68 GB |
| | Feature Caching Method | 8 | 7.78s | 23 min 05s | 51 GB |
| High-Res Videos | General Method | 4 | 17.00s | 101 min | 71 GB |
| | Feature Caching Method | 4 | 16.00s | 97 min | 57 GB |
2、Improved DiT pretraining with sparse or linear attention. In v1.3.0, we introduce the first DiT pretrained with sparse attention in the community. This is extended in v1.5.0 into the SUV architecture, enabling sparse DiT to achieve performance comparable to its dense counterpart. While sparse and linear attention have demonstrated significant success in the LLM domain, their application in video generation remains underexplored. In future versions, we plan to further investigate the integration of sparse and linear attention into video generation models.
3、MoE-based DiT. Since the release of [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1), the MoE (Mixture-of-Experts) paradigm has become a common approach for scaling LLMs to larger parameter sizes. Currently, open-source video generation models are capped at around 14B parameters, which is still relatively small compared to the 100B+ scales in the LLM field. Incorporating MoE into the DiT architecture, and exploring its combination with sparse and linear attention, is a future direction under consideration by the Open-Sora Plan team.
4、Unified video generation models for both generation and understanding. The March release of GPT-4o demonstrates that unified architectures combining generation and understanding can offer fundamentally different capabilities compared to purely generative models. In the video domain, we should similarly anticipate the potential breakthroughs that such unified generative models might bring.
5、Enhancing Image-to-Video generation models. Current approaches in this field still largely follow either the SVD paradigm or the inpainting-based paradigm adopted since Open-Sora Plan v1.2.0. Both approaches require extensive fine-tuning of pretrained Text-to-Video models. From a practical standpoint, Text-to-Video is more aligned with academic exploration, while Image-to-Video is more relevant to real-world production scenarios. As a result, developing a new paradigm for Image-to-Video will be a key focus for the Open-Sora Plan team moving forward.
================================================
FILE: docs/Report-v1.5.0_cn.md
================================================
## Report v1.5.0
在2024年的10月,我们发布了Open-Sora Plan v1.3.0,第一次将一种稀疏化的attention结构——skiparse attention引入video generation领域。同时,我们采用了高效的WFVAE,使得训练时的编码时间和显存占用大大降低。
在Open-Sora Plan v1.5.0中,Open-Sora Plan引入了几个关键的更新:
1、更好的sparse dit——SUV。在skiparse attention的基础上,我们将sparse dit扩展至U形变化的稀疏结构,使得在保持速度优势的基础上sparse dit可以取得和dense dit相近的性能。
2、更高压缩率的WFVAE。在Open-Sora Plan v1.5.0中,我们尝试了8x8x8下采样率的WFVAE,它在性能上媲美社区中广泛存在的4x8x8下采样率的VAE的同时latent shape减半,降低attention序列长度。
3、data和model scaling。在Open-Sora Plan v1.5.0中,我们收集了1.1B的高质量图片数据和40m的高质量视频数据,并将模型大小scale到8.5B,使最终得到的模型呈现出不俗的性能。
4、更简易的Adaptive Grad Clipping。相比于version 1.3.0中较复杂的丢弃污点batch的策略,在version 1.5.0中我们简单地维护一个adaptive的grad norm threshold并clipping,以此更适应各种并行策略的需要。
Open-Sora Plan v.1.5.0全程在昇腾910系列加速卡上完成训练和推理,并采用mindspeed-mm训练框架适配并行策略。
### Open-Source Release
Open-Sora Plan v1.5.0的开源包括:
1、所有训练和推理代码。你也可以在[MindSpeed-MM](https://gitee.com/ascend/MindSpeed-MM)官方仓库找到open-sora plan v1.5.0版本的实现。
2、8x8x8下采样的WFVAE权重以及8.5B的SUV去噪器权重。
## Detailed Technical Report
### Data collection and processing
我们共收集了来自Recap-DataComp-1B、Coyo700M、Laion-aesthetic的共1.1B图片数据。对于图片数据,我们不进行除了分辨率之外的筛选。我们的视频数据来自于Panda70M以及其他自有数据。对于视频数据,我们采用与Open-Sora Plan v1.3.0相同的处理策略进行筛选,最终数据量为40m的高质量视频数据。
### Adaptive Grad Clipping
在Open-Sora Plan v1.3.0中,我们介绍了一种基于丢弃梯度异常batch的Adaptive Grad Clipping策略,这种策略具有很高的稳定性,但是执行逻辑过于复杂。因此,在Open-Sora Plan v1.5.0中,我们选择将该策略进行优化,采用EMA方式维护grad norm的threshold,并在grad norm超过该threshold时裁剪到threshold以下。该策略本质上是将大模型领域常用的1.0常数grad norm threshold扩展为一个随着训练进程动态变化的threshold。
```python
'''
moving_avg_max_grad_norm: EMA方式维护的最大grad norm
moving_avg_max_grad_norm_var: EMA方式维护的最大grad norm的方差
clip_threshold: 根据3 sigma策略计算得到的梯度裁剪阈值
ema_decay: EMA衰减系数,一般为0.99
grad_norm: 当前step的grad norm
'''
clip_threshold = moving_avg_max_grad_norm + 3.0 * (moving_avg_max_grad_norm_var ** 0.5)
if grad_norm <= clip_threshold:
# grad norm小于裁剪阈值,则该step参数正常更新,同时更新维护的moving_avg_max_grad_norm 和 moving_avg_max_grad_norm_var
moving_avg_max_grad_norm = ema_decay * moving_avg_max_grad_norm + (1 - ema_decay) * grad_norm
max_grad_norm_var = (moving_avg_max_grad_norm - grad_norm) ** 2
moving_avg_max_grad_norm_var = ema_decay * moving_avg_max_grad_norm_var + (1 - ema_decay) * max_grad_norm_var
参数更新...
else:
# grad norm大于裁剪阈值,则先裁剪grad使grad norm减少至clip_threshold,再进行参数更新。
clip_coef = grad_norm / clip_threshold
grads = clip(grads, clip_coef) # 裁剪grads
参数更新...
```
该策略相较于v1.3.0中策略实现更简单,且能够很好应对diffusion训练后期grad norm远小于1.0时仍存在loss spike的问题。
### WFVAE with 8x8x8 downsampling
在V1.5.0版本中,我们将VAE的时间压缩率从4倍压缩提高至8倍压缩,使得对于同样原始尺寸的视频,latent shape减少为先前版本的一半,这使得我们可以实现更高帧数的视频生成。
| Model | THW(C) | PSNR | LPIPS | rFVD |
| ----------------- | ------------- | ------------ | ------------- | ------------ |
| CogVideoX | 4x8x8 (16) | 36.38 | 0.0243 | 50.33 |
| StepVideo | 8x16x16 (16) | 33.61 | 0.0337 | 113.68 |
| LTXVideo | 8x32x32 (128) | 33.84 | 0.0380 | 150.87 |
| Wan2.1 | 4x8x8 (16) | 35.77 | **0.0197** | **46.05** |
| Ours (WF-VAE-M) | 8x8x8 (32) | **36.91** | 0.0205 | 52.53 |
**Test on an open-domain dataset with 1K samples.**
WFVAE详情请见[WF-VAE: Enhancing Video VAE by Wavelet-Driven Energy Flow for Latent Video Diffusion Model](https://arxiv.org/abs/2411.17459)
### Training Text-to-Video Diffusion Model
#### Framework —— SUV: A Sparse U-shaped Diffusion Transformer For Fast Video Generation
在Open-Sora Plan v1.3.0中,我们讨论了Full 3D Attention以及2+1D Attention的优劣,并综合他们的特点提出了Skiparse Attention——一种新型的global sparse attention。
在一个事先指定的sparse ratio $k$ 下,Skiparse Attention按照Single Skip - Group Skip交替的方式选定原序列长度 $\frac{1}{k}$ 的子序列进行attention交互,以此达到近似Full 3D Attention的效果。在Skiparse Attention中,sparse ratio越大,子序列在原序列中的位置越稀疏;sparse ratio越小,子序列在原序列中的位置越密集。但无论sparse ratio为多少,Skiparse Attention总是global的。
在Open-Sora Plan v1.5.0中,我们将这种稀疏交互方式看作一种token上的信息下采样,越稀疏的Skiparse Attention是一种更偏语义级的信息交互,越密集的Skiparse Attention是一种更偏细粒度的信息交互。遵循神经网络中多尺度设计的准则,我们在网络中引入U形变化稀疏度的Skiparse Attention,即浅层采用稀疏度低的Skiparse Attention,并在最浅层使用Full 3D Attention,深层采用稀疏度高的Skiparse Attention。特别的,类比UNet的设计,我们在相同稀疏度的Stage之间引入了Long Skip Connection。我们将这种U形变化的基于Skiparse Attention的DiT称之为SUV。

在Open-Sora Plan v1.5.0中我们采用了基于MMDiT的SUV架构。对于video latents,我们对其进行skiparse attention操作,对于text embedding,我们仅对其进行repeat以对齐skiparse后的latent shape而不进行任何稀疏化操作。
SUV架构存在以下优点:
1、SUV是首个在视频生成模型上验证有效的稀疏化方法,在我们的消融实验中表明其在同样训练步数下可以达到接近dense dit的性能,且可以同时应用于预训练和推理中。在910B测试平台下,在121x576x1024的视频shape下,SUV的推理速度相比Dense DiT提升35%以上,其中Attn部分速度提升45%以上。
2、相较于UNet结构对feature map进行显式的下采样造成了信息损失,SUV的U形结构作用在Attention上,feature map的shape并没有发生变化,即信息并未发生损失,改变的只是token间信息交互的粒度。
3、Skiparse Attention及SUV不改变权重大小,只改变forward时attention的计算方式。这使得我们可以随着训练进程动态调整稀疏度,在图片训练或低分辨率视频训练时采用较低的稀疏度,在高分辨率视频训练时提高稀疏度,获得随序列长度近似线性增长的FLOPS。
对SUV架构更细致的分析,将会在后续更新至arxiv。
#### Training Stage
我们的训练包括Text-to-Image和Text-to-Video两个阶段。
#### Text-to-Image
先前的工作表明从合成数据训练得到的图像权重可能会影响视频训练时的效果。因此,在v1.5.0更新中,我们选择在更大的真实数据域内训练图像权重。我们收集了共1.1B的图片数据进行训练。由于图片存在多种不同的分辨率,而视频主要为9:16分辨率,因此我们选择在训练图片权重时开启多分辨率(5个常见宽高比:(1,1), (3,4), (4,3), (9,16), (16,9) )及Min-Max token Strategy训练,而在训练视频时采用固定9:16的宽高比固定分辨率训练。
Skiparse Attention与Full Attention的区别在于前向过程中参与计算的token序列不同,所需要的权重变量则完全相同。因此,我们可以先用Full 3D Attention的Dense MMDiT做训练,并在训练充分后Fine-tune至Sparse MMDiT模式。
**Image-Stage-1:** 采用512张Ascend 910B进行训练。 我们采用随机初始化的Dense MMDiT在256^2px级别分辨率的图片上训练,开启多分辨率。学习率为1e-4,batch size为8096。在这个阶段我们总共训练了225k steps。
**Image-Stage-2:** 采用384张Ascend 910B进行训练。在384^px级别的图片上训练,开启多分辨率训练。学习率为1e-4,batch size为6144,共训练150k step。
**Image-Stage-3:** 采用256张Ascend 910B进行训练。固定288x512分辨率训练。学习率为1e-4,batch size为4096,共训练110k step。Dense MMDiT阶段训练完成。
**Image-Stage-4:** 采用256张Ascend 910B进行训练。采用Dense MMDiT的权重初始化SUV,其中skip connection采用零初始化,保证初始SUV权重能够推出非噪声图片。事实上,zero shot推理得到的图片具备一定的低频信息,我们验证了Dense DiT到SUV的finetune可以很快达成。该阶段固定分辨率为288x512,学习率为1e-4,batch size为4096,共训练约160k step。
#### Text-to-Video
在训练视频时,我们采用的宽高比固定为9:16,且并未采用视频图像联合训练,而是仅用视频数据做训练。以下训练均在512张Ascend 910B上完成。
**Video-Stage-1:** 继承Text-to-Image阶段得到的SUV权重,我们在57x288x512的视频上训练了大约40k step,学习率为6e-5,TP/SP并行度为2,学习率为6e-5,梯度累积次数为2, micro batch size为2,global batch size为1024。在这个阶段,我们采用的train fps为24,即大约57/24≈2.4s的视频内容。该阶段作为图片权重到视频权重迁移的第一个阶段,我们选择了较短的视频训练作为良好的初始化。
**Video-Stage-2:** 我们同样在57x288x512的视频上训练45k step,学习率、TP/SP并行度和梯度累积设置保持不变,但是train fps更改为12,即对应的原视频长度为57/12≈4.8s的内容。该阶段旨在不增加序列长度的同时提高对时序的学习,为后续高帧数训练阶段做准备。
**Video-Stage-3:** 我们在121x288x512的视频上训练约25k step,学习率调整为4e-5、TP/SP并行度设置为4,梯度累积次数设置为2,micro batch size为4,global batch size为1024。在这个阶段我们重新采用train fps为24。
**Video-Stage-4:** 在121x576x1024的视频上共训练16k + 9k step,学习率分别为2e-5和1e-5,TP/SP并行度设置为4,梯度累积次数设置为4,micro batch size为1,global batch size为512。
**Video-Stage-5:** 我们选择数据中的高质量子集训练了5k step,学习率为1e-5,TP/SP并行度设置为4,梯度累积次数设置为4,micro batch size为1,global batch size为512。
#### Performance on Vbench
| Model | Parameters | Total Score | Quality Score | Semantic Score | **aesthetic quality** |
| -------------------------- | ---------- | ------------- | ------------- | -------------- | --------------------- |
| Mochi-1 | 10B | 80.13% | 82.64% | 70.08% | 56.94% |
| CogvideoX-2B | 2B | 80.91% | 82.18% | 75.83% | 60.82% |
| CogvideoX-5B | 5B | 81.61% | 82.75% | 77.04% | 61.98% |
| Step-Video-T2V | 30B | 81.83% | 84.46% | 71.28% | 61.23% |
| CogvideoX1.5-5B | 5B | 82.17% | 82.78% | **79.76%** | 62.79% |
| Gen-3 | - | 82.32% | 84.11% | 75.17% | 63.34% |
| HunyuanVideo (Open-Source) | 13B | **83.24%** | **85.09%** | 75.82% | 60.36% |
| Open-Sora Plan v1.5.0 | 8B | 83.02% | 84.24% | 78.18% | **66.89%** |
### Training Image-to-Video Diffusion Model
Comming Soon...
### Future Work
目前,开源社区已经有与闭源商业版本相当性能的模型,如Wan2.1。鉴于算力和数据相比企业来说仍存在不足,后续Open-Sora Plan团队的改进方向为:
1、Latents Cache。
在Text2Video模型的训练过程中,训练数据需要经过变分自编码器(VAE)和文本编码器(Text Encoder)两个关键模块的处理,以实现对视频/图片和对应引导词的特征编码。这些编码后的特征数据作为模型训练的输入,参与后续训练流程。然而业界训练方案中,每个训练周期(Epoch)都需要对多模态训练数据集进行重复的特征编码计算,这不仅增加了额外的计算开销,还显著延长了整体训练时间。
具体而言,在传统的训练流程中,VAE和Text Encoder模型通常需要常驻于GPU显存中,以便在每个Epoch中实时执行特征编码任务。这种设计虽然确保了特征编码的实时性,但也导致了GPU显存占用率居高不下,成为制约训练效率的主要瓶颈之一。尤其是在处理大规模数据集或复杂模型时,显存资源的紧张会进一步加剧这一问题,限制了模型的参数量和训练速度。
为了解决上述问题,我们提出了一种特征值以查代算的优化方案。该方案的核心思想是将特征编码的计算过程与模型训练过程进行解耦。具体实现方式为:在训练前或首轮训练时计算耗时最高的引导词特征值,将其保存至外置高性能文件存储中。后续的训练过程中,模型可以直接从文件存储中读取这些预计算的特征数据,避免了重复的特征编码计算。这种设计不仅显著减少了计算资源的浪费,还大幅降低了GPU显存的占用率,使更多的显存资源可用于模型训练。
基于以下配置环境,统计使用特征数据存储前后的单个epoch及单个step的训练数据。实验表明,特征值存储方案**可缩短约30%多轮迭代训练时间,同时释放约20%显存资源。**
| 配置环境 | 详细信息 |
| :--------: | :-------------------------------: |
| 模型 | Open-Sora Plan v1.5.0 with 2B量级 |
| 数据集 | 100K图片及10K视频 |
| GPU服务器 | 8张Nvidia A800 |
| 特征值存储 | 华为OceanStor AI存储 |
测试数据:
| 训练阶段 | 测试类型 | Batch Size | 单Step耗时 | 单Epoch耗时 | 显存占用 |
| ------------ | ---------------- | ---------- | ---------- | ----------- | -------- |
| 低分辨率图片 | 通用方案 | 64 | 6.53s | 21min12s | 56GB |
| | 特征数据存储方案 | 64 | 4.10s | 13min19s | 40GB |
| | 通用方案 | 128 | 12.78s | 20min39s | 74GB |
| | 特征数据存储方案 | 128 | 7.81s | 12min38s | 50GB |
| 低分辨率视频 | 通用方案 | 8 | 8.90s | 26min23s | 68GB |
| | 特征数据存储方案 | 8 | 7.78s | 23min05s | 51GB |
| 高分辨率视频 | 通用方案 | 4 | 17s | 101min | 71GB |
| | 特征数据存储方案 | 4 | 16s | 97min | 57GB |
2、更好的基于稀疏化attention or 线性attention预训练的DiT。在V1.3.0中,我们推出了社区中第一个基于稀疏attention预训练的DiT,并在V1.5.0版本中将其扩展为SUV架构,使稀疏DiT获得了与Dense DiT相当的模型性能。稀疏attention和线性attention在LLM领域已经获得了很大的成功,但在视频生成领域中的应用仍不够明显。在后续版本中,我们将进一步探索稀疏attention和线性attention在video generation领域的应用。
3、基于MoE的DiT。自Mixtral 8x7B发布以来,LLM领域通常会采用MoE的方式将模型scale至更大的参数量。目前开源视频模型的最大大小仅限于14B,相比于LLM领域上百B的参数量来说仍属于小模型。在DiT架构中引入MoE,以及MoE与稀疏attention和线性attention的结合,是Open-Sora Plan团队未来考虑的方向。
4、生成和理解统一的视频生成模型。3月份gpt-4o的更新让大家认识到了生成理解统一架构的生成模型能够获得与纯生成模型完全不同的能力。在视频领域,我们同样应该期待一个统一的生成模型能够为我们带来哪些惊喜。
5、更好的Image-to-Video模型。目前Image-to-Video领域仍基本遵循SVD范式和Open-Sora Plan v1.2.0起采用的Inpainting范式。这两种范式都需要在Text-to-Video模型权重的基础上进行长时间的finetune。从应用意义上看,Text-to-Video更接近于学术上的探索,而Image-to-Video则更贴近现实的生产环境。因此,Image-to-Video的更新范式也会是Open-Sora Plan团队未来的重点探索方向。
================================================
FILE: docs/VAE.md
================================================
### Data prepare
The organization of the training data is easy. We only need to put all the videos recursively in a directory. This makes the training more convenient when using multiple datasets.
``` shell
Training Dataset
|——sub_dataset1
|——sub_sub_dataset1
|——video1.mp4
|——video2.mp4
......
|——sub_sub_dataset2
|——video3.mp4
|——video4.mp4
......
|——sub_dataset2
|——video5.mp4
|——video6.mp4
......
|——video7.mp4
|——video8.mp4
```
### Training
``` shell
bash scripts/causalvae/train.sh
```
We introduce the important args for training.
| Argparse | Usage |
|:---|:---|
|_Training size_||
|`--num_frames`|The number of using frames for training videos|
|`--resolution`|The resolution of the input to the VAE|
|`--batch_size`|The local batch size in each GPU|
|`--sample_rate`|The frame interval of when loading training videos|
|_Data processing_||
|`--video_path`|/path/to/dataset|
|_Load weights_||
|`--model_name`| `CausalVAE` or `WFVAE`|
|`--model_config`|/path/to/config.json The model config of VAE. If you want to train from scratch use this parameter.|
|`--pretrained_model_name_or_path`|A directory containing a model checkpoint and its config. Using this parameter will only load its weight but not load the state of the optimizer|
|`--resume_from_checkpoint`|/path/to/checkpoint It will resume the training process from the checkpoint including the weight and the optimizer.|
### Inference
``` shell
bash scripts/causalvae/rec_video.sh
```
We introduce the important args for inference.
| Argparse | Usage |
|:---|:---|
|_Ouoput video size_||
|`--num_frames`|The number of frames of generated videos|
|`--height`|The resolution of generated videos|
|`--width`|The resolution of generated videos|
|_Data processing_||
|`--video_path`|The path to the original video|
|`--rec_path`|The path to the generated video|
|_Load weights_||
|`--ae_path`|/path/to/model_dir. A directory containing the checkpoint of VAE is used for inference and its model config.json|
|_Other_||
|`--enable_tilintg`|Use tiling to deal with videos of high resolution and long duration|
|`--save_memory`|Save memory to inference but lightly influence quality|
### Evaluation
The evaluation process consists of two steps:
Reconstruct videos in batches: `bash scripts/causalvae/prepare_eval.sh`
Evaluate video metrics: `bash scripts/causalvae/eval.sh`
To simplify the evaluation, environment variables are used for control. For step 1 (`bash scripts/causalvae/prepare_eval.sh`):
```bash
# Experiment name
EXP_NAME=wfvae
# Video parameters
SAMPLE_RATE=1
NUM_FRAMES=33
RESOLUTION=256
# Model weights
CKPT=ckpt
# Select subset size (0 for full set)
SUBSET_SIZE=0
# Dataset directory
DATASET_DIR=test_video
```
For step 2 (`scripts/causalvae/eval.sh`):
```bash
# Experiment name
EXP_NAME=wfvae-4dim
# Video parameters
SAMPLE_RATE=1
NUM_FRAMES=33
RESOLUTION=256
# Evaluation metric
METRIC=lpips
# Select subset size (0 for full set)
SUBSET_SIZE=0
# Path to the ground truth videos, which can be saved during video reconstruction by setting `--output_origin`
ORIGIN_DIR=video_gen/${EXP_NAME}_sr${SAMPLE_RATE}_nf${NUM_FRAMES}_res${RESOLUTION}_subset${SUBSET_SIZE}/origin
# Path to the reconstructed videos
RECON_DIR=video_gen/${EXP_NAME}_sr${SAMPLE_RATE}_nf${NUM_FRAMES}_res${RESOLUTION}_subset${SUBSET_SIZE}
```
================================================
FILE: examples/cond_pix_path.txt
================================================
examples/test_img1.png
examples/test_img2.png
examples/test_img3.png
================================================
FILE: examples/cond_prompt.txt
================================================
A rocket ascends slowly into the sky.
Along the coast, variously sized boats float on the lake.
The landscape at sunset is profound and expansive.
================================================
FILE: examples/rec_image.py
================================================
import sys
sys.path.append(".")
from PIL import Image
import torch
from torchvision.transforms import ToTensor, Compose, Resize, Normalize, Lambda
from torch.nn import functional as F
import argparse
import numpy as np
from opensora.models.causalvideovae import ae_wrapper
def preprocess(video_data: torch.Tensor, short_size: int = 128) -> torch.Tensor:
transform = Compose(
[
ToTensor(),
Lambda(lambda x: 2. * x - 1.),
Resize(size=short_size),
]
)
outputs = transform(video_data)
outputs = outputs.unsqueeze(0).unsqueeze(2)
return outputs
def main(args: argparse.Namespace):
image_path = args.image_path
short_size = args.short_size
device = args.device
kwarg = {}
# vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir', **kwarg).to(device)
vae = ae_wrapper[args.ae](args.ae_path, **kwarg).eval().to(device)
if args.enable_tiling:
vae.vae.enable_tiling()
vae.vae.tile_overlap_factor = args.tile_overlap_factor
vae.eval()
vae = vae.to(device)
vae = vae.half()
with torch.no_grad():
x_vae = preprocess(Image.open(image_path), short_size)
x_vae = x_vae.to(device, dtype=torch.float16) # b c t h w
latents = vae.encode(x_vae)
latents = latents.to(torch.float16)
image_recon = vae.decode(latents) # b t c h w
x = image_recon[0, 0, :, :, :]
x = x.squeeze()
x = x.detach().cpu().numpy()
x = np.clip(x, -1, 1)
x = (x + 1) / 2
x = (255*x).astype(np.uint8)
x = x.transpose(1,2,0)
image = Image.fromarray(x)
image.save(args.rec_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image_path', type=str, default='')
parser.add_argument('--rec_path', type=str, default='')
parser.add_argument('--ae', type=str, default='')
parser.add_argument('--ae_path', type=str, default='')
parser.add_argument('--model_path', type=str, default='results/pretrained')
parser.add_argument('--short_size', type=int, default=336)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--tile_overlap_factor', type=float, default=0.25)
parser.add_argument('--enable_tiling', action='store_true')
args = parser.parse_args()
main(args)
================================================
FILE: examples/rec_video.py
================================================
import math
import random
import argparse
from typing import Optional
import cv2
import numpy as np
import numpy.typing as npt
import torch
from PIL import Image
from decord import VideoReader, cpu
from torch.nn import functional as F
from torchvision.transforms import Lambda, Compose
import sys
sys.path.append(".")
from opensora.models.causalvideovae import ae_wrapper
from opensora.dataset.transform import ToTensorVideo, CenterCropResizeVideo
def array_to_video(image_array: npt.NDArray, fps: float = 30.0, output_file: str = 'output_video.mp4') -> None:
height, width, channels = image_array[0].shape
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height))
for image in image_array:
image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
video_writer.write(image_rgb)
video_writer.release()
def custom_to_video(x: torch.Tensor, fps: float = 2.0, output_file: str = 'output_video.mp4') -> None:
x = x.detach().cpu()
x = torch.clamp(x, -1, 1)
x = (x + 1) / 2
x = x.permute(0, 2, 3, 1).float().numpy()
x = (255 * x).astype(np.uint8)
array_to_video(x, fps=fps, output_file=output_file)
return
def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor:
decord_vr = VideoReader(video_path, ctx=cpu(0))
total_frames = len(decord_vr)
sample_frames_len = sample_rate * num_frames
# if total_frames > sample_frames_len:
# s = random.randint(0, total_frames - sample_frames_len - 1)
# s = 0
# e = s + sample_frames_len
# num_frames = num_frames
# else:
# s = 0
# e = total_frames
# num_frames = int(total_frames / sample_frames_len * num_frames)
s = 0
e = sample_frames_len
print(f'sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}', video_path,
total_frames)
frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)
video_data = decord_vr.get_batch(frame_id_list).asnumpy()
video_data = torch.from_numpy(video_data)
video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W)
return video_data
def preprocess(video_data: torch.Tensor, height: int = 128, width: int = 128) -> torch.Tensor:
transform = Compose(
[
ToTensorVideo(),
CenterCropResizeVideo((height, width)),
Lambda(lambda x: 2. * x - 1.)
]
)
video_outputs = transform(video_data)
video_outputs = torch.unsqueeze(video_outputs, 0)
return video_outputs
def main(args: argparse.Namespace):
device = args.device
kwarg = {}
# vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir', **kwarg).to(device)
# vae = CausalVAEModelWrapper(args.ae_path, **kwarg).to(device)
vae = ae_wrapper[args.ae](args.ae_path, **kwarg).eval().to(device)
if args.enable_tiling:
vae.vae.enable_tiling()
vae.vae.tile_overlap_factor = args.tile_overlap_factor
# vae.vae.tile_sample_min_size = 512
# vae.vae.tile_latent_min_size = 64
# vae.vae.tile_sample_min_size_t = 29
# vae.vae.tile_latent_min_size_t = 8
# if args.save_memory:
# vae.vae.tile_sample_min_size = 256
# vae.vae.tile_latent_min_size = 32
# vae.vae.tile_sample_min_size_t = 9
# vae.vae.tile_latent_min_size_t = 3
dtype = torch.bfloat16
vae.eval()
vae = vae.to(device, dtype=dtype)
with torch.no_grad():
x_vae = preprocess(read_video(args.video_path, args.num_frames, args.sample_rate), args.height,
args.width)
print("input shape", x_vae.shape)
x_vae = x_vae.to(device, dtype=dtype) # b c t h w
# for i in range(10000):
latents = vae.encode(x_vae)
latents = latents.to(dtype)
video_recon = vae.decode(latents) # b t c h w
print("recon shape", video_recon.shape)
# vae = vae.half()
# from tqdm import tqdm
# with torch.no_grad():
# x_vae = torch.rand(1, 3, 93, 720, 1280)
# print(x_vae.shape)
# x_vae = x_vae.to(device, dtype=torch.float16) # b c t h w
# # x_vae = x_vae.to(device) # b c t h w
# for i in tqdm(range(100000)):
# latents = vae.encode(x_vae)
# print(latents.shape)
# latents = latents.to(torch.float16)
# video_recon = vae.decode(latents) # b t c h w
# print(video_recon.shape)
custom_to_video(video_recon[0], fps=args.fps, output_file=args.rec_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--video_path', type=str, default='')
parser.add_argument('--rec_path', type=str, default='')
parser.add_argument('--ae', type=str, default='')
parser.add_argument('--ae_path', type=str, default='')
parser.add_argument('--model_path', type=str, default='results/pretrained')
parser.add_argument('--fps', type=int, default=30)
parser.add_argument('--height', type=int, default=336)
parser.add_argument('--width', type=int, default=336)
parser.add_argument('--num_frames', type=int, default=100)
parser.add_argument('--sample_rate', type=int, default=1)
parser.add_argument('--device', type=str, default="cuda")
parser.add_argument('--tile_overlap_factor', type=float, default=0.25)
parser.add_argument('--tile_sample_min_size', type=int, default=512)
parser.add_argument('--tile_sample_min_size_t', type=int, default=33)
parser.add_argument('--tile_sample_min_size_dec', type=int, default=256)
parser.add_argument('--tile_sample_min_size_dec_t', type=int, default=33)
parser.add_argument('--enable_tiling', action='store_true')
parser.add_argument('--save_memory', action='store_true')
args = parser.parse_args()
main(args)
================================================
FILE: examples/sora.txt
================================================
A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, along red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is dampand reflective, creating a mirror effect of thecolorful lights. Many pedestrians walk about.
Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered tree sand dramatic snow capped mountains in the distance,mid afternoon lightwith wispy cloud sand a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field
A movie trailer featuring the adventures ofthe 30 year old spacemanwearing a redwool knitted motorcycle helmet, bluesky, saltdesert, cinematic style, shoton 35mm film, vivid colors.
Drone view of waves crashing against the rugged cliffs along Big Sur's garay point beach.The crashing blue waters create white-tipped waves,while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green
shrubbery covers the cliffs edge. The steep drop from the road down to the beach is adramatic feat, with the cliff's edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.
Animated scene features a close-up of a short fluffy monster kneeling beside a melting red candle.The art style is 3D and realistic,with a focus on lighting and texture.The mood of the painting is one of wonder and curiosity, as the monster gazes at the flame with wide eyes and
open mouth. lts pose and expression convey a sense of innocence and playfulness, as if it is exploring the world around it for the first time.The use of warm colors and dramatic lighting further enhances the cozy atmosphere of the image.
A gorgeously rendered papercraft world of a coral reef,rife with colorful fish and sea creatures.
This close-up shot of a Victoria crowned pigeon showcases its striking blue plumage and red chest. Its crest is made of delicate, lacy feathers, while its eye is a striking red color. The bird's head is tilted slightly to the side,giving the impression of it looking regal and majestic. The background is blurred,drawing attention to the bird's striking appearance.
Photorealistic closeup video of two pirate ships battling each other as they sail inside a cup of coffee.
A young man at his 20s is sitting on a piece of cloud in the sky, reading a book.
A petri dish with a bamboo forest growing within it that has tiny red pandas running around.
The camera rotates around a large stack of vintage televisions all showing different programs-1950s sci-fi movies, horror movies, news, static, a 1970s sitcom, etc, set inside a large New York museum gallery.
3D animation of a small, round, fluffy creature with big, expressive eyes explores a vibrant, enchanted forest. The creature, a whimsical blend of a rabbit and a squirrel, has soft blue fur and a bushy, striped tail. It hops along a sparkling stream,its eyes wide with wonder. The forest is alive with magical elements: flowers that glow and change colors, trees with leaves in shades of purple and silver, and small floating lights that resemble fireflies. The creature stops to interact playfully with a group of tiny, fairy-like beings dancing around a mushroom ring. The creature looks up in awe at a large, glowing tree that seems to be the heart of the forest.
Historical footage of California during the gold rush.
A close up view of a glass sphere that has a zen garden within it. There is a small dwarf in the sphere who is raking the zen garden and creating patterns in the sand.
Extreme close up of a 24 year old woman's eye blinking, standing in Marrakech during magic hour, cinematic film shot in 70mm, depth of field,vivid colors, cinematic.
A cartoon kangaroo disco dances.
A beautiful homemade video showing the people of Lagos, Nigeria in the year 2056. Shot with a mobile phone camera.
A cat waking up its sleeping owner demanding breakfast.The owner tries to ignore the cat, but the cat tries new tactics and finally the owner pulls out a secret stash of treats from under the pillow to hold the cat off a little longer.
Borneo wildlife on the Kinabatangan River
A Chinese Lunar New Year celebration video with Chinese Dragon.
The camera follows behind a white vintage SUv with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from it's tires, the sunlight shines on the Suv as it speeds along the dirt road,casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars orvehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains with a clear blue sky above with wispy clouds.
Reflections in the window of a train traveling through the Tokyo suburbs.
A drone camera circles around a beautiful historic church built on a rocky outcropping along the Amalfi Coast, the view showcases historic and magnificent architectural details and tiered pathways and patios, waves are seen crashing against the rocks below as the view overlooks the horizon of the coastal waters and hilly landscapes of the Amalfi Coast ltaly, several distant people are seen walking and enjoying vistas on patios of the dramatic ocean views, the warm glow of the afternoon sun creates a magical and romantic feeling to the scene, the view is stunning captured with beautiful photography
A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. lts tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock,its claws raised and ready to attack. The crab is brown and spiny,with long legs and antennae. The scene is captured from a wide angle,showing the vastness and depth of the ocean. The wateris clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred,creating a depth of field effect.
A flock of paper airplanes flutters through a dense jungle,weaving around trees as if they were migrating birds.
A beautiful silhouette animation shows a wolf howling at the moon,feeling lonely, untilit finds its pack.
New York City submerged like Atlantis.Fish,whales,sea turtles and sharks swim through the streets of New York.
A litter of golden retriever puppies playing in the snow.Their heads pop out of the snow, covered in.
Tour of an art gallery with many beautiful works of art in different styles.
Beautiful, snowy Tokyo city is bustling. The camera moves through the bustling city street, following several people enjoying the beautiful snowy weather and shopping at nearby stalls. Gorgeous sakura petals are flying through the wind along with snowflakes.
A stop motion animation of a flower growing out of the windowsill of a suburban house.
The story of a robot's life in a cyberpunk setting.
An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt, he wears a brown beret and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and the Parisian streets and city in the background, depth of field, cinematic 35mm film.
Basketball through hoop then explodes
Archeologists discovera generic plastic chairin the desert,excavating and dusting it with great care
A grandmother with neatly combed grey hair stands behind a colorful birthday cake with numerous candles at a wood dining room table,expression is one of pure joy and happines with a happy glow in her eye. She leans forward and blows out the candles with a gentle puff, the cake has pink frosting and sprinkles and the candles cease to flicker,the grandmotherwears a light blue blouse adorned with floral patterns,several happy friends and family sitting at the table can be seen celebrating,out of focus.The scene is beautifully captured, cinematic, showing a 3/4 view of the grandmother and the dining room. Warm color tones and soft lighting enhance the mood
Step-printing scene of a person running, cinematic film shot in 35mm
Five gray wolf pups frolicking and chasing each other around a remote gravel road, surrounded by grass. The pups run and leap, chasing each other, and nipping at each other, playing.
Tiltshift of a construction site filled with workers, equipment, and heavy machinery.
A giant, towering cloud in the shape of a man looms overthe earth. The cloud man shoots lighting bolts down to the earth.
A Samoyed and a Golden Retriever dog are playfully romping through a futuristic neon city at night. The neon lights emitted from the nearby buildings glistens off of their fur.
The Glenfinnan Viaduct is a historic railway bridge in Scotland, UK, that crosses over the west highland line between the towns of Mallaig and Fort Wiliam. It is a stunning sight as a steam train leaves the bridge, traveling over the arch-covered viaduct. The landscape is dotted with lush greenery and rocky mountains, creating a picturesque backdrop forthe train journey. The sky is blue and the sun is shining,making for a beautiful day to explore this majestic spot.
The camera directly faces colorful buildings in Burano ltaly. An adorable dalmation looks through a window on a building on the ground floor. Many people are walking and cycling along the canal streets in front of the buildings.
An adorable happy otter confidently stands on a surfboard wearing a yellow lifejacket, riding along turquoise tropical waters near lush tropical islands,3D digital render art style.
This close-up shot of a chameleon showcases its striking color changing capabilities.The background is blurred, drawing attention to the animals striking appearance.
A corgi vlogging itself in tropical Maui.
A white and orange tabby cat is seen happily darting through a dense garden, as if chasing something.Its eyes are wide and happy as it jogs forward, scanning the branches, flowers, and leaves as it walks. The path is narrow as it makes its way between all the plants. the scene is captured from a ground-level angle, following the cat closely, giving a low and intimate perspective. The image is cinematic with warm tones and a grainy texture. The scattered daylight between the leaves and plants above creates awarm contrast, accentuating the cat's orange fur. The shot is clear and sharp, with a shallow depth of field.
Aerial view of Santorini during the blue hour, showcasing the stunning architecture of white Cycladic buildings with blue domes. The caldera views are breathtaking,and the lighting creates a beautiful, serene atmosphere.
Tiltshift of a construction site filled with workers, equipment, and heavy machinery.
================================================
FILE: opensora/__init__.py
================================================
#
================================================
FILE: opensora/acceleration/__init__.py
================================================
================================================
FILE: opensora/acceleration/communications.py
================================================
import torch
import torch.distributed as dist
from einops import rearrange
from opensora.acceleration.parallel_states import hccl_info, lccl_info, enable_LCCL
try:
from lcalib.functional import lcal_all2allvc
except:
lcal_all2allvc = None
def broadcast(input_: torch.Tensor):
sp_size = hccl_info.world_size
src = hccl_info.rank // sp_size * sp_size
dist.broadcast(input_, src=src, group=hccl_info.group)
_COUNT = 0
def _all_to_all(
input_: torch.Tensor,
scatter_dim: int,
gather_dim: int,
):
group = hccl_info.group
sp_size = hccl_info.world_size
input_list = [t.contiguous() for t in torch.tensor_split(input_, sp_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(sp_size)]
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_dim).contiguous()
def _single_all_to_all(
input_: torch.Tensor,
scatter_dim: int,
gather_dim: int,
enable_HCCL=False,
):
if enable_LCCL:
sp_size = lccl_info.world_size
else:
sp_size = hccl_info.world_size
inp_shape = list(input_.shape)
inp_shape[scatter_dim] = inp_shape[scatter_dim] // sp_size
if scatter_dim < 1:
input_t = input_.reshape(
[sp_size, inp_shape[scatter_dim]] + \
inp_shape[scatter_dim + 1:]
)
else:
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
input_t = input_.reshape(
[-1, sp_size, inp_shape[scatter_dim]] + \
inp_shape[scatter_dim + 1:]
).transpose(0, 1).contiguous()
output = torch.empty_like(input_t)
if enable_LCCL and not enable_HCCL:
matrix_count = torch.ones([sp_size, sp_size], dtype=torch.int64, device=input_t.device) * (
input_t.numel() // sp_size)
lcal_all2allvc(input_t, output, matrix_count, lccl_info.group)
else:
dist.all_to_all_single(output, input_t, group=hccl_info.group)
# if scattering the seq-dim, transpose the heads back to the original dimension
if scatter_dim < 1:
output = output.transpose(0, 1).contiguous()
return output.reshape(
inp_shape[: gather_dim] + [inp_shape[gather_dim] * sp_size, ] + inp_shape[gather_dim + 1:])
class _AllToAll(torch.autograd.Function):
"""All-to-all communication.
Args:
input_: input matrix
process_group: communication group
scatter_dim: scatter dimension
gather_dim: gather dimension
"""
@staticmethod
def forward(ctx, input_, scatter_dim, gather_dim, all_to_all_func):
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
ctx.all_to_all = all_to_all_func
output = ctx.all_to_all(input_, scatter_dim, gather_dim)
return output
@staticmethod
def backward(ctx, grad_output):
grad_output = ctx.all_to_all(
grad_output,
ctx.gather_dim,
ctx.scatter_dim,
)
return (
grad_output,
None,
None,
None,
)
def all_to_all_SBH(
input_: torch.Tensor,
scatter_dim: int = 1,
gather_dim: int = 0,
):
return _AllToAll.apply(input_, scatter_dim, gather_dim, _single_all_to_all)
def all_to_all_BSND(
input_: torch.Tensor,
scatter_dim: int = 2,
gather_dim: int = 1,
):
return _AllToAll.apply(input_, scatter_dim, gather_dim, _all_to_all)
def prepare_parallel_data(
hidden_states,
encoder_hidden_states,
attention_mask,
encoder_attention_mask,
pooled_projections=None,
):
def all_to_all(
hidden_states,
encoder_hidden_states,
attention_mask,
encoder_attention_mask,
pooled_projections,
):
# hidden_states (b c t h w) -gather0-> (sp*b c t h w) -scatter2-> (sp*b c t//sp h w)
# encoder_hidden_states (b sp l/sp d) -gather0-> (sp*b sp l/sp d) -scatter1-> (sp*b 1 l/sp d)
# attention_mask (b t*sp h w) -gather0-> (sp*b t*sp h w) -scatter1-> (sp*b t h w)
# encoder_attention_mask (b sp l) -gather0-> (sp*b sp l) -scatter1-> (sp*b 1 l)
# pooled_projections (b sp d) -gather0-> (sp*b sp d) -scatter1-> (sp*b 1 d)
hidden_states = _single_all_to_all(hidden_states, scatter_dim=2, gather_dim=0, enable_HCCL=True)
encoder_hidden_states = _single_all_to_all(encoder_hidden_states, scatter_dim=1, gather_dim=0, enable_HCCL=True)
attention_mask = _single_all_to_all(attention_mask, scatter_dim=1, gather_dim=0, enable_HCCL=True)
encoder_attention_mask = _single_all_to_all(encoder_attention_mask, scatter_dim=1, gather_dim=0, enable_HCCL=True)
if pooled_projections is not None:
pooled_projections = _single_all_to_all(pooled_projections, scatter_dim=1, gather_dim=0, enable_HCCL=True)
return hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask, pooled_projections
sp_size = hccl_info.world_size
frame = hidden_states.shape[2]
assert frame % sp_size == 0, "frame should be a multiple of sp_size"
encoder_hidden_states = rearrange(
encoder_hidden_states, 'b 1 (n x) h -> b n x h',
n=sp_size, x=encoder_hidden_states.shape[2]//sp_size
).contiguous()
hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask, pooled_projections = all_to_all(
hidden_states,
encoder_hidden_states,
attention_mask.repeat(1, sp_size, 1, 1),
encoder_attention_mask.repeat(1, sp_size, 1),
pooled_projections.repeat(1, sp_size, 1)
)
return hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask, pooled_projections
================================================
FILE: opensora/acceleration/parallel_states.py
================================================
import torch
import torch_npu
import torch.distributed as dist
import os
try:
from lcalib.functional import lcal_initialize
enable_LCCL = True
except:
lcal_initialize = None
enable_LCCL = False
class COMM_INFO:
def __init__(self):
self.group = None
self.world_size = 0
self.rank = -1
lccl_info = COMM_INFO()
hccl_info = COMM_INFO()
_SEQUENCE_PARALLEL_STATE = False
def initialize_sequence_parallel_state(sequence_parallel_size):
global _SEQUENCE_PARALLEL_STATE
if sequence_parallel_size > 1:
_SEQUENCE_PARALLEL_STATE = True
initialize_sequence_parallel_group(sequence_parallel_size)
def set_sequence_parallel_state(state):
global _SEQUENCE_PARALLEL_STATE
_SEQUENCE_PARALLEL_STATE = state
def get_sequence_parallel_state():
return _SEQUENCE_PARALLEL_STATE
def initialize_sequence_parallel_group(sequence_parallel_size):
"""Initialize the sequence parallel group."""
rank = int(os.getenv('RANK', '0'))
world_size = int(os.getenv("WORLD_SIZE", '1'))
assert world_size % sequence_parallel_size == 0, "world_size must be divisible by sequence_parallel_size"
# hccl
hccl_info.world_size = sequence_parallel_size
hccl_info.rank = rank
num_sequence_parallel_groups: int = world_size // sequence_parallel_size
for i in range(num_sequence_parallel_groups):
ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)
group = dist.new_group(ranks)
if rank in ranks:
hccl_info.group = group
if enable_LCCL:
assert sequence_parallel_size == 8, "sequence_parallel_size should be 8 when enable_LCCL is True"
rank %= sequence_parallel_size
lccl_info.world_size = sequence_parallel_size
lccl_info.group = lcal_initialize(rank, sequence_parallel_size)
lccl_info.rank = rank
def destroy_sequence_parallel_group():
"""Destroy the sequence parallel group."""
dist.destroy_process_group()
================================================
FILE: opensora/adaptor/__init__.py
================================================
================================================
FILE: opensora/adaptor/bf16_optimizer.py
================================================
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from collections import OrderedDict
import torch
import sys
import os
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed import comm as dist
from deepspeed.runtime.constants import PIPE_REPLICATED
from deepspeed.runtime import ZeROOptimizer
from packaging import version as pkg_version
from deepspeed.git_version_info import version
from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim,
align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,
is_model_parallel_parameter, see_memory_usage, graph_process)
from deepspeed.utils import link_hp_params, fragment_address
from deepspeed.checkpoint import enable_universal_checkpoint
from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
PARAM_SLICE_MAPPINGS)
setattr(sys.modules[__name__], 'fragment_address', fragment_address)
def contigous_flatten(tensors):
return _flatten_dense_tensors([tensor.contiguous() for tensor in tensors])
class BF16_Optimizer(ZeROOptimizer):
def __init__(self,
init_optimizer,
param_names,
mpu=None,
clip_grad=0.0,
norm_type=2,
allgather_bucket_size=5000000000,
dp_process_group=None,
timers=None,
grad_acc_dtype=None,
graph_harvesting=False):
# super().__init__()
# base_class = ZeROOptimizer.__bases__[0]
# # 直接调用基类的 __init__ 方法
# base_class.__init__()
see_memory_usage('begin bf16_optimizer', force=True)
self.timers = timers
self.optimizer = init_optimizer
self.param_names = param_names
self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim)
assert grad_acc_dtype in [torch.float32, torch.bfloat16
], f"BF16Optimizer: Unsupported gradient accumulation data type: {grad_acc_dtype}"
self.grad_acc_dtype = grad_acc_dtype
self.clip_grad = clip_grad
self.norm_type = norm_type
self.mpu = mpu
self.allgather_bucket_size = int(allgather_bucket_size)
self.dp_process_group = dp_process_group
self.dp_rank = dist.get_rank(group=self.dp_process_group)
self.real_dp_process_group = [dp_process_group for i in range(len(self.optimizer.param_groups))]
# Use torch (un)flatten ops
self.flatten = contigous_flatten
self.unflatten = _unflatten_dense_tensors
#align nccl all-gather send buffers to 4-bye boundary
self.nccl_start_alignment_factor = 16
# Build BF16/FP32 groups
self.bf16_groups = []
self.bf16_groups_flat = []
self.bf16_partitioned_groups = []
self.fp32_groups_flat_partition = []
# Maintain different fp32 gradients views for convenience
self.fp32_groups_gradients = []
self.fp32_groups_gradient_dict = {}
self.fp32_groups_gradients_flat = []
self.fp32_groups_actual_gradients_flat = []
self.fp32_groups_gradient_flat_partition = []
self.fp32_groups_has_gradients = []
self.group_paddings = []
self.graph_harvesting = graph_harvesting
if self.using_real_optimizer:
self._setup_for_real_optimizer()
see_memory_usage('end bf16_optimizer', force=True)
def _setup_for_real_optimizer(self):
dp_world_size = dist.get_world_size(group=self.dp_process_group)
self.partition_count = [dp_world_size for i in range(len(self.optimizer.param_groups))]
for i, param_group in enumerate(self.optimizer.param_groups):
see_memory_usage(f'before initializing group {i}', force=True)
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
# grab the original list
trainable_parameters = [param for param in param_group['params'] if param.requires_grad]
self.bf16_groups.append(trainable_parameters)
# create flat bf16 params
self.bf16_groups_flat.append(
self._flatten_dense_tensors_aligned(self.bf16_groups[i],
self.nccl_start_alignment_factor * dp_world_size))
# Make bf16 params point to flat tensor storage
self._update_storage_to_flattened_tensor(tensor_list=self.bf16_groups[i],
flat_tensor=self.bf16_groups_flat[i])
# divide flat weights into equal sized partitions
partition_size = self.bf16_groups_flat[i].numel() // dp_world_size
bf16_dp_partitions = [
self.bf16_groups_flat[i].narrow(0, dp_index * partition_size, partition_size)
for dp_index in range(dp_world_size)
]
self.bf16_partitioned_groups.append(bf16_dp_partitions)
# create fp32 params partition
self.fp32_groups_flat_partition.append(bf16_dp_partitions[partition_id].clone().float().detach())
self.fp32_groups_flat_partition[i].requires_grad = True
num_elem_list = [t.numel() for t in self.bf16_groups[i]]
# create fp32 gradients
self.fp32_groups_gradients_flat.append(
torch.zeros_like(self.bf16_groups_flat[i], dtype=self.grad_acc_dtype))
# track individual fp32 gradients for entire model
fp32_gradients = self._split_flat_tensor(flat_tensor=self.fp32_groups_gradients_flat[i],
num_elem_list=num_elem_list)
self.fp32_groups_gradients.append(fp32_gradients)
self.fp32_groups_gradient_dict[i] = fp32_gradients
# flat tensor corresponding to actual fp32 gradients (i.e., minus alignment padding)
length_without_padding = sum(num_elem_list)
self.fp32_groups_actual_gradients_flat.append(
torch.narrow(self.fp32_groups_gradients_flat[i], 0, 0, length_without_padding))
# flat tensor corresponding to gradient partition
self.fp32_groups_gradient_flat_partition.append(
torch.narrow(self.fp32_groups_gradients_flat[i], 0, partition_id * partition_size, partition_size))
# track fp32 gradient updates
self.fp32_groups_has_gradients.append([False] * len(self.bf16_groups[i]))
# Record padding required for alignment
if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1:
padding = self.bf16_groups_flat[i].numel() - length_without_padding
else:
padding = 0
self.group_paddings.append(padding)
# update optimizer param groups to reference fp32 params partition
param_group['params'] = [self.fp32_groups_flat_partition[i]]
see_memory_usage(f'after initializing group {i}', force=True)
see_memory_usage('before initialize_optimizer', force=True)
self.initialize_optimizer_states()
see_memory_usage('end initialize_optimizer', force=True)
# Need optimizer states initialized before linking lp to optimizer state
self._link_all_hp_params()
self._enable_universal_checkpoint()
self._param_slice_mappings = self._create_param_mapping()
def _enable_universal_checkpoint(self):
for lp_param_group in self.bf16_groups:
enable_universal_checkpoint(param_list=lp_param_group)
def _create_param_mapping(self):
param_mapping = []
for i, _ in enumerate(self.optimizer.param_groups):
param_mapping_per_group = OrderedDict()
for lp in self.bf16_groups[i]:
if lp._hp_mapping is not None:
lp_name = self.param_names[lp]
param_mapping_per_group[lp_name] = lp._hp_mapping.get_hp_fragment_address()
param_mapping.append(param_mapping_per_group)
return param_mapping
def _link_all_hp_params(self):
dp_world_size = dist.get_world_size(group=self.dp_process_group)
for i, _ in enumerate(self.optimizer.param_groups):
# Link bf16 and fp32 params in partition
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
partition_size = self.bf16_groups_flat[i].numel() // dp_world_size
flat_hp_partition = self.fp32_groups_flat_partition[i]
link_hp_params(lp_param_list=self.bf16_groups[i],
flat_hp_partition=flat_hp_partition,
gradient_dict=self.fp32_groups_gradient_dict,
offload_gradient_dict=None,
use_offload=False,
param_group_index=i,
partition_start=partition_id * partition_size,
partition_size=partition_size,
partition_optimizer_state=self.optimizer.state[flat_hp_partition],
dp_group=self.real_dp_process_group[i])
def initialize_optimizer_states(self):
"""Take an optimizer step with zero-valued gradients to allocate internal
optimizer state.
This helps prevent memory fragmentation by allocating optimizer state at the
beginning of training instead of after activations have been allocated.
"""
for param_partition, grad_partition in zip(self.fp32_groups_flat_partition,
self.fp32_groups_gradient_flat_partition):
# In case of grad acc dtype different than FP32, need to cast to high precision.
param_partition.grad = grad_partition.to(
param_partition.dtype) if grad_partition.dtype != param_partition.dtype else grad_partition
self.optimizer.step()
if self.grad_acc_dtype is not torch.float32:
for param_partition in self.fp32_groups_flat_partition:
param_partition.grad = None
self.clear_hp_grads()
def _split_flat_tensor(self, flat_tensor, num_elem_list):
assert sum(num_elem_list) <= flat_tensor.numel()
tensor_list = []
offset = 0
for num_elem in num_elem_list:
dense_tensor = torch.narrow(flat_tensor, 0, offset, num_elem)
tensor_list.append(dense_tensor)
offset += num_elem
return tensor_list
def _update_storage_to_flattened_tensor(self, tensor_list, flat_tensor):
updated_params = self.unflatten(flat_tensor, tensor_list)
for p, q in zip(tensor_list, updated_params):
p.data = q.data
def _flatten_dense_tensors_aligned(self, tensor_list, alignment):
return self.flatten(align_dense_tensors(tensor_list, alignment))
@torch.no_grad()
def step(self, closure=None):
if closure is not None:
raise NotImplementedError(f'{self.__class__} does not support closure.')
all_groups_norm = get_global_norm_of_tensors(input_tensors=self.get_grads_for_norm(),
mpu=self.mpu,
norm_type=self.norm_type,
use_graph=self.graph_harvesting)
self._global_grad_norm = all_groups_norm
assert all_groups_norm > 0.
if self.clip_grad > 0.:
clip_tensors_by_global_norm(input_tensors=self.get_grads_for_norm(for_clipping=True),
max_norm=self.clip_grad,
global_norm=all_groups_norm,
mpu=self.mpu,
use_graph=self.graph_harvesting)
self.optimizer.step()
self.update_lp_params()
self.clear_hp_grads()
def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs):
"""Perform a backward pass and copy the low-precision gradients to the
high-precision copy.
We copy/accumulate to the high-precision grads now to prevent accumulating in the
bf16 grads after successive backward() calls (i.e., grad accumulation steps > 1)
The low-precision grads are deallocated during this procedure.
"""
self.clear_lp_grads()
loss.backward(**bwd_kwargs)
if update_hp_grads:
self.update_hp_grads(clear_lp_grads=clear_lp_grads)
@torch.no_grad()
def update_hp_grads(self, clear_lp_grads=False):
def _update_hp_grads_func(clear_lp_grads=False):
for i, group in enumerate(self.bf16_groups):
for j, lp in enumerate(group):
if lp.grad is None:
continue
hp_grad = self.fp32_groups_gradients[i][j]
assert hp_grad is not None, \
f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{i}][{j}]'
hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape))
lp._hp_grad = hp_grad
self.fp32_groups_has_gradients[i][j] = True
# clear gradients
if clear_lp_grads:
lp.grad._zero()
if self.graph_harvesting:
graph_process(False, _update_hp_grads_func, clear_lp_grads)
else:
_update_hp_grads_func(clear_lp_grads)
#cpu op
for i, group in enumerate(self.bf16_groups):
for j, lp in enumerate(group):
if lp.grad is None:
continue
self.fp32_groups_has_gradients[i][j] = True
@torch.no_grad()
def get_grads_for_reduction(self):
return self.fp32_groups_gradients_flat
@torch.no_grad()
def get_grads_for_norm(self, for_clipping=False):
grads = []
tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
for i, group in enumerate(self.bf16_groups):
for j, lp in enumerate(group):
if not for_clipping:
if hasattr(lp, PIPE_REPLICATED) and lp.ds_pipe_replicated:
continue
if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp)):
continue
if not self.fp32_groups_has_gradients[i][j]:
continue
grads.append(self.fp32_groups_gradients[i][j])
return grads
@torch.no_grad()
def update_lp_params(self):
for i, (bf16_partitions,
fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)):
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
bf16_partitions[partition_id].data.copy_(fp32_partition.data)
# print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True)
# if i == 0:
# print_rank_0(f'{fp32_partition[:10]=}', force=True)
all_gather_dp_groups(groups_flat=self.bf16_groups_flat,
partitioned_param_groups=self.bf16_partitioned_groups,
dp_process_group=self.real_dp_process_group,
start_alignment_factor=self.nccl_start_alignment_factor,
allgather_bucket_size=self.allgather_bucket_size)
def clear_hp_grads(self):
for flat_gradients in self.fp32_groups_gradients_flat:
flat_gradients.zero_()
for i, group in enumerate(self.fp32_groups_gradients):
self.fp32_groups_has_gradients[i] = [False] * len(group)
def clear_lp_grads(self):
for group in self.bf16_groups:
for param in group:
if param.grad is not None:
# Using zero_() fixed memory address for graph replay
param.grad.zero_()
def state_dict(self):
state_dict = {}
state_dict[CLIP_GRAD] = self.clip_grad
state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict()
state_dict[SINGLE_PARTITION_OF_FP32_GROUPS] = self.fp32_groups_flat_partition
state_dict[GROUP_PADDINGS] = self.group_paddings
state_dict[PARTITION_COUNT] = self.partition_count
state_dict[DS_VERSION] = version
state_dict[PARAM_SLICE_MAPPINGS] = self._param_slice_mappings
return state_dict
# Restore base optimizer fp32 weights bfloat16 weights
def _restore_from_bit16_weights(self):
for i, group in enumerate(self.bf16_groups):
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
for bf16_partitions, fp32_partition in zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition):
fp32_partition.data.copy_(bf16_partitions[partition_id].data)
def refresh_fp32_params(self):
self._restore_from_bit16_weights()
def load_state_dict(self,
state_dict_list,
checkpoint_folder,
load_optimizer_states=True,
load_from_fp32_weights=False,
load_serial=None):
if checkpoint_folder:
self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights)
else:
self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights)
def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False):
dp_rank = dist.get_rank(group=self.dp_process_group)
current_rank_sd = state_dict_list[dp_rank]
ckpt_version = current_rank_sd.get(DS_VERSION, False)
assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed"
ckpt_version = pkg_version.parse(ckpt_version)
self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad)
if load_optimizer_states:
self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE])
if load_from_fp32_weights:
for current, saved in zip(self.fp32_groups_flat_partition,
current_rank_sd[SINGLE_PARTITION_OF_FP32_GROUPS]):
src_tensor = _get_padded_tensor(saved, current.numel())
current.data.copy_(src_tensor.data)
if load_optimizer_states:
self._link_all_hp_params()
def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):
self._load_hp_checkpoint_state(checkpoint_folder)
@property
def param_groups(self):
"""Forward the wrapped optimizer's parameters."""
return self.optimizer.param_groups
def _load_hp_checkpoint_state(self, checkpoint_dir):
checkpoint_dir = os.path.join(checkpoint_dir, "zero")
tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
tp_world_size = self.mpu.get_slice_parallel_world_size()
for i, _ in enumerate(self.optimizer.param_groups):
for lp in self.bf16_groups[i]:
if lp._hp_mapping is not None:
#print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
tp_world_size)
def _get_padded_tensor(src_tensor, size):
if src_tensor.numel() >= size:
return src_tensor
padded_tensor = torch.zeros(size, dtype=src_tensor.dtype, device=src_tensor.device)
slice_tensor = torch.narrow(padded_tensor, 0, 0, src_tensor.numel())
slice_tensor.data.copy_(src_tensor.data)
return padded_tensor
================================================
FILE: opensora/adaptor/engine.py
================================================
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import re
import stat
import torch
import hashlib
from collections import defaultdict, OrderedDict, deque
from shutil import copyfile
import gc
from torch.nn.modules import Module
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from typing import Callable, Dict, Union, Iterable
import deepspeed
from deepspeed import comm as dist
from deepspeed.runtime.utils import see_memory_usage, DummyOptim
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.bf16_optimizer import BF16_Optimizer
from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \
ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, \
MUSGD_OPTIMIZER, LION_OPTIMIZER
from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
PLD_THETA, PLD_GAMMA, BFLOAT16, FP16, AMP, GRADIENT_ACCUMULATION_STEPS, \
DATA_PARALLEL_GROUP, GLOBAL_RANK
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.compression import compression_scheduler
from deepspeed.compression.constants import \
WEIGHT_QUANTIZE_IN_FORWARD_ENABLED, \
WEIGHT_QUANTIZATION, SHARED_PARAMETERS, \
WEIGHT_QUANTIZE_ENABLED, \
WEIGHT_QUANTIZE_GROUPS, \
WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE, \
WEIGHT_QUANTIZE_CHANGE_RATIO, \
WEIGHT_QUANTIZE_TYPE, \
WEIGHT_QUANTIZE_ROUNDING, \
WEIGHT_QUANTIZE_VERBOSE, \
WEIGHT_QUANTIZE_KERNEL
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FROZEN_PARAM_FRAGMENTS
from deepspeed.runtime.sparse_tensor import SparseTensor
from deepspeed.runtime import lr_schedules
from deepspeed.utils import groups
from deepspeed.utils import logger, log_dist, instrument_w_nvtx
from deepspeed.utils.timer import NoopTimer, ThroughputTimer, SynchronizedWallClockTimer, \
FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER, \
STEP_MICRO_TIMER, \
FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_GLOBAL_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER, \
STEP_GLOBAL_TIMER
from deepspeed.utils.debug import debug_extract_module_and_param_names
from deepspeed.monitor.monitor import MonitorMaster
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop
from deepspeed.runtime.utils import clip_grad_norm_
from deepspeed.runtime.eigenvalue import Eigenvalue
from deepspeed.runtime.data_pipeline.constants import DATA_SAMPLING, \
DATA_ROUTING, DATA_SAMPLING_ENABLED, CURRICULUM_LEARNING, \
CURRICULUM_LEARNING_ENABLED, DATA_SAMPLING_NUM_WORKERS, RANDOM_LTD, \
RANDOM_LTD_ENABLED, RANDOM_LTD_LAYER_ID, RANDOM_LTD_LAYER_NUM, \
RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE, RANDOM_LTD_LAYER_TOKEN_LR_ENABLED, \
RANDOM_LTD_GLOBAL_BATCH_SIZE, RANDOM_LTD_MICRO_BATCH_SIZE, DATA_EFFICIENCY
from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler
from deepspeed.runtime.data_pipeline.data_routing.scheduler import RandomLTDScheduler
from deepspeed.runtime.data_pipeline.data_routing.helper import remove_random_ltd_state_dict
from deepspeed.runtime.data_pipeline.data_routing.basic_layer import RandomLayerTokenDrop
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
from deepspeed.runtime.pipe.module import PipelineModule
from deepspeed.runtime.utils import get_ma_status
from deepspeed.ops.adam import FusedAdam
from deepspeed.moe.sharded_moe import TopKGate, MOELayer
from deepspeed.moe.layer import MoE
from deepspeed.moe.utils import is_moe_param
from deepspeed.git_version_info import version
from deepspeed.profiling.flops_profiler.profiler import FlopsProfiler
from deepspeed.utils.logging import print_json_dist, print_configuration
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.config import DtypeEnum
from opensora.adaptor.zp_manager import zp_manager
MEMORY_OPT_ALLREDUCE_SIZE = 500000000
DeepSpeedOptimizerCallable = \
Callable[[Union[Iterable[Parameter], Dict[str, Iterable]]], Optimizer]
DeepSpeedSchedulerCallable = Callable[[Optimizer], _LRScheduler]
try:
import apex
from apex import amp
APEX_INSTALLED = True
except ImportError:
# Fail silently so we don't spam logs unnecessarily if user isn't using amp
APEX_INSTALLED = False
def split_half_float_double_sparse(tensors):
device_type = get_accelerator().device_name()
supported_types = [
"torch.{}.HalfTensor".format(device_type), "torch.{}.FloatTensor".format(device_type),
"torch.{}.DoubleTensor".format(device_type), "torch.{}.BFloat16Tensor".format(device_type),
SparseTensor.type()
]
for t in tensors:
assert t.type() in supported_types, f"attempting to reduce an unsupported grad type: {t.type()}"
buckets = []
for i, dtype in enumerate(supported_types):
bucket = [t for t in tensors if t.type() == dtype]
if bucket:
buckets.append((dtype, bucket))
return buckets
class EngineTimers(object):
r"""Wallclock timers for DeepSpeedEngine"""
def __init__(self, enable_micro_timers, enable_global_timers):
self.forward_timers = []
self.backward_timers = []
self.backward_inner_timers = []
self.backward_reduce_timers = []
self.step_timers = []
self.global_timers = []
self.micro_timers = []
if enable_micro_timers:
self.forward_timers += [FORWARD_MICRO_TIMER]
self.backward_timers += [BACKWARD_MICRO_TIMER]
self.backward_inner_timers += [BACKWARD_INNER_MICRO_TIMER]
self.backward_reduce_timers += [BACKWARD_REDUCE_MICRO_TIMER]
self.step_timers += [STEP_MICRO_TIMER]
self.micro_timers += [
FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER,
STEP_MICRO_TIMER
]
if enable_global_timers:
self.forward_timers += [FORWARD_GLOBAL_TIMER]
self.backward_timers += [BACKWARD_GLOBAL_TIMER]
self.backward_inner_timers += [BACKWARD_INNER_GLOBAL_TIMER]
self.backward_reduce_timers += [BACKWARD_REDUCE_GLOBAL_TIMER]
self.step_timers += [STEP_GLOBAL_TIMER]
self.global_timers += [
FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_GLOBAL_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER,
STEP_GLOBAL_TIMER
]
class DeepSpeedEngine(Module):
r"""DeepSpeed engine for training."""
def __init__(
self,
args,
model,
optimizer=None,
model_parameters=None,
training_data=None,
lr_scheduler=None,
mpu=None,
dist_init_required=None,
collate_fn=None,
config=None,
config_class=None,
dont_change_device=False,
):
super(DeepSpeedEngine, self).__init__()
self.dont_change_device = dont_change_device
self.client_optimizer = optimizer
self.client_lr_scheduler = lr_scheduler
self.training_data = training_data
self.collate_fn = collate_fn
self.mpu = mpu
self.all_to_all_group = None
self.data_parallel_group = None
self.global_steps = 0
self.global_samples = 0
self.micro_steps = 0
self.skipped_steps = 0
self.gradient_average = True
self.warn_unscaled_loss = True
self.config = config
self._config = config_class
self.loaded_checkpoint_mp_world_size = None
self.loaded_checkpoint_dp_world_size = None
self.loaded_checkpoint_zp_world_size = None
self.enable_backward_allreduce = True
self.progressive_layer_drop = None
self.eigenvalue = None
self.block_eigenvalue = None
self.gas_boundary_ctr = 0
self.dist_backend = get_accelerator().communication_backend_name()
self.has_moe_layers = False
self.num_experts = []
self.gate_modules = []
self.moe_layers = []
self._step_applied = False
self._global_grad_norm = None
self.use_ds_comm = False # False --> Use torch.dist, True --> Use ds.comm backend.
self.checkpoint_engine = None
self._is_gradient_accumulation_boundary = None
self.scale_wrt_gas = None
self.losses = 0.0
# for debug purposes - can then debug print: debug_get_module_name(module)
debug_extract_module_and_param_names(model)
self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu)
self._do_sanity_check()
see_memory_usage(f"DeepSpeed Engine: After args sanity test", force=self.memory_breakdown())
if mpu is not None:
if self.elasticity_enabled():
if not self.is_elastic_model_parallel_supported():
assert not self.elasticity_enabled(), ("Elasticity is not currently supported"
" with model parallelism.")
self._set_distributed_vars(args)
dist.configure(self._config)
self.monitor = MonitorMaster(self._config.monitor_config)
see_memory_usage(
f"DeepSpeed Engine: Before configure distributed model",
force=self.memory_breakdown(),
)
self.pipeline_parallelism = isinstance(model, PipelineModule)
# Configure distributed model
self._configure_distributed_model(model)
# needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict
self.param_names = {param: name for name, param in model.named_parameters()}
self._get_model_parameters()
see_memory_usage(f"DeepSpeed Engine: After configure distributed model")
# Configure wall clock timers
self.timers = SynchronizedWallClockTimer()
# Throughput timer
self.tput_timer = ThroughputTimer(
batch_size=self.train_batch_size(),
steps_per_output=self.steps_per_print(),
monitor_memory=False,
)
log_dist(f"DeepSpeed Flops Profiler Enabled: {self.flops_profiler_enabled()}", ranks=[0])
if self.flops_profiler_enabled():
self.flops_profiler = FlopsProfiler(self.module, self, self.flops_profiler_recompute_fwd_factor())
if training_data:
self.training_dataloader = self.deepspeed_io(training_data)
else:
self.training_dataloader = None
# Configure optimizer and scheduler
self.optimizer = None
self.basic_optimizer = None
self.lr_scheduler = None
has_optimizer = False
if optimizer or self.optimizer_name():
has_optimizer = True
# If no parameters given by init default to module parameters
if model_parameters is None:
model_parameters = self.module.parameters()
# Convert model parameters from generator to list
if not isinstance(model_parameters, list):
model_parameters = list(model_parameters)
if has_optimizer:
self._configure_optimizer(optimizer, model_parameters)
self._configure_lr_scheduler(lr_scheduler)
self._report_progress(0)
elif self.zero_optimization():
# no optim selected but zero is enabled
self.optimizer = self._configure_zero_optimizer(optimizer=None)
elif self.bfloat16_enabled():
self.optimizer = self._configure_bf16_optimizer(optimizer=None)
# Hook optimizer for snip_momentum pruning
if hasattr(model, 'pruners'):
from deepspeed.compression.helper import rewrite_optimizer_step
self.optimizer.pruners = model.pruners
rewrite_optimizer_step(self.optimizer)
# Bookkeeping for sparse support
self.sparse_tensor_module_names = set()
# if self.sparse_gradients_enabled():
for name, module in self.module.named_modules():
if isinstance(module, (torch.nn.Embedding, torch.nn.EmbeddingBag)) and self.sparse_gradients_enabled():
self.sparse_tensor_module_names.add(name + ".weight")
logger.info("Will convert {} to sparse tensor during training".format(name))
self.save_non_zero_checkpoint = False
self.save_zero_checkpoint = False
if not isinstance(self.optimizer, DeepSpeedZeRoOffload):
self._configure_checkpointing(dist_init_required)
if self.eigenvalue_enabled():
self.eigenvalue = self._configure_eigenvalue()
if self.pld_enabled():
self.progressive_layer_drop = self._configure_progressive_layer_drop()
if self.curriculum_enabled_legacy():
self.curriculum_scheduler_legacy = self._configure_curriculum_scheduler_legacy()
if self.random_ltd_enabled():
random_ltd_config = self.random_ltd_config()
random_ltd_config[RANDOM_LTD_GLOBAL_BATCH_SIZE] = self.train_batch_size()
random_ltd_config[RANDOM_LTD_MICRO_BATCH_SIZE] = self.train_micro_batch_size_per_gpu()
self.random_ltd_scheduler = self._configure_random_ltd_scheduler(random_ltd_config)
# Engine timers
self.engine_timers = EngineTimers(enable_micro_timers=self.wall_clock_breakdown(),
enable_global_timers=self.wall_clock_breakdown()
or self.flops_profiler_enabled())
if self.global_rank == 0:
self._config.print("DeepSpeedEngine configuration")
if self.dump_state():
print_configuration(self, "DeepSpeedEngine")
# Use torch (un)flatten ops
self.flatten = _flatten_dense_tensors
self.unflatten = _unflatten_dense_tensors
def destroy(self):
if self.optimizer is not None and hasattr(self.optimizer, 'destroy'):
self.optimizer.destroy()
def _get_model_parameters(self):
if self.autotuning_profile_model_info():
self.autotuning_model_info = {}
num_params = 0
trainable_num_params = 0
for p in self.module.parameters():
# since user code might call deepspeed.zero.Init() before deepspeed.initialize(), need to check the attribute to check if the parameter is partitioned in zero 3 already or not
n = 0
if hasattr(p, "ds_tensor"): # if the parameter is partitioned in zero 3
n += p.ds_numel
else: # if the parameter is not partitioned in zero 3 yet
n += p.numel()
num_params += n
if p.requires_grad:
trainable_num_params += n
if self.global_rank == 0:
self.autotuning_model_info["num_params"] = num_params * self.mp_world_size
self.autotuning_model_info["trainable_num_params"] = trainable_num_params * self.mp_world_size
logger.info(f"model parameter = {num_params}")
def get_batch_info(self):
"""Get all training batch related settings.
Returns:
train_batch_size (int): The effective training batch size. This is the amount of data
samples that leads to one step of model update.
train_micro_batch_size_per_gpu (int): Batch size to be processed by one GPU in one
step (without gradient accumulation).
gradient_accumulation_steps (int): Number of training steps to accumulate gradients
before averaging and applying them.
"""
return (
self.train_batch_size,
self.train_micro_batch_size_per_gpu,
self.gradient_accumulation_steps,
)
def set_train_batch_size(self, train_batch_size):
"""Adjust the global batch size by increasing or decreasing the number of
micro-batches (i.e., gradient accumulation steps). The size of each micro-batch
(i.e., ``train_micro_batch_size_per_gpu``) is not changed.
Args:
train_batch_size (int): The new global batch size for training.
Raises:
ValueError: if ``train_batch_size`` is not divisible by the
configured micro-batch size and data parallelism.
"""
if train_batch_size % (self.train_micro_batch_size_per_gpu() * self.dp_world_size) != 0:
#print(f'{train_batch_size=} {self.train_micro_batch_size_per_gpu()=} {self.dp_world_size=}')
raise ValueError(f'Train batch size must be divisible by micro-batch data parallelism')
new_gas = train_batch_size // (self.train_micro_batch_size_per_gpu() * self.dp_world_size)
# overwrite config
self._config.train_batch_size = train_batch_size
self._config.gradient_accumulation_steps = new_gas
def set_train_micro_batch_size(self, micro_batch_size):
"""Adjust the micro batch size(i.e., the micro batch size in every data parallel group),
while keep the gradient accumulation steps the same.
Args:
micro_batch_size (int): The new micro batch size for training.
"""
# overwrite config
new_global_batch_size = micro_batch_size * self._config.gradient_accumulation_steps * self.dp_world_size
self._config.train_batch_size = new_global_batch_size
self._config.train_micro_batch_size_per_gpu = micro_batch_size
def set_data_post_process_func(self, post_process_func):
if self.training_dataloader is not None:
self.training_dataloader.post_process_func = post_process_func
def set_custom_curriculum_learning_schedule(self, schedule_func_dict):
if self.training_dataloader is not None and self.curriculum_learning_enabled():
self.training_dataloader.data_sampler.set_custom_curriculum_learning_schedule(schedule_func_dict)
def get_global_grad_norm(self) -> float:
"""Return the 2-norm of all gradients. If there is model parallelism,
the norm will be global.
The computed norm will be cached and reused until the next step() pass.
.. note::
In the presence of model parallelism, this is a collective call
and acts as a barrier among ``mpu.get_model_parallel_group()``.
Returns:
float: norm
"""
return self._global_grad_norm
def __getattr__(self, name):
"""
Pass through attributes defined in the model if they are not overridden by ds-engine.
"""
_module = {}
if "module" in self.__dict__:
_module = self.__dict__['module']
if name in dir(self):
return getattr(self, name)
elif name in dir(_module):
return getattr(_module, name)
else:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
def checkpoint_tag_validation_enabled(self):
return self._config.checkpoint_tag_validation_enabled
def checkpoint_tag_validation_fail(self):
return self._config.checkpoint_tag_validation_fail
def elasticity_enabled(self):
return self._config.elasticity_enabled
def is_elastic_model_parallel_supported(self):
if self.elasticity_enabled():
# Add code for finding number of GPUs per node automatically
if self._config.num_gpus_per_node % self._config.elastic_model_parallel_size == 0:
return True
else:
return False
def pld_enabled(self):
return self._config.pld_enabled
def pld_params(self):
return self._config.pld_params
def pld_theta(self):
return self.pld_params()[PLD_THETA]
def pld_gamma(self):
return self.pld_params()[PLD_GAMMA]
def eigenvalue_enabled(self):
return self._config.eigenvalue_enabled
def eigenvalue_verbose(self):
return self._config.eigenvalue_verbose
def eigenvalue_max_iter(self):
return self._config.eigenvalue_max_iter
def eigenvalue_tol(self):
return self._config.eigenvalue_tol
def eigenvalue_stability(self):
return self._config.eigenvalue_stability
def eigenvalue_gas_boundary_resolution(self):
return self._config.eigenvalue_gas_boundary_resolution
def eigenvalue_layer_name(self):
return self._config.eigenvalue_layer_name
def eigenvalue_layer_num(self):
return self._config.eigenvalue_layer_num
def curriculum_enabled_legacy(self):
return self._config.curriculum_enabled_legacy
def curriculum_params_legacy(self):
return self._config.curriculum_params_legacy
def data_efficiency_enabled(self):
return self._config.data_efficiency_enabled
def data_efficiency_config(self):
return self._config.data_efficiency_config
def data_sampling_enabled(self):
return self._config.data_efficiency_config[DATA_SAMPLING][DATA_SAMPLING_ENABLED]
def data_sampling_config(self):
return self._config.data_efficiency_config[DATA_SAMPLING]
def curriculum_learning_enabled(self):
return self._config.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_ENABLED]
def curriculum_learning_config(self):
return self._config.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING]
def random_ltd_enabled(self):
return self._config.data_efficiency_config[DATA_ROUTING][RANDOM_LTD][RANDOM_LTD_ENABLED]
def random_ltd_config(self):
return self._config.data_efficiency_config[DATA_ROUTING][RANDOM_LTD]
def random_ltd_initialize(self):
assert self.random_ltd_enabled()
random_ltd_config = self.random_ltd_config()
random_ltd_queue = deque([x for x in sorted(random_ltd_config[RANDOM_LTD_LAYER_ID])])
count = 0
for name, layer in self.module.named_modules():
if isinstance(layer, RandomLayerTokenDrop):
if len(random_ltd_queue) != 0 and str(random_ltd_queue[0]) in name: ###[1,2,3]
layer.init_config(random_ltd_config, self.random_ltd_scheduler, count)
random_ltd_queue.popleft()
count += 1
if random_ltd_config[RANDOM_LTD_LAYER_NUM] != count:
raise ValueError(f'random_ltd_layer_num {random_ltd_config[RANDOM_LTD_LAYER_NUM]} must be \
equivalent to the len of random_ltd_layer_id {count}')
if random_ltd_config[RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE][RANDOM_LTD_LAYER_TOKEN_LR_ENABLED]:
assert self.client_lr_scheduler is None
raise ValueError(f'not yet support')
#self.lr_scheduler = lr_schedules.WarmupLayerTokenDecayLR(self.optimizer, self.random_ltd_scheduler)
def wall_clock_breakdown(self):
return self._config.wall_clock_breakdown
def flops_profiler_enabled(self):
return self._config.flops_profiler_config.enabled or self.autotuning_enabled()
def flops_profiler_recompute_fwd_factor(self):
return self._config.flops_profiler_config.recompute_fwd_factor
def flops_profiler_profile_step(self):
step = self._config.flops_profiler_config.profile_step
if self._config.autotuning_config.enabled:
step = self.autotuning_start_profile_step()
return step
def flops_profiler_module_depth(self):
return self._config.flops_profiler_config.module_depth
def flops_profiler_top_modules(self):
return self._config.flops_profiler_config.top_modules
def flops_profiler_detailed(self):
if self._config.autotuning_config.enabled:
return False
return self._config.flops_profiler_config.detailed
def flops_profiler_output_file(self):
return self._config.flops_profiler_config.output_file
def memory_breakdown(self):
return self._config.memory_breakdown
def autotuning_enabled(self):
return self._config.autotuning_config.enabled
def autotuning_start_profile_step(self):
return self._config.autotuning_config.start_profile_step
def autotuning_end_profile_step(self):
return self._config.autotuning_config.end_profile_step
def autotuning_metric_path(self):
path = self._config.autotuning_config.metric_path
if not path:
path = os.path.join(os.getcwd(), "autotuning_metric.json")
return path
def autotuning_model_info_path(self):
path = self._config.autotuning_config.model_info_path
if not path:
path = os.path.join(os.getcwd(), "autotuning_model_info.json")
return path
def autotuning_metric(self):
return self._config.autotuning_config.metric
def autotuning_profile_model_info(self):
return self.autotuning_enabled(
) and self._config.autotuning_config.model_info and self._config.autotuning_config.model_info.get(
"profile", False)
def sparse_gradients_enabled(self):
return self._config.sparse_gradients_enabled
def train_batch_size(self):
return self._config.train_batch_size
def train_micro_batch_size_per_gpu(self):
return self._config.train_micro_batch_size_per_gpu
def optimizer_name(self):
return (self.client_optimizer.__class__.__name__ if self.client_optimizer else self._config.optimizer_name)
def optimizer_params(self):
return self._config.optimizer_params
def optimizer_legacy_fusion(self):
return self._config.optimizer_legacy_fusion
def scheduler_name(self):
return self._config.scheduler_name
def scheduler_params(self):
return self._config.scheduler_params
def quantize_training(self):
return (
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_ENABLED],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_GROUPS],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_CHANGE_RATIO],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_TYPE],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_ROUNDING],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_VERBOSE],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_KERNEL],
)
def zero_optimization(self):
return self._config.zero_enabled
def zero_allow_untested_optimizer(self):
return self._config.zero_allow_untested_optimizer
def zero_force_ds_cpu_optimizer(self):
return self._config.zero_force_ds_cpu_optimizer
def zero_reduce_scatter(self):
return self._config.zero_config.reduce_scatter
def zero_overlap_comm(self):
return self._config.zero_config.overlap_comm
def zero_offload_optimizer(self):
return self._config.zero_config.offload_optimizer
def zero_offload_param(self):
return self._config.zero_config.offload_param
def zero_use_cpu_optimizer(self):
if self._config.zero_config.offload_optimizer is not None:
return self._config.zero_config.offload_optimizer.device in [OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme]
return False
def zero_cpu_offload(self):
if self._config.zero_config.offload_optimizer is not None:
return self._config.zero_config.offload_optimizer.device == OffloadDeviceEnum.cpu
return False
def zero_partial_offload(self):
return getattr(self._config.zero_config.offload_optimizer, "ratio", 1.0)
def zero_sub_group_size(self):
return self._config.zero_config.sub_group_size
def zero_optimization_stage(self):
return self._config.zero_optimization_stage
def mics_shard_size(self):
return self._config.mics_shard_size
def zero_reduce_bucket_size(self):
return self._config.zero_config.reduce_bucket_size
def zero_multi_rank_bucket_allreduce(self):
return self._config.zero_config.use_multi_rank_bucket_allreduce
def zero_allgather_bucket_size(self):
return self._config.zero_config.allgather_bucket_size
def zero_optimization_partition_gradients(self):
return self.zero_optimization_stage() >= ZeroStageEnum.gradients
def zero_optimization_partition_weights(self):
return self.zero_optimization_stage() >= ZeroStageEnum.weights
def is_first_weights_partition_group(self):
ret = True if self.mics_shard_size() < 0 \
and self.zero_optimization_partition_weights() else False
if self.mics_shard_size() > 0 and self.global_rank < self.mics_shard_size():
ret = True
return ret
def zero_contiguous_gradients(self):
return self._config.zero_config.contiguous_gradients
def zero_load_from_fp32_weights(self):
return self._config.zero_config.load_from_fp32_weights
def zero_elastic_checkpoint(self):
return self._config.zero_config.elastic_checkpoint
def zero_max_live_parameters(self):
return self._config.zero_config.max_live_parameters
def zero_max_reuse_distance(self):
return self._config.zero_config.max_reuse_distance
def zero_prefetch_bucket_size(self):
return self._config.zero_config.prefetch_bucket_size
def zero_param_persistence_threshold(self):
return self._config.zero_config.param_persistence_threshold
def zero_model_persistence_threshold(self):
return self._config.zero_config.model_persistence_threshold
def zero_gather_16bit_weights_on_model_save(self):
return self._config.zero_config.gather_16bit_weights_on_model_save
def zero_grad_hooks(self):
return self._config.zero_config.grad_hooks
def zero_legacy_stage1(self):
return self._config.zero_config.legacy_stage1
def zero_ignore_unused_parameters(self):
return self._config.zero_config.ignore_unused_parameters
def graph_harvesting(self):
return self._config.graph_harvesting
def fp16_enabled(self):
return self._config.fp16_enabled
def bfloat16_enabled(self):
return self._config.bfloat16_enabled
def fp16_master_weights_and_gradients(self):
return self._config.fp16_master_weights_and_gradients
def amp_enabled(self):
return self._config.amp_enabled
def amp_params(self):
return self._config.amp_params
def fp16_auto_cast(self):
return self._config.fp16_auto_cast
def loss_scale(self):
return self._config.loss_scale
def gradient_accumulation_steps(self):
return self._config.gradient_accumulation_steps
def use_node_local_storage(self):
return self._config.use_node_local_storage
def load_universal_checkpoint(self):
return self._config.load_universal_checkpoint
@property
def communication_data_type(self):
res = self._config.communication_data_type
if res is not None:
return res
if self.fp16_enabled():
return torch.float16
if self.bfloat16_enabled():
return torch.bfloat16
return torch.float32
@communication_data_type.setter
def communication_data_type(self, value):
self._config.communication_data_type = value
def postscale_gradients(self):
return not self._config.prescale_gradients
def gradient_predivide_factor(self):
return self._config.gradient_predivide_factor
def steps_per_print(self):
return self._config.steps_per_print
def zero_allgather_partitions(self):
return self._config.zero_config.allgather_partitions
def zero_round_robin_gradients(self):
return self._config.zero_config.round_robin_gradients
def zero_hpz_partition_size(self):
return self._config.zero_config.zero_hpz_partition_size
def zero_quantized_weights(self):
return self._config.zero_config.zero_quantized_weights
def zero_quantized_nontrainable_weights(self):
return self._config.zero_config.zero_quantized_nontrainable_weights
def zero_quantized_gradients(self):
return self._config.zero_config.zero_quantized_gradients
def dump_state(self):
return self._config.dump_state
def gradient_clipping(self):
return self._config.gradient_clipping
def dynamic_loss_scale(self):
return self._config.loss_scale == 0
def initial_dynamic_scale(self):
return self._config.initial_dynamic_scale
def dynamic_loss_scale_args(self):
return self._config.dynamic_loss_scale_args
def swap_tensor_config(self):
return self._config.swap_tensor_config
def aio_config(self):
return self._config.aio_config
def get_data_types(self):
model_dtype = torch.float32
if self.fp16_enabled():
model_dtype = torch.float16
elif self.bfloat16_enabled():
model_dtype = torch.bfloat16
if self._config.grad_accum_dtype is None:
if model_dtype == torch.bfloat16 and not self.zero_optimization():
grad_accum_dtype = torch.float32
else:
grad_accum_dtype = model_dtype
else:
grad_accum_dtype = DtypeEnum(self._config.grad_accum_dtype).value
return (model_dtype, grad_accum_dtype)
def _optimizer_has_ckpt_event_prologue(self):
return self.optimizer is not None and hasattr(self.optimizer, 'checkpoint_event_prologue')
def _optimizer_has_ckpt_event_epilogue(self):
return self.optimizer is not None and hasattr(self.optimizer, 'checkpoint_event_epilogue')
def _configure_lr_scheduler(self, client_lr_scheduler):
# First check for scheduler in json configuration
lr_scheduler = self._scheduler_from_config(self.optimizer)
if lr_scheduler:
log_dist(f"DeepSpeed using configured LR scheduler = {self.scheduler_name()}", ranks=[0])
self.lr_scheduler = lr_scheduler
else:
if isinstance(client_lr_scheduler, Callable):
log_dist('DeepSpeed using client callable to create LR scheduler', ranks=[0])
self.lr_scheduler = client_lr_scheduler(self.basic_optimizer)
else:
log_dist('DeepSpeed using client LR scheduler', ranks=[0])
self.lr_scheduler = client_lr_scheduler
log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0])
def _configure_checkpointing(self, dist_init_required):
self.checkpoint_engine = TorchCheckpointEngine()
if self._config is not None and self._config.nebula_config.enabled:
try:
from deepspeed.runtime.checkpoint_engine.nebula_checkpoint_engine import \
NebulaCheckpointEngine
self.checkpoint_engine = NebulaCheckpointEngine(config_params=self._config.nebula_config)
except ImportError as err:
logger.error(f"No torch_nebula was found! Will fall back to torch.save. Details: {err}")
self.checkpoint_engine = TorchCheckpointEngine()
dp_rank = groups._get_sequence_data_parallel_rank()
rank = self.local_rank if self.use_node_local_storage() else dp_rank
# only the first data parallel process needs to store the model checkpoint
# if you want to use node local storage this must be done by rank 0 on each
# node
self.save_non_zero_checkpoint = (rank == 0) or (self.zero_optimization_partition_weights()
and self.is_first_weights_partition_group())
if self.zero_optimization() or self.bfloat16_enabled():
param_rank = dist.get_rank(group=self.optimizer.zp_process_group)
# Only the first parameter parallel process needs to store the
# optimizer state checkpoints for zero
self.save_zero_checkpoint = param_rank == dp_rank
def _scheduler_from_config(self, optimizer):
scheduler_name = self.scheduler_name()
if scheduler_name is not None:
if hasattr(lr_schedules, scheduler_name):
scheduler = getattr(lr_schedules, scheduler_name)
else:
assert hasattr(torch.optim.lr_scheduler,
scheduler_name), f"DeepSpeed does not recognize LR scheduler {scheduler_name}"
scheduler = getattr(torch.optim.lr_scheduler, scheduler_name)
scheduler_params = self.scheduler_params()
instantiated_scheduler = scheduler(optimizer, **scheduler_params)
return instantiated_scheduler
else:
return None
def _set_distributed_vars(self, args):
device_rank = args.device_rank if args is not None and hasattr(args, 'device_rank') else self.local_rank
if device_rank >= 0:
get_accelerator().set_device(device_rank)
self.device = torch.device(get_accelerator().device_name(), device_rank)
self.world_size = dist.get_world_size()
self.global_rank = dist.get_rank()
else:
self.world_size = 1
self.global_rank = 0
self.device = torch.device(get_accelerator().device_name())
# Configure based on command line arguments
def _configure_with_arguments(self, args, mpu):
# After the distributed backend is initialized we are guaranteed the LOCAL_RANK
# environment variable is set. We must align args.local_rank to this value for
# backwards compatibility with scripts relying on [args|self].local_rank containing
# the correct local rank info. _do_args_sanity_check will ensure this is the case.
if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
ompi_local_rank = os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")
local_rank = os.environ.get('LOCAL_RANK', ompi_local_rank)
assert ompi_local_rank == local_rank, f"LOCAL_RANK ({local_rank}) != OMPI_COMM_WORLD_LOCAL_RANK ({ompi_local_rank}), " \
"not sure how to proceed as we're seeing conflicting local rank info."
os.environ['LOCAL_RANK'] = local_rank
self.local_rank = int(os.environ['LOCAL_RANK'])
if hasattr(args, 'local_rank'):
args.local_rank = self.local_rank
# Validate command line arguments
def _do_args_sanity_check(self, args):
assert "LOCAL_RANK" in os.environ or "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ, "DeepSpeed requires the LOCAL_RANK environment " \
"variable, it is set by the deepspeed launcher, deepspeed.init_distributed, or the torch's launcher. If using a " \
"different launcher please ensure LOCAL_RANK is set prior to initializing deepspeed."
if hasattr(args, 'local_rank') and args.local_rank is not None:
assert isinstance(args.local_rank,
int), f"args.local_rank of {args.local_rank} is an unknown type {type(args.local_rank)}"
if args.local_rank >= 0:
env_local_rank = int(os.environ.get("LOCAL_RANK"))
assert (
env_local_rank == args.local_rank
), f"Mismatch in local rank setting, args.local_rank={args.local_rank} but env['LOCAL_RANK']={env_local_rank}."
def _is_supported_optimizer(self, optimizer_name):
return (optimizer_name in DEEPSPEED_OPTIMIZERS or getattr(torch.optim, optimizer_name, None) is not None)
def _supported_optims(self):
FairseqOptimizer = None
try:
from fairseq.optim.fairseq_optimizer import FairseqOptimizer
except ImportError:
pass
expected_optim_types = [Optimizer]
if FairseqOptimizer:
# fairseq optims are not torch.optim objects
expected_optim_types.append(FairseqOptimizer)
return expected_optim_types
# Validate configuration based on command line arguments
def _do_sanity_check(self):
expected_optim_types = self._supported_optims()
expected_optim_types += [type(None), Callable]
assert isinstance(self.client_optimizer, tuple(expected_optim_types)), \
f'Client Optimizer is of unexpected type {type(self.client_optimizer)}'
if not self.client_optimizer:
if self.optimizer_name() is not None:
assert self._is_supported_optimizer(
self.optimizer_name()), "{} is not a supported DeepSpeed Optimizer".format(self.optimizer_name())
if (self.optimizer_name() == LAMB_OPTIMIZER or self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER):
assert (self.dynamic_loss_scale()), "DeepSpeed {} optimizer requires dynamic loss scaling".format(
self.optimizer_name())
# Detect invalid combinations of client optimizer and client scheduler
if isinstance(self.client_lr_scheduler, _LRScheduler):
assert isinstance(self.client_optimizer, Optimizer), \
f'Client Optimizer (type = {type(self.client_optimizer)} is not instantiated but Client LR Scheduler is instantiated'
def _broadcast_model(self):
def is_replicated(p):
if hasattr(p, "ds_status") and p.ds_status is not ZeroParamStatus.AVAILABLE:
return False
return True
for p in self.module.parameters():
# Broadcast the model for different parameters
if is_moe_param(p):
if torch.is_tensor(p) and is_replicated(p):
dist.broadcast(p,
groups._get_expert_broadcast_src_rank(p.group_name),
group=self.expert_data_parallel_group[p.group_name])
else:
if torch.is_tensor(p) and is_replicated(p):
dist.broadcast(p, groups._get_broadcast_src_rank(), group=self.seq_data_parallel_group)
@staticmethod
def __check_params(model: Module, dtype: torch.dtype) -> None:
return
if not all(param.dtype == dtype for param in model.parameters()) and dist.get_rank() == 0:
raise ValueError(f"{dtype} is enabled but the following parameters have dtype that is "
f"not {dtype}: "
f"{[(n, p.dtype) for n, p in model.named_parameters() if p.dtype != dtype]}")
def _set_client_model(self, model):
# register client model in _modules so that nn.module methods work correctly
modules = self.__dict__.get('_modules')
modules['module'] = model
# register module attribute in engine but avoid getattr
self.__dict__['module'] = model
def _configure_distributed_model(self, model):
self._set_client_model(model)
is_zero_init_model = self.zero_optimization_partition_weights() and any(
[hasattr(param, "ds_id") for param in self.module.parameters()])
if self.fp16_enabled():
if is_zero_init_model:
self.__check_params(self.module, torch.half)
self.module.half()
elif self.bfloat16_enabled():
if is_zero_init_model:
self.__check_params(self.module, torch.bfloat16)
self.module.bfloat16()
else:
self.__check_params(self.module, torch.float)
# zero.Init() handles device placement of model
if not (self.dont_change_device or is_zero_init_model):
self.module.to(self.device)
# MoE related initialization
for _, module in self.module.named_modules():
if isinstance(module, MoE):
self.has_moe_layers = True
self.num_experts.append(module.num_experts)
if self.has_moe_layers:
for _, module in self.module.named_modules():
if isinstance(module, TopKGate):
self.gate_modules.append(module)
if self.wall_clock_breakdown():
module.wall_clock_breakdown = True
if isinstance(module, MOELayer):
self.moe_layers.append(module)
if self.wall_clock_breakdown():
module.wall_clock_breakdown = True
# Pass the mpu from here to groups. For subsequent use, just query groups
if self.mpu is not None:
groups.mpu = self.mpu
# Set deepspeed parallelism spec. for the model including expert parallelism
for _, module in self.module.named_modules():
if hasattr(module, 'set_deepspeed_parallelism'):
module.set_deepspeed_parallelism(self._config.use_data_before_expert_parallel_)
# Query the groups module to get information about various parallel groups
self.local_all_to_all_group = None
if self.zero_quantized_gradients():
log_dist("Using quantized gradients", ranks=[0])
self.local_all_to_all_group = groups._get_local_all_to_all_group()
self.data_parallel_group = groups._get_data_parallel_group()
self.dp_world_size = groups._get_data_parallel_world_size()
self.zp_world_size = zp_manager.zp_size
self.seq_data_parallel_group = groups._get_sequence_data_parallel_group()
self.seq_dp_world_size = groups._get_sequence_data_parallel_world_size()
self.mp_world_size = groups._get_model_parallel_world_size()
self.expert_parallel_group = groups._get_expert_parallel_group_dict()
self.expert_data_parallel_group = groups._get_expert_data_parallel_group_dict()
self.sequence_parallel_size = groups._get_sequence_parallel_world_size()
if self.sequence_parallel_size > 1:
self.communication_data_type = self._config.seq_parallel_communication_data_type
if not (self.amp_enabled() or is_zero_init_model):
self._broadcast_model()
# check if parameters are duplicated in optimizer param_groups
def _check_for_duplicates(self, optimizer):
for name, param in self.module.named_parameters():
param_id = id(param)
def ids_list(group):
return [id(param) for param in group]
occurrence = sum([
ids_list(group['params']).count(param_id) if param_id in ids_list(group['params']) else 0
for group in optimizer.param_groups
])
assert occurrence <= 1, f"Parameter with name: {name} occurs multiple times in optimizer.param_groups. Make sure it only appears once to prevent undefined behavior."
def _do_optimizer_sanity_check(self, basic_optimizer):
model_dtype, grad_accum_dtype = self.get_data_types()
zero_enabled = self.zero_optimization()
amp_enabled = self.amp_enabled()
# config based assertions
assert (
not (amp_enabled and zero_enabled)
), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2"
if zero_enabled:
if not is_zero_supported_optimizer(basic_optimizer):
assert (
self.zero_allow_untested_optimizer()
), 'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.'
if self.global_rank == 0:
logger.warning("**** You are using ZeRO with an untested optimizer, proceed with caution *****")
if model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32 and self.zero_optimization_stage(
) == 1 and not self.zero_cpu_offload():
return BFLOAT16
return ZERO_OPTIMIZATION
elif amp_enabled:
if model_dtype != grad_accum_dtype:
raise NotImplementedError(
"Model data type and gradient accumulation data type must be equal to use Amp")
if model_dtype == torch.bfloat16 or model_dtype == torch.float16:
raise NotImplementedError("Cannot enable both amp with (legacy) fp16 or bfloat16 mode")
try:
logger.info("Initializing Apex amp from: {}".format(amp.__path__))
except NameError:
# If apex/amp is available it will be imported above
raise RuntimeError("Unable to import apex/amp, please make sure it is installed")
return AMP
# data type checks
elif model_dtype == grad_accum_dtype:
if model_dtype == torch.bfloat16:
if self.pipeline_parallelism:
logger.warning(
"**** BF16 gradient accumulation is not safe numerically with large number of accumulation steps, proceed with caution *****"
)
return BFLOAT16
else:
raise NotImplementedError(
"Bfloat16 wrapper must use a gradient accumulation type of fp32, enable ZeRO to use Bfloat16 gradient accumulation"
)
if model_dtype == torch.float16:
return FP16
# else optimizer_wrapper = None
elif model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32:
return BFLOAT16
else:
raise NotImplementedError("unsupported mix of model dtype and gradient accumulation type")
return None
# Configure optimizer
def _configure_optimizer(self, client_optimizer, model_parameters):
if client_optimizer is None:
basic_optimizer = self._configure_basic_optimizer(model_parameters)
log_dist(f"Using DeepSpeed Optimizer param name {self.optimizer_name()} as basic optimizer", ranks=[0])
else:
if isinstance(client_optimizer, tuple(self._supported_optims())):
basic_optimizer = client_optimizer
log_dist('Using client Optimizer as basic optimizer', ranks=[0])
else:
basic_optimizer = client_optimizer(model_parameters)
log_dist('Using client callable to create basic optimizer', ranks=[0])
if self.zero_use_cpu_optimizer() and not isinstance(basic_optimizer, deepspeed.ops.adam.DeepSpeedCPUAdam):
if self.zero_force_ds_cpu_optimizer():
msg = f'You are using ZeRO-Offload with a client provided optimizer ({type(basic_optimizer)}) which in most cases will yield poor performance. Please either use deepspeed.ops.adam.DeepSpeedCPUAdam or set an optimizer in your ds-config (https://www.deepspeed.ai/docs/config-json/#optimizer-parameters). If you really want to use a custom optimizer w. ZeRO-Offload and understand the performance impacts you can also set <"zero_force_ds_cpu_optimizer": false> in your configuration file.'
raise ZeRORuntimeException(msg)
basic_optimizer.param_groups[:] = [pg for pg in basic_optimizer.param_groups if len(pg["params"]) != 0]
log_dist("Removing param_group that has no 'params' in the basic Optimizer", ranks=[0])
self._check_for_duplicates(basic_optimizer)
self.basic_optimizer = basic_optimizer
log_dist("DeepSpeed Basic Optimizer = {}".format(basic_optimizer.__class__.__name__), ranks=[0])
optimizer_wrapper = self._do_optimizer_sanity_check(basic_optimizer)
if optimizer_wrapper == ZERO_OPTIMIZATION:
self.optimizer = self._configure_zero_optimizer(basic_optimizer)
elif optimizer_wrapper == AMP:
amp_params = self.amp_params()
log_dist(f"Initializing AMP with these params: {amp_params}", ranks=[0])
model, self.optimizer = amp.initialize(self.module, basic_optimizer, **amp_params)
self._set_client_model(model)
self._broadcast_model()
# TODO: maybe need to broadcast experts differently?
elif optimizer_wrapper == FP16:
self.optimizer = self._configure_fp16_optimizer(basic_optimizer)
elif optimizer_wrapper == BFLOAT16:
self.optimizer = self._configure_bf16_optimizer(basic_optimizer)
else:
self.optimizer = basic_optimizer
log_dist("DeepSpeed Final Optimizer = {}".format(self.optimizer_name()), ranks=[0])
self.compression_scheduler = self._configure_compression_scheduler()
self.quantizer = self._configure_quantization()
def _configure_basic_optimizer(self, model_parameters):
optimizer_parameters = self.optimizer_params()
if optimizer_parameters is None:
optimizer_parameters = {}
# print(optimizer_parameters.keys())
if "max_grad_norm" in optimizer_parameters.keys():
raise ValueError(
"'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details"
)
if self.optimizer_name() in [ADAM_OPTIMIZER, ADAMW_OPTIMIZER]:
torch_adam = optimizer_parameters.pop(TORCH_ADAM_PARAM, False)
adam_w_mode = optimizer_parameters.pop(ADAM_W_MODE, ADAM_W_MODE_DEFAULT)
# Optimizer name of Adam forces AdamW logic unless adam_w_mode is explicitly set
effective_adam_w_mode = self.optimizer_name() == ADAMW_OPTIMIZER or adam_w_mode
if torch_adam:
if not effective_adam_w_mode:
optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters)
else:
optimizer = torch.optim.AdamW(model_parameters, **optimizer_parameters)
else:
if self.zero_use_cpu_optimizer():
from deepspeed.ops.adam import DeepSpeedCPUAdam
optimizer = DeepSpeedCPUAdam(model_parameters,
**optimizer_parameters,
adamw_mode=effective_adam_w_mode)
else:
from deepspeed.ops.adam import FusedAdam
optimizer = FusedAdam(
model_parameters,
**optimizer_parameters,
adam_w_mode=effective_adam_w_mode,
)
elif self.optimizer_name() == ADAGRAD_OPTIMIZER:
if self.zero_use_cpu_optimizer():
from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad
optimizer = DeepSpeedCPUAdagrad(model_parameters, **optimizer_parameters)
else:
optimizer = torch.optim.Adagrad(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == LAMB_OPTIMIZER:
from deepspeed.ops.lamb import FusedLamb
optimizer = FusedLamb(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER:
assert not self.zero_optimization(), "1bit-Adam is not compatible with ZeRO"
from deepspeed.runtime.fp16.onebit.adam import OnebitAdam
optimizer = OnebitAdam(model_parameters, self, **optimizer_parameters)
if not self.fp16_enabled():
logger.warning(f"Currently the convergence of 1-bit Adam is only verified under FP16")
elif self.optimizer_name() == ZERO_ONE_ADAM_OPTIMIZER:
assert not self.zero_optimization(), "0/1 Adam is not compatible with ZeRO"
from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
optimizer = ZeroOneAdam(model_parameters, self, **optimizer_parameters)
if not self.fp16_enabled():
logger.warning(f'Currently the convergence of 0/1 Adam is only verified under FP16')
elif self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER:
assert not self.zero_optimization(), "1bit-Lamb is not compatible with ZeRO"
from deepspeed.runtime.fp16.onebit.lamb import OnebitLamb
optimizer = OnebitLamb(model_parameters, self, **optimizer_parameters)
if not self.fp16_enabled():
logger.warning(f"Currently the convergence of 1-bit Lamb is only verified under FP16")
elif self.optimizer_name() == LION_OPTIMIZER:
if self.zero_use_cpu_optimizer():
from deepspeed.ops.lion import DeepSpeedCPULion
optimizer = DeepSpeedCPULion(model_parameters, **optimizer_parameters)
else:
from deepspeed.ops.lion import FusedLion
optimizer = FusedLion(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == MUADAM_OPTIMIZER:
try:
from mup import MuAdam
except ImportError:
logger.error(f"Install mup to use MuAdam optimizer")
optimizer = MuAdam(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == MUADAMW_OPTIMIZER:
try:
from mup import MuAdamW
except ImportError:
logger.error(f"Install mup to use MuAdamW optimizer")
optimizer = MuAdamW(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == MUSGD_OPTIMIZER:
try:
from mup import MuSGD
except ImportError:
logger.error(f"Install mup to use MuSGD optimizer")
optimizer = MuSGD(model_parameters, **optimizer_parameters)
else:
torch_optimizer = getattr(torch.optim, self.optimizer_name())
optimizer = torch_optimizer(model_parameters, **optimizer_parameters)
return optimizer
def _configure_compression_scheduler(self):
return compression_scheduler(self.module, self._config.compression_config)
def _configure_random_ltd_scheduler(self, configs):
return RandomLTDScheduler(configs)
def _configure_quantization(self):
(
quantize_weight_in_forward,
quantize_enabled,
q_groups,
q_mixed_fp16,
q_change_ratio,
q_type,
q_rounding,
q_verbose,
use_quantizer_kernel,
) = self.quantize_training()
if quantize_enabled and not quantize_weight_in_forward:
assert self.fp16_enabled(
), "MoQ (quantize in optimization step) weight quantization is only supported for FP16"
quantizer = None
if quantize_enabled and not quantize_weight_in_forward:
from deepspeed.runtime.quantize import Quantizer
quantizer = Quantizer(
q_groups,
q_mixed_fp16,
q_change_ratio,
q_type,
q_rounding,
q_verbose,
self.eigenvalue_enabled(),
use_quantizer_kernel,
self.eigenvalue_layer_num() if self.eigenvalue_enabled() else 0,
)
return quantizer
def _configure_fp16_optimizer(self, optimizer):
initial_dynamic_scale = self.initial_dynamic_scale()
dynamic_loss_args = self.dynamic_loss_scale_args()
clip_grad = self.gradient_clipping()
if APEX_INSTALLED:
fused_opts = (apex.optimizers.FusedAdam, FusedAdam)
else:
fused_opts = FusedAdam
if isinstance(optimizer, fused_opts) \
or self.optimizer_name() in [ONEBIT_ADAM_OPTIMIZER, ZERO_ONE_ADAM_OPTIMIZER]:
if self.dynamic_loss_scale():
log_dist(f'Creating fp16 optimizer with dynamic loss scale', ranks=[0])
timers = self.timers if self.wall_clock_breakdown() else NoopTimer()
optimizer = FP16_Optimizer(
optimizer,
deepspeed=self,
dynamic_loss_scale=True,
initial_dynamic_scale=initial_dynamic_scale,
dynamic_loss_args=dynamic_loss_args,
mpu=self.mpu,
clip_grad=clip_grad,
fused_adam_legacy=self.optimizer_legacy_fusion(),
timers=timers,
has_moe_layers=self.has_moe_layers,
)
else:
log_dist(f'Creating fp16 optimizer with static loss scale: {self.loss_scale()}', ranks=[0])
optimizer = FP16_Optimizer(
optimizer,
deepspeed=self,
static_loss_scale=self.loss_scale(),
mpu=self.mpu,
clip_grad=clip_grad,
fused_adam_legacy=self.optimizer_legacy_fusion(),
has_moe_layers=self.has_moe_layers,
)
else:
log_dist(f'Creating fp16 unfused optimizer with dynamic loss scale', ranks=[0])
optimizer = FP16_UnfusedOptimizer(
optimizer,
deepspeed=self,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=dynamic_loss_args,
mpu=self.mpu,
clip_grad=clip_grad,
fused_lamb_legacy=self.optimizer_name() == LAMB_OPTIMIZER,
)
return optimizer
def _configure_bf16_optimizer(self, optimizer):
clip_grad = self.gradient_clipping()
if optimizer is None:
optimizer = DummyOptim(list(self.module.parameters()))
log_dist('Creating BF16 optimizer', ranks=[0])
timers = self.timers if self.wall_clock_breakdown() else NoopTimer()
optimizer = BF16_Optimizer(optimizer,
self.param_names,
mpu=self.mpu,
clip_grad=clip_grad,
allgather_bucket_size=self.zero_allgather_bucket_size(),
dp_process_group=self.seq_data_parallel_group,
timers=timers,
grad_acc_dtype=self.get_data_types()[1],
graph_harvesting=self.graph_harvesting())
return optimizer
def _configure_zero_optimizer(self, optimizer):
zero_stage = self.zero_optimization_stage()
mics_shard_size = self.mics_shard_size()
model_dtype, gradient_accumulation_dtype = self.get_data_types()
timers = self.timers if self.wall_clock_breakdown() else NoopTimer()
if optimizer is None:
optimizer = DummyOptim(list(self.module.parameters()))
if self.zero_legacy_stage1():
raise Exception(
"The deprecated version of ZeRO Stage 1 is not supported in deepspeed >= 0.5.9. Please downgrade to a version less than 0.5.9 if you need to use this deprecated version of ZeRO."
)
if zero_stage <= ZeroStageEnum.gradients:
overlap_comm = self.zero_overlap_comm()
contiguous_gradients = self.zero_contiguous_gradients()
round_robin_gradients = self.zero_round_robin_gradients()
assert not isinstance(optimizer, DummyOptim), "zero stage {} requires an optimizer".format(zero_stage)
log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0])
# Overlap and contiguous grads are meaningless in stage 1 and are ignored
if zero_stage == ZeroStageEnum.optimizer_states:
overlap_comm = False
round_robin_gradients = False
# Non-MoE requires contiguous grads to be disabled w. stage 1
if not self.has_moe_layers:
contiguous_gradients = False
if isinstance(self.module, PipelineModule):
if overlap_comm:
logger.warning("Pipeline parallelism does not support overlapped communication, will be disabled.")
overlap_comm = False
optimizer = DeepSpeedZeroOptimizer(
optimizer,
self.param_names,
timers=timers,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
clip_grad=self.gradient_clipping(),
contiguous_gradients=contiguous_gradients,
reduce_bucket_size=self.zero_reduce_bucket_size(),
use_multi_rank_bucket_allreduce=self.zero_multi_rank_bucket_allreduce(),
allgather_bucket_size=self.zero_allgather_bucket_size(),
dp_process_group=self.seq_data_parallel_group,
expert_parallel_group=self.expert_parallel_group if self.has_moe_layers else None,
expert_data_parallel_group=self.expert_data_parallel_group if self.has_moe_layers else None,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=overlap_comm,
offload_optimizer_config=self.zero_offload_optimizer(),
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps(),
ignore_unused_parameters=self.zero_ignore_unused_parameters(),
partition_grads=zero_stage == ZeroStageEnum.gradients,
round_robin_gradients=round_robin_gradients,
has_moe_layers=self.has_moe_layers,
fp16_master_weights_and_gradients=self.fp16_master_weights_and_gradients(),
gradient_accumulation_dtype=gradient_accumulation_dtype,
communication_data_type=self.communication_data_type,
elastic_checkpoint=self.zero_elastic_checkpoint())
elif zero_stage == ZeroStageEnum.weights:
assert not self.has_moe_layers, "MoE not supported with Stage 3"
if isinstance(optimizer, DummyOptim):
log_dist("Creating ZeRO Offload", ranks=[0])
zero_param_parallel_group = groups._get_zero_param_intra_parallel_group()
if self.zero_hpz_partition_size() > 1 and zero_param_parallel_group is None:
self._set_zero_group_parallelism()
zero_param_parallel_group = groups._get_zero_param_intra_parallel_group()
optimizer = DeepSpeedZeRoOffload(
self.module,
timers=timers,
ds_config=self.config,
overlap_comm=self.zero_overlap_comm(),
prefetch_bucket_size=self.zero_prefetch_bucket_size(),
max_reuse_distance=self.zero_max_reuse_distance(),
max_live_parameters=self.zero_max_live_parameters(),
param_persistence_threshold=self.zero_param_persistence_threshold(),
model_persistence_threshold=self.zero_model_persistence_threshold(),
offload_param_config=self.zero_offload_param(),
mpu=self.mpu,
zero_param_parallel_group=zero_param_parallel_group,
zero_quantized_weights=self.zero_quantized_weights(),
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),
)
else:
log_dist(
f'Creating fp16 ZeRO stage {zero_stage} optimizer,'
f' MiCS is enabled {mics_shard_size>0},'
f' Hierarchical params gather {self._config.mics_hierarchial_params_gather}',
ranks=[0])
if mics_shard_size > 0:
return self._return_mics_optimizer(optimizer, timers)
log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0])
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
optimizer = DeepSpeedZeroOptimizer_Stage3(
self.module,
optimizer,
timers=timers,
ds_config=self.config,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
clip_grad=self.gradient_clipping(),
contiguous_gradients=self.zero_contiguous_gradients(),
reduce_bucket_size=self.zero_reduce_bucket_size(),
prefetch_bucket_size=self.zero_prefetch_bucket_size(),
max_reuse_distance=self.zero_max_reuse_distance(),
max_live_parameters=self.zero_max_live_parameters(),
param_persistence_threshold=self.zero_param_persistence_threshold(),
model_persistence_threshold=self.zero_model_persistence_threshold(),
dp_process_group=self.seq_data_parallel_group,
all2all_process_group=self.local_all_to_all_group,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(),
offload_optimizer_config=self.zero_offload_optimizer(),
offload_param_config=self.zero_offload_param(),
sub_group_size=self.zero_sub_group_size(),
offload_ratio=self.zero_partial_offload(),
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps(),
aio_config=self.aio_config(),
gradient_accumulation_dtype=gradient_accumulation_dtype,
communication_data_type=self.communication_data_type,
zero_hpz_partition_size=self.zero_hpz_partition_size(),
zero_quantized_weights=self.zero_quantized_weights(),
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),
)
else:
raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage))
return optimizer
def _return_mics_optimizer(self, basic_optimizer, timers):
from deepspeed.runtime.zero.mics import MiCS_Optimizer
model_dtype, gradient_accumulation_dtype = self.get_data_types()
optimizer = MiCS_Optimizer(self.module,
basic_optimizer,
timers=timers,
ds_config=self.config,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
clip_grad=self.gradient_clipping(),
contiguous_gradients=self.zero_contiguous_gradients(),
reduce_bucket_size=self.zero_reduce_bucket_size(),
prefetch_bucket_size=self.zero_prefetch_bucket_size(),
max_reuse_distance=self.zero_max_reuse_distance(),
max_live_parameters=self.zero_max_live_parameters(),
param_persistence_threshold=self.zero_param_persistence_threshold(),
model_persistence_threshold=self.zero_model_persistence_threshold(),
dp_process_group=self.seq_data_parallel_group,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(),
offload_optimizer_config=self.zero_offload_optimizer(),
offload_param_config=self.zero_offload_param(),
sub_group_size=self.zero_sub_group_size(),
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps(),
aio_config=self.aio_config(),
gradient_accumulation_dtype=gradient_accumulation_dtype,
communication_data_type=self.communication_data_type)
return optimizer
def _configure_eigenvalue(self):
eigenvalue = Eigenvalue(
verbose=self.eigenvalue_verbose(),
max_iter=self.eigenvalue_max_iter(),
tol=self.eigenvalue_tol(),
stability=self.eigenvalue_stability(),
gas_boundary_resolution=self.eigenvalue_gas_boundary_resolution(),
layer_name=self.eigenvalue_layer_name(),
layer_num=self.eigenvalue_layer_num(),
)
return eigenvalue
def _configure_progressive_layer_drop(self):
pld = ProgressiveLayerDrop(theta=self.pld_theta(), gamma=self.pld_gamma())
return pld
def _configure_curriculum_scheduler_legacy(self):
scheduler = CurriculumScheduler(self.curriculum_params_legacy())
return scheduler
@staticmethod
def is_map_style_dataset(obj):
return hasattr(obj, "__getitem__") and hasattr(obj, "__len__")
@staticmethod
def is_iterable_style_dataset(obj):
return isinstance(obj, torch.utils.data.IterableDataset) # hasattr(obj, "__iter__") should work as well
def dataloader_drop_last(self):
return self._config.dataloader_drop_last
def was_step_applied(self) -> bool:
"""Returns True if the latest ``step()`` produced in parameter updates.
Note that a ``False`` return is not an error condition. Steps are frequently
no-ops, such as between gradient accumulation boundaries or when overflows
occur.
Returns:
bool: Whether the latest ``step()`` modified model parameters.
"""
return self._step_applied
def deepspeed_io(self,
dataset,
batch_size=None,
route=ROUTE_TRAIN,
pin_memory=True,
data_sampler=None,
collate_fn=None,
num_local_io_workers=None):
if not (self.is_map_style_dataset(dataset) or self.is_iterable_style_dataset(dataset)):
raise ValueError("Training data must be a torch Dataset")
if batch_size is None:
batch_size = self.train_micro_batch_size_per_gpu()
if collate_fn is None:
collate_fn = self.collate_fn
# Currently we only use timer in train route
deepspeed_io_timer = None
if route == ROUTE_TRAIN:
deepspeed_io_timer = self.tput_timer
# If mpu is provided, forward world size and parallel rank to sampler.
data_parallel_world_size = self.dp_world_size
data_parallel_rank = self.global_rank
if self.mpu is not None:
data_parallel_world_size = self.mpu.get_data_parallel_world_size()
data_parallel_rank = self.mpu.get_data_parallel_rank()
if data_sampler is None and (route == ROUTE_PREDICT or route == ROUTE_EVAL):
data_sampler = torch.utils.data.DistributedSampler(
dataset,
num_replicas=data_parallel_world_size,
rank=data_parallel_rank,
shuffle=False,
)
deepspeed_dataloader_config = {}
if self.curriculum_learning_enabled():
deepspeed_dataloader_config = {
CURRICULUM_LEARNING: self.curriculum_learning_enabled(),
DATA_EFFICIENCY: self.data_efficiency_config(),
DATA_PARALLEL_GROUP: self.data_parallel_group,
GRADIENT_ACCUMULATION_STEPS: self.gradient_accumulation_steps(),
GLOBAL_RANK: self.global_rank,
DATA_SAMPLING_NUM_WORKERS: self.data_sampling_config()[DATA_SAMPLING_NUM_WORKERS]
}
return DeepSpeedDataLoader(dataset=dataset,
batch_size=batch_size,
pin_memory=pin_memory,
collate_fn=collate_fn,
local_rank=self.local_rank,
tput_timer=deepspeed_io_timer,
num_local_io_workers=num_local_io_workers,
data_sampler=data_sampler,
data_parallel_world_size=data_parallel_world_size,
data_parallel_rank=data_parallel_rank,
dataloader_drop_last=self.dataloader_drop_last(),
deepspeed_dataloader_config=deepspeed_dataloader_config)
def train(self, mode=True):
r""""""
self.warn_unscaled_loss = True
self.module.train(mode)
def eval(self):
r""""""
self.warn_unscaled_loss = True
self.module.train(False)
def _scale_loss_by_gas(self, prescaled_loss):
if isinstance(prescaled_loss, torch.Tensor):
scaled_loss = prescaled_loss / self.gradient_accumulation_steps()
elif isinstance(prescaled_loss, tuple) or isinstance(prescaled_loss, list):
scaled_loss = []
for l in prescaled_loss:
if isinstance(l, torch.Tensor):
scaled_loss.append(l / self.gradient_accumulation_steps())
else:
scaled_loss.append(l)
else:
scaled_loss = prescaled_loss
if self.warn_unscaled_loss:
logger.warning(f"DeepSpeed unable to scale loss because of type: {type(prescaled_loss)}")
self.warn_unscaled_loss = False
return scaled_loss
@instrument_w_nvtx
def forward(self, *inputs, **kwargs):
r"""Execute forward propagation
Arguments:
*inputs: Variable length input list
**kwargs: variable length keyword arguments
"""
if self.autotuning_profile_model_info():
ma = get_ma_status()
else:
see_memory_usage("Engine before forward", force=self.memory_breakdown())
flops_profiler_active = (self.flops_profiler_enabled()
and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0)
# used to check quantization happens at step 0!
if self.global_steps == 0 and hasattr(self, "compression_scheduler"):
self.compression_scheduler.step(step_zero_check=True)
if self.quantizer:
tensor_to_quantize = self.optimizer.bit16_groups if self.zero_optimization_stage(
) == 2 else self.optimizer.fp16_groups
if self.compression_scheduler.weight_quantization_enabled:
self.quantizer.quantize(
tensor_to_quantize,
(self.optimizer.overflow if self.fp16_enabled() else False),
self.eigenvalue_enabled(),
None,
)
if flops_profiler_active:
self.flops_profiler.start_profile(ignore_list=None)
if self.module.training:
if self.progressive_layer_drop:
kwargs.update(self.progressive_layer_drop.get_state())
if self.__class__.__name__ != "PipelineEngine":
# TODO: The above if condition is a HACK since for PipelineEngine
# it's difficult to inject argument in forward pass.
if self.module.training and self.curriculum_enabled_legacy():
self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1)
if self.curriculum_params_legacy()["curriculum_type"] == "seqlen":
kwargs.update({"curriculum_seqlen": self.curriculum_scheduler_legacy.get_current_difficulty()})
if self.module.training and self.random_ltd_enabled():
self.random_ltd_scheduler.update_seq(self.global_steps)
if self.zero_optimization_partition_weights():
# Enable automated discovery of external parameters by indicating that
# we are in a forward pass.
for module in self.module.modules():
module._parameters._in_forward = True
self._start_timers(self.engine_timers.forward_timers)
if self.training_dataloader is None:
self.tput_timer.start()
if self.fp16_auto_cast():
inputs = self._cast_inputs_half(inputs)
# print(f"RANK[{self.global_rank}] self.fp16_auto_cast() is {self.fp16_auto_cast()}")
loss = self.module(*inputs, **kwargs)
# print(f"RANK[{self.global_rank}]'s loss is {loss}")
if self.zero_optimization_partition_weights():
# Disable automated discovery of external parameters
for module in self.module.modules():
module._parameters._in_forward = False
self._stop_timers(self.engine_timers.forward_timers)
if flops_profiler_active:
self.flops_profiler.stop_profile()
if self.autotuning_profile_model_info():
activation_mem = get_ma_status() - ma
self.autotuning_model_info["activation_mem_per_gpu"] = activation_mem
print_json_dist(self.autotuning_model_info, [0], path=self.autotuning_model_info_path())
exit()
else:
see_memory_usage("Engine after forward", force=self.memory_breakdown())
return loss
def _cast_inputs_half(self, inputs):
if isinstance(inputs, (list, tuple)):
new_inputs = []
for v in inputs:
new_inputs.append(self._cast_inputs_half(v))
return inputs.__class__(new_inputs)
elif isinstance(inputs, dict):
new_inputs = {}
for k, v in inputs.items():
new_inputs[k] = self._cast_inputs_half(v)
return new_inputs
elif hasattr(inputs, 'half'):
return inputs.half()
else:
return inputs
def print_forward_breakdown(self, fwd_time):
gate_time = 0.0
moe_time = 0.0
falltoall = 0.0
salltoall = 0.0
for gate in self.gate_modules:
#logger.info(f"Individual TopK gate time: {gate.gate_time:.2f} ms")
gate_time += gate.gate_time
for l in self.moe_layers:
#logger.info(f"MoE layer; total: {l.time_moe:.2f} ms, first alltoall: {l.time_falltoall:.2f}, second alltoall: {l.time_salltoall:.2f}")
moe_time += l.time_moe
falltoall += l.time_falltoall
salltoall += l.time_salltoall
# TODO: Allreduce/average them across ranks for more accurate timing.
# if deepspeed.comm.get_rank() == 0:
log_dist(
f"time (ms) | fwd: {fwd_time:.2f} (fwd_moe: {moe_time:.2f}, 1st_a2a: {falltoall:.2f}, 2nd_a2a: {salltoall:.2f}, top_k: {gate_time:.2f})",
ranks=[0])
@instrument_w_nvtx
def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
assert not (self.bfloat16_enabled() and self.pipeline_parallelism), \
f'allreduce_gradients() is not valid when bfloat+pipeline_parallelism is enabled'
# Pass (PP) gas boundary flag to optimizer (required for zero)
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary()
# ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well
if self.zero_optimization_partition_gradients():
self.optimizer.overlapping_partition_gradients_reduce_epilogue()
# Communicate only at gradient accumulation boundaries
elif self.is_gradient_accumulation_boundary():
if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states and hasattr(
self.optimizer, 'reduce_gradients'):
self.optimizer.reduce_gradients(pipeline_parallel=self.pipeline_parallelism)
else:
self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)
@instrument_w_nvtx
def backward(self, loss, allreduce_gradients=True, release_loss=False, retain_graph=False, scale_wrt_gas=True):
r"""Execute backward pass on the loss
Arguments:
loss: Torch tensor on which to execute backward propagation
allreduce_gradients: is deprecated, ignored, and will soon be removed'
retain_graph: bool, default: false
forward on user defined choice of retain_graph
"""
see_memory_usage("Engine before backward", force=self.memory_breakdown())
if self.scale_wrt_gas is not None:
scale_wrt_gas = self.scale_wrt_gas
if not allreduce_gradients:
logger.warning(f"Argument `allreduce_gradients` is deprecated, ignored, and will soon be removed")
# scale loss w.r.t. gradient accumulation if needed
if self.gradient_accumulation_steps() > 1 and scale_wrt_gas:
loss = self._scale_loss_by_gas(loss.float())
# Log training loss
self.losses += loss.mean().item()
if self.monitor.enabled:
if self.is_gradient_accumulation_boundary():
if self.global_rank == 0:
self.summary_events = [(
f"Train/Samples/train_loss",
self.losses,
self.global_samples,
)]
self.monitor.write_events(self.summary_events)
self._start_timers(self.engine_timers.backward_timers)
assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \
"must provide optimizer during init in order to use backward"
self._start_timers(self.engine_timers.backward_inner_timers)
if self.zero_optimization():
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary()
self.optimizer.backward(loss, retain_graph=retain_graph)
elif self.amp_enabled():
# AMP requires delaying unscale when inside gradient accumulation boundaries
# https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
delay_unscale = not self.is_gradient_accumulation_boundary()
with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss:
scaled_loss.backward(retain_graph=retain_graph)
elif self.fp16_enabled():
if self.eigenvalue_enabled():
self.optimizer.backward(loss, create_graph=True, retain_graph=True)
else:
self.optimizer.backward(loss, retain_graph=retain_graph)
elif self.bfloat16_enabled():
self.optimizer.backward(loss)
else:
if self.eigenvalue_enabled():
loss.backward(create_graph=True, retain_graph=True)
else:
loss.backward(retain_graph=retain_graph)
self._stop_timers(self.engine_timers.backward_inner_timers)
self._start_timers(self.engine_timers.backward_reduce_timers)
if allreduce_gradients and self.enable_backward_allreduce:
# Traditional code path that allreduces the module parameter grads
self.allreduce_gradients()
self._stop_timers(self.engine_timers.backward_reduce_timers)
self._stop_timers(self.engine_timers.backward_timers)
if release_loss:
# loss.data = None
pass
see_memory_usage("Engine after backward", force=self.memory_breakdown())
return loss
def is_gradient_accumulation_boundary(self):
"""
Query whether the current micro-batch is at the boundary of
gradient accumulation, and thus will trigger gradient reductions and
an optimizer step.
Returns:
bool: if the current step is a gradient accumulation boundary.
"""
if self._is_gradient_accumulation_boundary is None:
return (self.micro_steps + 1) % \
self.gradient_accumulation_steps() == 0
else:
return self._is_gradient_accumulation_boundary
def set_gradient_accumulation_boundary(self, is_boundary):
"""
Manually overrides the DeepSpeed engine's gradient accumulation boundary state, this is an optional
feature and should be used with care. The state should be set before to the intended
value before each forward/backward. The final forward/backward should have the
boundary state set to True. This style allows client code to only call engine.step() once after all
the gradient accumulation passes are complete. See example below:
.. code-block:: python
engine.set_gradient_accumulation_boundary(False)
for _ in range(gradient_accumulation_steps - 1):
micro_batch = next(data_loader)
loss = engine(micro_batch)
engine.backward(loss)
engine.set_gradient_accumulation_boundary(True)
micro_batch = next(data_loader)
loss = engine(micro_batch)
engine.backward(loss)
engine.step()
Arguments:
is_boundary (bool): are we at a gradient accumulation boundary or not?
"""
self._is_gradient_accumulation_boundary = is_boundary
self.optimizer.is_gradient_accumulation_boundary = is_boundary
def zero_grad(self):
"""
Zero parameter grads.
"""
for param_name, param in self.module.named_parameters():
param.grad = None
def clip_fp32_gradients(self):
clip_grad_norm_(parameters=self.module.parameters(), max_norm=self.gradient_clipping(), mpu=self.mpu)
def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
if self.gradient_clipping() > 0.0:
if not (self.fp16_enabled() or self.bfloat16_enabled() or self.amp_enabled() or self.zero_optimization()):
self.clip_fp32_gradients()
elif self.amp_enabled():
# AMP's recommended way of doing clipping
# https://nvidia.github.io/apex/advanced.html#gradient-clipping
master_params = amp.master_params(self.optimizer)
clip_grad_norm_(parameters=master_params, max_norm=self.gradient_clipping(), mpu=self.mpu)
self.optimizer.step()
if hasattr(self.optimizer, '_global_grad_norm'):
self._global_grad_norm = self.optimizer._global_grad_norm
# Quantize the updated parameter if there is no overflow
if self.quantizer:
tensor_to_quantize = self.optimizer.bit16_groups if self.zero_optimization_stage(
) == 2 else self.optimizer.fp16_groups
if self.compression_scheduler.weight_quantization_enabled:
self.quantizer.quantize(
tensor_to_quantize,
(self.optimizer.overflow if self.fp16_enabled() else False),
self.eigenvalue_enabled(),
block_eigenvalue,
)
# zero grad in basic optimizer could be unreliable and may not exhibit
# the behavior that we want
if self.bfloat16_enabled():
# TODO: Temporary until bf16_optimizer and zero_optimizer are integrated
if self.zero_optimization() and hasattr(self.optimizer, "zero_grad"):
self.optimizer.zero_grad()
else:
pass
elif self.zero_optimization() or self.fp16_enabled() or self.amp_enabled():
self.optimizer.zero_grad()
else:
self.zero_grad()
report_progress = self.global_rank == 0 if self.global_rank else True
# Check overflow here since in DS fp16 optimizer, the overflow is updated in above step() function.
overflow = False
if hasattr(self.optimizer, "overflow"):
overflow = self.optimizer.overflow
self._step_applied = not overflow
if overflow:
self.skipped_steps += 1
else:
self.compression_scheduler.step()
if self.lr_scheduler is not None:
try:
self.lr_scheduler.step(**(lr_kwargs or {}))
except TypeError:
# XXX Hack to work with Megatron 2.0 and DeepSpeed pipelines.
# We don't currently have a way to specify lr_kwargs from
# pipe_engine.train_batch()
self.lr_scheduler.step(self.train_batch_size())
if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0:
self._report_progress(self.global_steps + 1)
self.losses = 0.0
self.global_steps += 1
self.global_samples += self.train_batch_size()
def step(self, lr_kwargs=None):
r"""Execute the weight update step after forward and backward propagation
on effective_train_batch.
"""
see_memory_usage("Engine before step", force=self.memory_breakdown())
# Check early because self.global_steps is incremented at some point here.
# TODO: Delay self.global_steps increment until very end of this function.
flops_profiler_active = self.flops_profiler_enabled(
) and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0
self._start_timers(self.engine_timers.step_timers)
assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \
"must provide optimizer during init in order to use step"
report_progress = False
self._step_applied = False # assume False, will flip to True
# Update the model when we reach gradient accumulation boundaries
if self.is_gradient_accumulation_boundary():
self.gas_boundary_ctr += 1
if (self.eigenvalue_enabled() and (self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution() == 0)
and self.quantizer.any_precision_switch()):
log_dist(f"computing eigenvalue...", ranks=[0])
self.block_eigenvalue = self.eigenvalue.compute_eigenvalue(self.module, self.device,
self.optimizer.cur_scale)
if self.progressive_layer_drop:
self.progressive_layer_drop.update_state(self.global_steps)
if (self.eigenvalue_enabled() and not self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution()
and self.quantizer.any_precision_switch()):
self._take_model_step(lr_kwargs, self.block_eigenvalue)
else:
self._take_model_step(lr_kwargs)
report_progress = self.global_rank == 0 if self.global_rank else True
self.tput_timer.stop(global_step=self.is_gradient_accumulation_boundary(), report_speed=report_progress)
self._stop_timers(self.engine_timers.step_timers)
# Log learning rate
if self.monitor.enabled:
if self.is_gradient_accumulation_boundary():
if self.global_rank == 0:
self.summary_events = [(f"Train/Samples/lr", self.get_lr()[0], self.global_samples)]
if self.fp16_enabled() and hasattr(self.optimizer, "cur_scale"):
self.summary_events.append((
f"Train/Samples/loss_scale",
self.optimizer.cur_scale,
self.global_samples,
))
if (self.eigenvalue_enabled()
and not self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution()):
ev_values = self.block_eigenvalue.values()
for i in range(len(ev_values)):
self.summary_events.append((
f"Train/Eigenvalues/ModelBlockParam_{i}",
self.ev_values[i][0],
self.global_samples,
))
self.monitor.write_events(self.summary_events)
# Check flops profiling
if flops_profiler_active:
if self.autotuning_enabled():
self.flops = self.flops_profiler.get_total_flops() * 3
self.fwd_duration = self.flops_profiler.get_total_duration()
else:
self.flops_profiler.print_model_profile(
profile_step=self.global_steps,
module_depth=self.flops_profiler_module_depth(),
top_modules=self.flops_profiler_top_modules(),
detailed=self.flops_profiler_detailed(),
output_file=self.flops_profiler_output_file(),
)
self.flops_profiler.end_profile()
if self.autotuning_enabled() and self.global_steps == (self.autotuning_end_profile_step() + 1):
self._autotuning_exit()
if self.wall_clock_breakdown():
# Log micro timing and reset
self.timers.log(names=self.engine_timers.micro_timers, memory_breakdown=self.memory_breakdown())
if self.wall_clock_breakdown() or self.flops_profiler_enabled():
# Log global timing and reset
if self.is_gradient_accumulation_boundary():
if self.monitor.enabled:
self._write_monitor()
if self.has_moe_layers:
fwd_time = self.timers(FORWARD_GLOBAL_TIMER).elapsed(reset=False)
self.print_forward_breakdown(fwd_time=fwd_time)
self.timers.log(self.engine_timers.global_timers)
self.micro_steps += 1
see_memory_usage("Engine after step", force=self.memory_breakdown())
def _start_timers(self, timer_names):
for name in timer_names:
self.timers(name).start()
def _stop_timers(self, timer_names):
record = self.is_gradient_accumulation_boundary() and \
self.flops_profiler_enabled() and \
(self.global_steps >= self.flops_profiler_profile_step())
for name in timer_names:
self.timers(name).stop(record=record)
def _autotuning_exit(self):
if self.global_rank == 0:
msg = self.timers.get_mean([
FORWARD_GLOBAL_TIMER,
BACKWARD_GLOBAL_TIMER,
STEP_GLOBAL_TIMER,
], reset=False)
titer = 0.0
titer += msg[FORWARD_GLOBAL_TIMER] if FORWARD_GLOBAL_TIMER in msg else 0
titer += msg[BACKWARD_GLOBAL_TIMER] if BACKWARD_GLOBAL_TIMER in msg else 0
titer += msg[STEP_GLOBAL_TIMER] if STEP_GLOBAL_TIMER in msg else 0
titer *= self.gradient_accumulation_steps()
msg["latency"] = titer
msg["FLOPS_per_gpu"] = self.flops * 1_000_000 * self.gradient_accumulation_steps() / titer
msg["throughput"] = self.train_batch_size() * 1_000_000 / \
msg["latency"]
print_json_dist(msg, [0], path=self.autotuning_metric_path())
log_dist(
f"Wrote metrics to {self.autotuning_metric_path()}, {os.path.abspath(self.autotuning_metric_path())}",
ranks=[0])
import atexit
atexit.register(print, "Autotuning: done with running current ds config.")
exit()
def _write_monitor(self):
if self.global_rank == 0:
self.summary_events = [
(
f"Train/Samples/elapsed_time_ms_forward",
self.timers(FORWARD_GLOBAL_TIMER).elapsed(reset=False),
self.global_samples,
),
(
f"Train/Samples/elapsed_time_ms_backward",
self.timers(BACKWARD_GLOBAL_TIMER).elapsed(reset=False),
self.global_samples,
),
(
f"Train/Samples/elapsed_time_ms_backward_inner",
self.timers(BACKWARD_INNER_GLOBAL_TIMER).elapsed(reset=False),
self.global_samples,
),
(
f"Train/Samples/elapsed_time_ms_backward_allreduce",
self.timers(BACKWARD_REDUCE_GLOBAL_TIMER).elapsed(reset=False),
self.global_samples,
),
(
f"Train/Samples/elapsed_time_ms_step",
self.timers(STEP_GLOBAL_TIMER).elapsed(reset=False),
self.global_samples,
),
]
self.monitor.write_events(self.summary_events)
def _get_optimizer_param(self, param_name):
result = []
if not self.optimizer:
return result
for group in self.optimizer.param_groups:
if param_name in group:
result.append(group[param_name])
else:
result.append(0.0)
return result
def get_lr(self):
return self._get_optimizer_param("lr")
def get_type(self):
return self._get_optimizer_param("type")
def get_mom(self):
if self.optimizer_name() in ["SGD", "RMSprop"]:
return self._get_optimizer_param("momentum")
else:
return self._get_optimizer_param("betas")
def get_pld_theta(self):
if self.progressive_layer_drop:
return self.progressive_layer_drop.get_theta()
else:
return None
def _report_progress(self, step):
lr = self.get_lr()
mom = self.get_mom()
log_dist(f"step={step}, skipped={self.skipped_steps}, lr={lr}, mom={mom}", ranks=[0])
def allreduce_bucket(self, bucket, dp_group):
tensor = self.flatten(bucket)
tensor_to_allreduce = tensor
if self.communication_data_type != tensor.dtype:
tensor_to_allreduce = tensor.to(self.communication_data_type)
if self.postscale_gradients():
if self.gradient_predivide_factor() != 1.0:
tensor_to_allreduce.mul_(1.0 / self.gradient_predivide_factor())
dist.all_reduce(tensor_to_allreduce, group=dp_group)
if self.gradient_average:
if self.gradient_predivide_factor() != dist.get_world_size(group=dp_group):
tensor_to_allreduce.mul_(self.gradient_predivide_factor() / dist.get_world_size(group=dp_group))
else:
tensor_to_allreduce.mul_(1. / dist.get_world_size(group=dp_group))
dist.all_reduce(tensor_to_allreduce, group=dp_group)
if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce:
tensor.copy_(tensor_to_allreduce)
return tensor
def allreduce_and_copy(self, small_bucket, dp_group):
allreduced = self.allreduce_bucket(small_bucket, dp_group)
for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)):
buf.copy_(synced)
def allreduce_no_retain(self, bucket, dp_group, numel_per_bucket=500000000):
small_bucket = []
numel = 0
for tensor in bucket:
small_bucket.append(tensor)
numel = numel + tensor.numel()
if numel > numel_per_bucket:
self.allreduce_and_copy(small_bucket, dp_group)
small_bucket = []
numel = 0
if len(small_bucket) > 0:
self.allreduce_and_copy(small_bucket, dp_group)
def _get_gradients_for_reduction(self):
non_expert_grads = []
expert_grads = {}
if self.has_moe_layers:
for key in self.expert_data_parallel_group.keys():
expert_grads[key] = []
for param_name, param in self.module.named_parameters():
if not param.requires_grad:
continue
if param.grad is None:
# In cases where there is an imbalance of empty grads across
# ranks we must create empty grads, this will ensure that every
# rank is reducing the same size. In some cases it may make
# sense in the future to support the ability to average not
# w.r.t. world size but with a different value.
param.grad = torch.zeros(param.size(), dtype=param.dtype, device=param.device)
grad_data = param.grad.data
if param_name in self.sparse_tensor_module_names or grad_data.is_sparse:
# Call param.grad without data to avoid problem with setting of updated grads
grad_data = SparseTensor(param.grad)
if is_moe_param(param):
expert_grads[param.group_name].append(grad_data)
else:
non_expert_grads.append(grad_data)
return non_expert_grads, expert_grads
def _reduce_non_expert_gradients(self, grads, elements_per_buffer):
split_buckets = split_half_float_double_sparse(grads)
for _, bucket_tuple in enumerate(split_buckets):
bucket_type, bucket = bucket_tuple
if self.pipeline_parallelism:
dp_group = self.mpu.get_data_parallel_group()
else:
dp_group = groups._get_sequence_data_parallel_group()
if bucket_type == SparseTensor.type():
self.sparse_allreduce_no_retain(bucket, dp_group=dp_group)
else:
self.allreduce_no_retain(bucket, dp_group=dp_group, numel_per_bucket=elements_per_buffer)
def _reduce_expert_gradients(self, expert_grads, elements_per_buffer):
for ep_name, expert_grads_group in expert_grads.items():
expert_split_buckets = split_half_float_double_sparse(expert_grads_group)
for i, bucket_tuple in enumerate(expert_split_buckets):
bucket_type, bucket = bucket_tuple
if bucket_type == SparseTensor.type():
self.sparse_allreduce_no_retain(bucket, groups._get_expert_data_parallel_group(ep_name))
else:
# Separate between diff groups
self.allreduce_no_retain(bucket,
dp_group=groups._get_expert_data_parallel_group(ep_name),
numel_per_bucket=elements_per_buffer)
def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000):
if grads is None:
non_expert_grads, expert_grads = self._get_gradients_for_reduction()
else:
assert not self.has_moe_layers, "attempting to reduce grads in unsupported way w.r.t. MoE"
non_expert_grads = grads
self._reduce_non_expert_gradients(non_expert_grads, elements_per_buffer)
if self.has_moe_layers:
self._reduce_expert_gradients(expert_grads, elements_per_buffer)
def sparse_allreduce_no_retain(self, bucket, dp_group):
allreduced_sparses = self.sparse_allreduce_bucket(bucket, dp_group)
# Densify sparse tensor and copy back to original location
for tensor in allreduced_sparses:
if tensor.is_sparse:
tensor.orig_dense_tensor.data = tensor.to_coo_tensor()
else:
tensor.orig_dense_tensor.copy_(tensor.to_dense())
def sparse_allreduce_bucket(self, bucket, dp_group):
sparse_list = []
for sparse in bucket:
sparse_list.append(self.sparse_allreduce(sparse, dp_group))
return sparse_list
def sparse_allreduce(self, sparse, dp_group):
original_data_type = sparse.values.dtype
if self.communication_data_type != sparse.values.dtype:
if self.communication_data_type in (torch.float16, torch.bfloat16):
indices = sparse.indices.to(torch.int32)
else:
indices = sparse.indices
values = sparse.values.to(self.communication_data_type)
else:
indices = sparse.indices
values = sparse.values
if self.postscale_gradients():
if self.gradient_average:
values.mul_(self.gradient_predivide_factor() /
(dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size)))
else:
values.mul_(1. / (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size)))
indices_device_list = self.sparse_all_gather(indices, dp_group)
values_device_list = self.sparse_all_gather(values, dp_group)
sparse.indices = torch.cat(indices_device_list).to(torch.long)
sparse.values = torch.cat(values_device_list).to(original_data_type)
return sparse
def sparse_all_gather(self, value, dp_group):
my_size = torch.LongTensor([value.size()[0]]).to(self.device)
all_sizes = self.all_gather_scalar(my_size, dp_group)
max_size = torch.cat(all_sizes).max()
fill_size = max_size - my_size
assert value.dim() in [1, 2]
if value.dim() == 1:
if fill_size > 0:
value = torch.cat([value, value.new_empty(fill_size)])
tensor_list = [value.new_empty(max_size) for _ in range(dist.get_world_size(group=dp_group))]
else:
if fill_size > 0:
value = torch.cat([value, value.new_empty(fill_size, value.size()[1])])
tensor_list = [
value.new_empty(max_size,
value.size()[1]) for _ in range(dist.get_world_size(group=dp_group))
]
dist.all_gather(tensor_list, value, group=dp_group)
tensors = []
for dev_idx, t in enumerate(tensor_list):
size = all_sizes[dev_idx][0]
tensors.append(t.index_select(0, torch.arange(size, dtype=torch.long, device=self.device)))
return tensors
def all_gather_scalar(self, value, dp_group):
tensor_list = [value.new_zeros(value.size()) for _ in range(dist.get_world_size(group=dp_group))]
dist.all_gather(tensor_list, value, group=dp_group)
return tensor_list
def module_state_dict(self, destination=None, prefix="", keep_vars=False, exclude_frozen_parameters=False):
sd = self.module.state_dict(destination, prefix, keep_vars)
# Remove frozen parameter weights from state_dict if specified
if exclude_frozen_parameters:
for n, p in self.module.named_parameters():
if not p.requires_grad and n in sd:
del sd[n]
if self.random_ltd_enabled():
sd = remove_random_ltd_state_dict(sd)
return sd
@staticmethod
def load_moe_state_dict(checkpoint_path,
tag,
state_dict,
old_moe_load,
model=None,
mpu=None,
num_experts=1,
checkpoint_engine=TorchCheckpointEngine()):
if old_moe_load:
expp_rank = groups._get_expert_data_parallel_rank(groups._get_max_expert_size_name())
num_local_experts = max(num_experts) // groups._get_expert_parallel_world_size(
groups._get_max_expert_size_name())
for local_expert_id in range(num_local_experts):
global_expert_id = expp_rank * num_local_experts + local_expert_id
expert_state_dict = checkpoint_engine.load(
DeepSpeedEngine._get_expert_ckpt_name(
checkpoint_path,
-1, # -1 means ignore layer_id
global_expert_id,
tag,
mpu),
map_location=torch.device('cpu'))
# Updating global -> local expert ids
moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.'
for key in list(expert_state_dict.keys()):
local_key = key.replace(f'{moe_str_prefix}{global_expert_id}',
f'{moe_str_prefix}{local_expert_id}')
expert_state_dict[local_key] = expert_state_dict.pop(key)
state_dict.update(expert_state_dict)
else:
moe_layer_id = 0
for n_module, module in model.named_modules():
if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
group_name = module.expert_group_name
num_local_experts = module.num_local_experts
expp_rank = groups._get_expert_parallel_rank(group_name)
# loop all local_experts
for local_expert_id in range(num_local_experts):
global_expert_id = expp_rank * num_local_experts + local_expert_id
expert_state_dict = checkpoint_engine.load(DeepSpeedEngine._get_expert_ckpt_name(
checkpoint_path, moe_layer_id, global_expert_id, tag, mpu),
map_location=torch.device('cpu'))
# print(expert_state_dict.keys())
# Updating global -> local expert ids
moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.'
for key in list(expert_state_dict.keys()):
local_key = key.replace(f'{moe_str_prefix}{global_expert_id}',
f'{moe_str_prefix}{local_expert_id}')
expert_state_dict[local_key] = expert_state_dict.pop(key)
state_dict.update(expert_state_dict)
moe_layer_id += 1
def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, fetch_z3_params=False):
if fetch_z3_params:
params_to_fetch = [
p for p in self.module.parameters()
if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE
]
else:
params_to_fetch = []
with deepspeed.zero.GatheredParameters(params_to_fetch, modifier_rank=0):
module_state_dict = checkpoint['module']
if custom_load_fn:
custom_load_fn(src=module_state_dict, dst=self.module)
else:
self.module.load_state_dict(
module_state_dict, # TODO
strict=strict)
if checkpoint.get(FROZEN_PARAM_FRAGMENTS, None) is not None:
saved_frozen_params = checkpoint[FROZEN_PARAM_FRAGMENTS]
for param in self.module.parameters():
if param.requires_grad:
continue
if param not in self.param_names:
raise ValueError(f"failed to find frozen {param} in named params")
name = self.param_names[param]
if hasattr(param, 'ds_id'):
param.ds_tensor.data.copy_(saved_frozen_params[name].data)
else:
param.data.copy_(saved_frozen_params[name].data)
def _get_zero_ckpt_prefix(self, dp_rank, bf16_mode):
return f'{"bf16_" if bf16_mode else ""}zero_pp_rank_{dp_rank}'
def _get_rank_zero_ckpt_name(self, checkpoints_path, tag, mp_rank, dp_rank, bf16_mode):
file_prefix = self._get_zero_ckpt_prefix(dp_rank, bf16_mode=bf16_mode)
zero_ckpt_name = os.path.join(
checkpoints_path,
str(tag),
f"{file_prefix}_mp_rank_{mp_rank:02d}_optim_states.pt",
)
return zero_ckpt_name
def _get_zero_ckpt_name(self, checkpoints_path, tag):
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
pp_rank = dist.get_rank(group=self.optimizer.zp_process_group)
bf16_mode = self.bfloat16_enabled()
return self._get_rank_zero_ckpt_name(checkpoints_path, tag, mp_rank, pp_rank, bf16_mode)
def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None):
if mp_placeholder is not None:
mp_rank_str = mp_placeholder
else:
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
mp_rank_str = f"{mp_rank:02d}"
if self.zero_optimization_partition_weights():
filename = "zero_pp_rank_{}".format(dist.get_rank(group=self.optimizer.zp_process_group))
ckpt_name = os.path.join(
checkpoints_path,
str(tag),
f"{filename}_mp_rank_{mp_rank_str}_model_states.pt",
)
else:
ckpt_name = os.path.join(
checkpoints_path,
str(tag),
"mp_rank_" + mp_rank_str + "_model_states.pt",
)
return ckpt_name
def _get_optimizer_ckpt_name(self, checkpoints_path, tag, expp_rank):
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
ckpt_name = os.path.join(checkpoints_path, str(tag),
f'expp_rank_{expp_rank}_mp_rank_{mp_rank:02d}_optim_states.pt')
return ckpt_name
@staticmethod
def _get_expert_ckpt_name(checkpoints_path, layer_id, expert_id, tag, mpu=None):
mp_rank = 0 if mpu is None else mpu.get_model_parallel_rank()
if layer_id <= -1:
# Used to support old checkpoint loading
ckpt_name = os.path.join(checkpoints_path, '' if tag is None else str(tag),
f'expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt')
else:
# Used to support new checkpoint loading
ckpt_name = os.path.join(checkpoints_path, '' if tag is None else str(tag),
f'layer_{layer_id}_expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt')
return ckpt_name
def _get_all_ckpt_names(self, checkpoints_path, tag):
# It is required that (checkpoints_path, tag) are consistent among all ranks.
ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, tag, mp_placeholder="*")
import glob
ckpt_files = glob.glob(ckpt_file_pattern)
ckpt_files.sort()
return ckpt_files
def load_checkpoint(self,
load_dir,
tag=None,
load_module_strict=True,
load_optimizer_states=True,
load_lr_scheduler_states=True,
load_module_only=False,
custom_load_fn=None):
"""
Load training checkpoint
Arguments:
load_dir: Required. Directory to load the checkpoint from
tag: Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file
load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match.
load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance
load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
load_module_only: Optional. Boolean to load only the model weights from the checkpoint. Ex. warmstarting.
custom_load_fn: Optional. Custom model load function.
Returns:
A tuple of ``load_path`` and ``client_state``.
*``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed.
*``client_state``: State dictionary used for loading required training states in the client code.
Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right
after ``engine.save_checkpoint()``. It is because ``engine.module`` is partitioned, and
``load_checkpoint()`` wants a pristine model. If insisting to do so, please reinitialize engine
before ``load_checkpoint()``.
"""
if tag is None:
latest_tag = "latest_universal" if self.load_universal_checkpoint() else "latest"
latest_path = os.path.join(load_dir, latest_tag)
if os.path.isfile(latest_path):
with open(latest_path, "r") as fd:
tag = fd.read().strip()
else:
if self.load_universal_checkpoint():
raise ValueError(f'Invalid for universal checkpoint: {latest_path} does not exist')
else:
logger.warning(
f"Unable to find latest file at {latest_path}, if trying to load latest "
"checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint."
)
return None, None
if self._optimizer_has_ckpt_event_prologue():
# Prepare for checkpoint load by ensuring all parameters are partitioned
self.optimizer.checkpoint_event_prologue()
load_path, client_states = self._load_checkpoint(load_dir,
tag,
load_module_strict=load_module_strict,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states,
load_module_only=load_module_only,
custom_load_fn=custom_load_fn)
load_zero_checkpoint = load_optimizer_states and load_path is not None and (self.zero_optimization()
or self.bfloat16_enabled())
if load_zero_checkpoint:
success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)
if not success:
self.optimizer._restore_from_bit16_weights()
if self._optimizer_has_ckpt_event_epilogue():
self.optimizer.checkpoint_event_epilogue()
if self.load_universal_checkpoint():
self.optimizer.update_lp_params()
if load_zero_checkpoint:
self.update_optimizer_step(step=client_states['iteration'] + 1)
return load_path, client_states
def _load_checkpoint(self,
load_dir,
tag,
load_module_strict=True,
load_optimizer_states=True,
load_lr_scheduler_states=True,
load_module_only=False,
custom_load_fn=None):
from deepspeed.runtime.state_dict_factory import SDLoaderFactory
ckpt_list = self._get_all_ckpt_names(load_dir, tag)
sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine=self.checkpoint_engine)
is_pipe_parallel = isinstance(self.module, PipelineModule)
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
load_path, checkpoint, _ = sd_loader.load(self.mp_world_size, mp_rank, is_pipe_parallel=is_pipe_parallel)
if checkpoint is None:
return None, None
fetch_z3_params = False
if self.zero_optimization_partition_weights() and not load_optimizer_states:
checkpoint['module'] = get_fp32_state_dict_from_zero_checkpoint(load_dir)
fetch_z3_params = True
if is_pipe_parallel:
# Pipeline parallelism uses this to load its own checkpoint files.
self._curr_ckpt_path = os.path.join(load_dir, tag)
if self.has_moe_layers:
# print(checkpoint.keys())
old_moe_load = False
if not isinstance(checkpoint['num_experts'], list):
old_moe_load = True
DeepSpeedEngine.load_moe_state_dict(load_dir,
tag,
state_dict=checkpoint['module'],
old_moe_load=old_moe_load,
model=self.module,
mpu=self.mpu,
num_experts=self.num_experts,
checkpoint_engine=self.checkpoint_engine)
if not self.load_universal_checkpoint():
self.load_module_state_dict(checkpoint=checkpoint,
strict=load_module_strict,
custom_load_fn=custom_load_fn,
fetch_z3_params=fetch_z3_params)
self.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size']
if 'zp_world_size' not in checkpoint:
checkpoint['zp_world_size'] = self.zp_world_size
self.loaded_checkpoint_zp_world_size = checkpoint['zp_world_size']
optim_checkpoint = None
if load_module_only:
deepspeed_states = ['module']
if self.optimizer is not None and self.fp16_enabled():
self.optimizer.refresh_fp32_params()
else:
has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()
if load_optimizer_states and self.optimizer is not None and not has_zero_optimizer_state:
if self.has_moe_layers:
largest_group_name = groups._get_max_expert_size_name()
expp_rank = groups._get_expert_parallel_rank(largest_group_name)
optim_load_path = self._get_optimizer_ckpt_name(load_dir, tag, expp_rank)
optim_checkpoint = self.checkpoint_engine.load(optim_load_path, map_location=torch.device('cpu'))
else:
optim_checkpoint = checkpoint
if self.fp16_enabled() or self.bfloat16_enabled():
self.optimizer.load_state_dict(optim_checkpoint['optimizer'],
load_optimizer_states=load_optimizer_states)
else:
optim_checkpoint = checkpoint
self.optimizer.load_state_dict(optim_checkpoint['optimizer'])
if load_lr_scheduler_states and self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
if self.random_ltd_enabled() and self.random_ltd_scheduler is not None and 'random_ltd' in checkpoint:
self.random_ltd_scheduler.load_state_dict(checkpoint['random_ltd'])
if self.training_dataloader is not None and self.curriculum_learning_enabled(
) and 'data_sampler' in checkpoint:
self.training_dataloader.data_sampler.load_state_dict(checkpoint['data_sampler'])
def get_sparse_tensor_module_names(original_set, loaded_set, original_parameters, loaded_parameters):
result = set()
for name in original_set:
if name in loaded_parameters and name not in loaded_set:
continue # parameter existed in previous model and was not sparse
result.add(name)
for name in loaded_set:
if name in original_parameters:
result.add(name) # parameter exists in both configs and it was sparse
return result
if 'sparse_tensor_module_names' in checkpoint:
sparse_tensor_module_names = checkpoint['sparse_tensor_module_names']
elif 'csr_tensor_module_names' in checkpoint:
sparse_tensor_module_names = checkpoint['csr_tensor_module_names']
else:
sparse_tensor_module_names = None
if sparse_tensor_module_names is not None:
if load_module_strict:
self.sparse_tensor_module_names = sparse_tensor_module_names
else:
self.sparse_tensor_module_names = get_sparse_tensor_module_names(
self.sparse_tensor_module_names, sparse_tensor_module_names,
dict(self.module.named_parameters()), checkpoint["module"])
self.global_steps = checkpoint['global_steps']
self.global_samples = checkpoint.get('global_samples', self.global_steps * self.train_batch_size())
self.skipped_steps = checkpoint['skipped_steps']
self.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size']
deepspeed_states = [
'module', 'sparse_tensor_module_names', 'skipped_steps', 'global_steps', 'zp_world_size',
'mp_world_size', 'data_sampler', 'random_ltd', 'dp_world_size',
]
client_state = {}
if load_lr_scheduler_states:
deepspeed_states.append('lr_scheduler')
if load_optimizer_states:
deepspeed_states.append('optimizer')
client_state = {key: value for key, value in checkpoint.items() if not key in deepspeed_states}
if optim_checkpoint is not None:
client_state['optimizer'] = optim_checkpoint['optimizer']
return load_path, client_state
def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
load_serial = None
# When use loading checkpoint serial, checkpoint loading start from local rank 0,
# all other local rank would be paused, waiting for its rank-1 peer ready and its notification.
if self._config.zero_config.pipeline_loading_checkpoint:
assert self.zero_optimization_stage(
) == ZeroStageEnum.weights, "Only stage3 support for pipeline checkpoint loading"
load_serial = torch.zeros(1).to(self.device)
if dist.get_local_rank() != 0:
dist.recv(tensor=load_serial, src=dist.get_rank() - 1)
if self.load_universal_checkpoint():
zero_sd_list = None
checkpoint_folder = f'{os.path.join(load_dir, tag)}'
else:
if load_optimizer_states and self.zp_world_size != self.loaded_checkpoint_zp_world_size:
raise ZeRORuntimeException("The checkpoint being loaded used a DP " \
f"world size of {self.loaded_checkpoint_zp_world_size} but the " \
f"current world size is {self.zp_world_size}. Automatic adjustment " \
"of ZeRO's optimizer state partitioning with a new world size is not " \
"currently supported.")
checkpoint_folder = None
zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag)
if zero_sd_list is None:
return False
self.optimizer.load_state_dict(state_dict_list=zero_sd_list,
load_optimizer_states=load_optimizer_states,
load_from_fp32_weights=self.zero_load_from_fp32_weights(),
checkpoint_folder=checkpoint_folder,
load_serial=load_serial)
if self.load_universal_checkpoint():
logger.info(f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}')
else:
logger.info(f"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}")
return True
def update_optimizer_step(self, step):
def set_step(d):
if isinstance(d['step'], torch.Tensor):
d['step'] = torch.tensor(step, dtype=d['step'].dtype, device=d['step'].device)
else:
d['step'] = step
optimizer = self.optimizer
base_optimizer = optimizer.optimizer
state = base_optimizer.state
for group in optimizer.param_groups:
if 'step' in group:
set_step(group)
for p in group['params']:
if p in state and len(state[p]) > 0 and 'step' in state[p]:
set_step(state[p])
def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size, bf16_mode):
zero_ckpt_names = []
for dp_rank in range(dp_world_size):
ckpt_name = self._get_rank_zero_ckpt_name(checkpoints_path=load_dir,
tag=tag,
mp_rank=mp_rank,
dp_rank=dp_rank,
bf16_mode=bf16_mode)
zero_ckpt_names.append(ckpt_name)
return zero_ckpt_names
def _get_all_zero_checkpoint_names(self, load_dir, tag, bf16_mode):
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
zero_ckpt_names = self._get_mp_rank_zero_checkpoint_names(load_dir=load_dir,
tag=tag,
mp_rank=mp_rank,
dp_world_size=self.loaded_checkpoint_dp_world_size,
bf16_mode=bf16_mode)
for i, ckpt_name in enumerate(zero_ckpt_names):
if not os.path.exists(ckpt_name):
# transparently handle the old file pattern for optim_states
if "optim_states.pt" in ckpt_name:
ckpt_name_try = ckpt_name.replace("_optim_states.pt", "optim_states.pt")
if os.path.exists(ckpt_name_try):
zero_ckpt_names[i] = ckpt_name_try
continue
return zero_ckpt_names
def _get_all_zero_checkpoint_state_dicts(self, zero_ckpt_names):
zero_sd_list = []
for i, ckpt_name in enumerate(zero_ckpt_names):
_state = None
if ckpt_name is None:
_state = {OPTIMIZER_STATE_DICT: None}
# Fully load state for current rank
elif self.zero_elastic_checkpoint() or dist.get_rank(group=self.optimizer.zp_process_group) == i:
_state = self.checkpoint_engine.load(
ckpt_name,
map_location='cpu',
)
else:
_state = {OPTIMIZER_STATE_DICT: None}
zero_sd_list.append(_state)
zero_optimizer_sd = [sd[OPTIMIZER_STATE_DICT] for sd in zero_sd_list]
logger.info(f"successfully read {len(zero_optimizer_sd)} ZeRO state_dicts for rank {self.global_rank}")
return zero_optimizer_sd
def _get_all_zero_checkpoints(self, load_dir, tag):
for bf16_mode in [self.bfloat16_enabled(), not self.bfloat16_enabled()]:
zero_ckpt_names = self._get_all_zero_checkpoint_names(load_dir, tag, bf16_mode)
if zero_ckpt_names is not None:
# Warn if loading checkpoint of different bit16 type
if bf16_mode is not self.bfloat16_enabled():
checkpoint_bit16 = BFLOAT16 if bf16_mode else FP16
engine_bit16 = BFLOAT16 if self.bfloat16_enabled() else FP16
logger.warn(f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine')
return self._get_all_zero_checkpoint_state_dicts(zero_ckpt_names)
return None
def _checkpoint_tag_validation(self, tag):
if self.checkpoint_tag_validation_enabled():
s_hash = hashlib.sha1(tag.encode())
bhash = torch.ByteTensor([s_hash.digest()]).flatten().to(self.device)
max_bhash = bhash.clone()
min_bhash = bhash.clone()
dist.all_reduce(max_bhash, op=dist.ReduceOp.MAX)
dist.all_reduce(min_bhash, op=dist.ReduceOp.MIN)
valid = all(min_bhash == bhash) and all(max_bhash == bhash)
msg = (f"[rank={dist.get_rank()}] The checkpoint tag name '{tag}' is not consistent across "
"all ranks. Including rank unique information in checkpoint tag could cause issues when "
"restoring with different world sizes.")
if self.checkpoint_tag_validation_fail():
assert valid, msg
elif not valid:
logger.warning(msg)
def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True, exclude_frozen_parameters=False):
"""Save training checkpoint
Arguments:
save_dir: Required. Directory for saving the checkpoint
tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is
used if not provided. Tag name must be the same across all ranks.
client_state: Optional. State dictionary used for saving required training states in the client code.
save_latest: Optional. Save a file 'latest' pointing to the latest saved checkpoint.
exclude_frozen_parameters: Optional. Exclude frozen parameters from checkpointed state.
Important: all processes must call this method and not just the process with rank 0. It is
because each process needs to save its master weights and scheduler+optimizer states. This
method will hang waiting to synchronize with other processes if it's called just for the
process with rank 0.
"""
if self._optimizer_has_ckpt_event_prologue():
# Custom preparation for checkpoint save, if applicable
self.optimizer.checkpoint_event_prologue()
rank = self.local_rank if self.use_node_local_storage() else self.global_rank
# This is to make sure the checkpoint names are created without collision
# There seems to be issue creating them in parallel
# Ensure save_dir directory exists
if rank == 0:
self.checkpoint_engine.makedirs(save_dir, exist_ok=True)
dist.barrier()
if tag is None:
tag = f"global_step{self.global_steps}"
# Ensure tag is a string
tag = str(tag)
self.checkpoint_engine.create(tag)
# Ensure checkpoint tag is consistent across ranks
self._checkpoint_tag_validation(tag)
if self.has_moe_layers:
self.save_non_zero_checkpoint = False
self._create_checkpoint_file(save_dir, tag, False)
self._save_moe_checkpoint(save_dir,
tag,
client_state=client_state,
exclude_frozen_parameters=exclude_frozen_parameters)
# We distribute the task of saving layer checkpoint files among
# data parallel instances, so all procs should call _save_checkpoint.
# All procs then call module_state_dict(), but only procs of data
# parallel rank 0 save the general model params.
if not self.has_moe_layers:
self._create_checkpoint_file(save_dir, tag, False)
self._save_checkpoint(save_dir,
tag,
client_state=client_state,
exclude_frozen_parameters=exclude_frozen_parameters)
if self.save_zero_checkpoint:
self._create_zero_checkpoint_files(save_dir, tag)
self._save_zero_checkpoint(save_dir, tag)
if self._optimizer_has_ckpt_event_epilogue():
self.optimizer.checkpoint_event_epilogue()
# Save latest checkpoint tag
self.checkpoint_engine.commit(tag)
if save_latest and rank == 0:
with open(os.path.join(save_dir, 'latest'), 'w') as fd:
fd.write(tag)
dist.barrier()
return True
def _get_non_moe_state_dict(self, full_state_dict):
"""
Get the state dict of the non-moe layers
"""
for key in list(full_state_dict.keys()):
if 'expert' in key and 'moe.gate.wg.weight' not in key:
full_state_dict.pop(key)
return full_state_dict
def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False):
save_path = self._get_ckpt_name(save_dir, tag)
# A hack to save the checkpointing directory. Pipeline parallelism overrides
# module_state_dict() and uses this path to save the model. module_state_dict()
# then instead just returns None.
# Using layer_#_export_# to save the model's expert state_dict
moe_layer_id = 0
for n_module, module in self.module.named_modules():
if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
group_name = module.expert_group_name
num_local_experts = module.num_local_experts
expp_rank = groups._get_expert_parallel_rank(group_name)
exp_dp_rank = groups._get_expert_data_parallel_rank(group_name)
# print(expp_rank, exp_dp_rank)
if exp_dp_rank != 0:
moe_layer_id += 1
continue
# get all moe parameters
moe_state_dict = {}
for n, p in module.state_dict().items():
if 'expert' in n and 'moe.gate.wg.weight' not in n:
moe_state_dict[n_module + '.' + n] = p
moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.'
# print(moe_state_dict.keys()) # until now, everything is fine. So the bug happens at next few lines
# Reorder the moe name rank, so that each checkpoint only has one expert
experts_state_dict = defaultdict(dict)
for key in list(moe_state_dict.keys()):
m = re.match(f".*{moe_str_prefix}([0-9]+).*", key)
local_expert_id = None
if not m:
logger.warn(f'No expert found in key {key}.')
else:
local_expert_id = m.group(1)
global_expert_id = expp_rank * \
num_local_experts + int(local_expert_id)
expert_key = key.replace(f'{moe_str_prefix}{local_expert_id}',
f'{moe_str_prefix}{global_expert_id}')
# truncating extra tensor (shared) storage
truncated = moe_state_dict.pop(key).clone().detach()
experts_state_dict[str(global_expert_id)][expert_key] = truncated
# let save the moe parameters
for global_expert_id, expert_state_dict in experts_state_dict.items():
# save the moe parameters
moe_save_path = self._get_expert_ckpt_name(save_dir, moe_layer_id, global_expert_id, tag, self.mpu)
if self.random_ltd_enabled():
expert_state_dict = remove_random_ltd_state_dict(expert_state_dict)
self.checkpoint_engine.save(expert_state_dict, moe_save_path)
moe_layer_id += 1
self._curr_ckpt_path = os.path.join(save_dir, tag)
largest_group_name = groups._get_max_expert_size_name()
expp_rank = groups._get_expert_parallel_rank(largest_group_name)
exp_dp_rank = groups._get_expert_data_parallel_rank(largest_group_name)
# In the case of E + D parallelism, only the
# first expert parallel group should save the expert weights
# since each expert parallel group is a copy of the model's experts
if exp_dp_rank != 0:
return
# Save optimizer states. They are different across each exp parallel rank.
optimizer_state = {
'optimizer': self.optimizer.state_dict() if self.optimizer and not self.zero_optimization() else None
}
# TODO: why use BufferedWriter not the path
file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank)
self.checkpoint_engine.save(optimizer_state, file_path)
# get non-moe parameters
model_state_dict = self._get_non_moe_state_dict(
self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters))
if expp_rank == 0:
# TODO: update num experts info,.. in checkpoint
state = {
'module':
model_state_dict,
'lr_scheduler':
self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
'data_sampler':
self.training_dataloader.data_sampler.state_dict() if
(self.training_dataloader is not None and self.curriculum_learning_enabled()) else None,
'random_ltd':
self.random_ltd_scheduler.state_dict() if self.random_ltd_enabled() else None,
'sparse_tensor_module_names':
self.sparse_tensor_module_names,
'skipped_steps':
self.skipped_steps,
'global_steps':
self.global_steps,
'global_samples':
self.global_samples,
'zp_world_size':
self.zp_world_size,
'dp_world_size':
self.dp_world_size,
'mp_world_size':
self.mp_world_size,
'num_experts':
self.num_experts
}
state.update(client_state)
logger.info(f'Saving model checkpoint: {save_path}')
self.checkpoint_engine.save(state, save_path)
def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint):
name_function = (self._get_zero_ckpt_name if zero_checkpoint else self._get_ckpt_name)
try:
checkpoint_name = name_function(save_dir, tag)
path = os.path.dirname(checkpoint_name)
self.checkpoint_engine.makedirs(path, exist_ok=True)
except:
logger.error(f"Failed saving model checkpoint to {save_dir} with tag {tag}")
return False
return True
def _create_zero_checkpoint_files(self, save_dir, tag):
success = True
# zero checkpoint files are created sequentially
for rank in range(dist.get_world_size(self.optimizer.zp_process_group)):
if rank == self.global_rank:
success = self._create_checkpoint_file(save_dir, tag, True)
dist.barrier(group=self.optimizer.zp_process_group)
return success
def _save_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False):
save_path = self._get_ckpt_name(save_dir, tag)
zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()
save_frozen_param = self.zero_optimization_partition_gradients() and not exclude_frozen_parameters
# A hack to save the checkpointing directory. Pipeline parallelism overrides
# module_state_dict() and uses this path to save the model. module_state_dict()
# then instead just returns None. The module_state_dict() implementation in
# PipelineEngine expects the save path to be set in self._curr_ckpt_path.
self._curr_ckpt_path = os.path.join(save_dir, tag)
module = self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters)
self._curr_ckpt_path = None
state = dict(module=module,
buffer_names=self._get_buffer_names(),
optimizer=self.optimizer.state_dict() if self.optimizer and not zero_optimizer_state else None,
param_shapes=self._get_zero_param_shapes() if self.optimizer and zero_optimizer_state else None,
frozen_param_shapes=self._get_zero_frozen_param_attributes(self._get_param_shape_func)
if save_frozen_param else None,
shared_params=self._get_shared_params() if self.optimizer and zero_optimizer_state else None,
frozen_param_fragments=self._get_zero_frozen_param_attributes(self._get_param_fragment_func)
if save_frozen_param else None,
lr_scheduler=self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
data_sampler=self.training_dataloader.data_sampler.state_dict() if
(self.training_dataloader is not None and self.curriculum_learning_enabled()) else None,
random_ltd=self.random_ltd_scheduler.state_dict() if self.random_ltd_enabled() else None,
sparse_tensor_module_names=self.sparse_tensor_module_names,
skipped_steps=self.skipped_steps,
global_steps=self.global_steps,
global_samples=self.global_samples,
dp_world_size=self.seq_dp_world_size,
mp_world_size=self.mp_world_size,
ds_config=self.config,
ds_version=version)
state.update(client_state)
if self.save_non_zero_checkpoint:
log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1])
self.checkpoint_engine.save(state, save_path)
def _get_buffer_names(self):
buffer_names = []
# we save buffer names so that we could extract later the real buffers from the saved
# state_dict["module"] in the non-zero checkpoint - the buffers are already there but they
# are intermixed with param placeholders
# have to traverse the tree to be able to skip non-persistent buffers
def get_layer_named_buffers(module, prefix=""):
for name, buf in module.named_buffers(recurse=False):
if buf is not None and name not in module._non_persistent_buffers_set:
buffer_names.append(prefix + name)
for name, child in module.named_children():
if child is not None:
get_layer_named_buffers(child, prefix + name + ".")
get_layer_named_buffers(self.module, prefix="")
return buffer_names
def _get_param_shape_func(self, param):
return param.ds_shape if hasattr(param, 'ds_id') else param.shape
def _get_param_fragment_func(self, param):
return param.ds_tensor.detach().cpu() if hasattr(param, 'ds_id') else param.detach().cpu()
def _get_zero_frozen_param_attributes(self, attr_func):
frozen_param_fragments = OrderedDict()
for param in self.module.parameters():
if param.requires_grad:
continue
if param not in self.param_names:
raise ValueError(f"failed to find frozen {param} in named params")
name = self.param_names[param]
frozen_param_fragments[name] = attr_func(param)
return frozen_param_fragments
def _get_zero_param_shapes(self):
"""Returns a dict of name to shape mapping, only for the flattened fp32 weights saved by the
optimizer. the names are exactly as in state_dict. The order is absolutely important, since
the saved data is just flattened data with no identifiers and requires reconstruction in the
same order it was saved.
We can't rely on self.module.named_parameters() to get the saved tensors, as some params
will be missing and others unsaved and then it'd be impossible to reconstruct state_dict
from the flattened weights.
optimizer.bit16_groups seems to be the easiest to use as it's in all zeroX versions.
"""
param_group_shapes = []
cnt = 0
numel = 0
# zero2 started using a round_robin_bit16_groups which is a shuffled version of bit16_groups -
# if we don't use it, we get parameters ordered incorrectly
if hasattr(self.optimizer, "round_robin_bit16_groups"):
bit16_groups = self.optimizer.round_robin_bit16_groups
elif self.bfloat16_enabled() and hasattr(self.optimizer, "bf16_groups"):
bit16_groups = self.optimizer.bf16_groups
else:
bit16_groups = self.optimizer.bit16_groups if self.zero_optimization_stage(
) == 2 else self.optimizer.fp16_groups
for bit16_group in bit16_groups:
param_shapes = OrderedDict()
for param in bit16_group:
cnt += 1
numel += param.ds_numel if hasattr(param, "ds_numel") else param.numel()
shape = param.ds_shape if hasattr(param, "ds_shape") else param.shape
if param not in self.param_names:
raise ValueError(f"failed to find optimizer param in named params")
name = self.param_names[param]
param_shapes[name] = shape
# uncomment to debug zero_to_fp32.py problems
# if self.global_rank == 0: print(f"saving param {name} {shape} (numel={shape.numel()})")
param_group_shapes.append(param_shapes)
# if self.global_rank == 0: print(f"Total saved {numel} numels in {cnt} params")
return param_group_shapes
def _get_shared_params(self):
"""
Returns a dict of shared params, which can later be used to reconstruct the original state dict,
e.g. in `zero_to_fp32`. Each dict entry is a pair of param names, where the key is the name
of the variable that isn't stored and the value is the actual param holding data.
"""
shared_index = {}
shared_params_by_full_name = {}
is_zero3_model = (self.zero_optimization_partition_weights()
and any(hasattr(param, "ds_id") for param in self.module.parameters()))
def get_layer_state_dict(module, prefix=""):
# handle params
for name, param in module.named_parameters(recurse=False):
if param is None or (is_zero3_model and not hasattr(param, "ds_id")):
continue
key = prefix + name
# When weights are manged by stage 3, we can't rely on param.data_ptr() as it will be reused
# as weights get gathered and reduced, but param.ds_id is unique across all zero weights
# (and shared params will have the same param.ds_id)
param_id = param.ds_id if is_zero3_model else param.data_ptr()
if param_id in shared_index:
# shared weights
#print(f"`{key}` is shared with `{shared_index[param_id]}`")
shared_params_by_full_name[key] = shared_index[param_id]
else:
shared_index[param_id] = key
for name, child in module.named_children():
if child is not None:
get_layer_state_dict(child, prefix + name + ".")
if dist.get_rank() == 0:
get_layer_state_dict(self.module, prefix="")
return shared_params_by_full_name
def _copy_recovery_script(self, save_path):
base_dir = os.path.dirname(os.path.dirname(__file__))
script = "zero_to_fp32.py"
src = os.path.join(base_dir, "utils", script)
dst = os.path.join(save_path, script)
#logger.info(f"creating recovery script {dst}")
copyfile(src, dst)
self._change_recovery_script_permissions(dst)
def _change_recovery_script_permissions(self, dst):
# make executable (safeguard for file shares - Azure as example)
try:
os.chmod(dst, os.stat(dst).st_mode | stat.S_IEXEC)
except (FileNotFoundError, PermissionError) as e:
#this message is used in unit test TestZeRONonDistributed
logger.info(
f'Warning: Could not change permissions for {dst} due to error: {e}. Continuing without changing permissions.'
)
def _save_zero_checkpoint(self, save_path, tag):
zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag)
zero_sd = dict(optimizer_state_dict=self.optimizer.state_dict(), ds_config=self.config, ds_version=version)
self.checkpoint_engine.save(zero_sd, zero_checkpoint_name)
if self.global_rank == 0:
self._copy_recovery_script(save_path)
ckpt_type = 'zero' if self.zero_optimization() else 'bf16_zero'
logger.info(f'{ckpt_type} checkpoint saved {zero_checkpoint_name}')
def _zero3_consolidated_16bit_state_dict(self):
"""
Get a full non-partitioned state_dict with fp16 weights on cpu.
Important: this function must be called on all ranks and not just rank 0.
This is similar to nn.Module.state_dict (modelled after _save_to_state_dict), but:
1. consolidates the weights from different partitions on gpu0
2. works on one layer at a time to require as little gpu0 memory as possible, by
moving the already consolidated weights to cpu
3. takes care to keep the shared params shared when gradually copying the params to cpu
Returns:
a consolidated fp16 ``state_dict`` on cpu on rank 0, ``None`` on other ranks
"""
if not self.zero_optimization_partition_weights():
raise ValueError("this function requires ZeRO-3 mode")
state_dict = OrderedDict() if dist.get_rank() == 0 else None
shared_params = {}
def get_layer_state_dict(module, prefix=""):
# gather one layer at a time to be memory-efficient
# must use modifier_rank=0 to release GPU memory after each layer gathered
#see_memory_usage("before GatheredParameters", force=True)
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
if dist.get_rank() == 0:
# handle params
for name, param in module.named_parameters(recurse=False):
if param is None:
continue
key = prefix + name
# can't rely on param.data_ptr() as it will be reused as weights gets
# gathered and reduced, but param.ds_id is unique across all zero weights
# (and shared params will have the same param.ds_id)
if param.ds_id in shared_params:
# shared weights
#print(f"`{key}` is shared with `{shared_params[param.ds_id]}`")
state_dict[key] = state_dict[shared_params[param.ds_id]]
else:
state_dict[key] = param.detach().cpu()
shared_params[param.ds_id] = key
#print(f"param {param.ds_id} {param.shape} {key} ")
# now buffers - not sure if need to take care of potentially shared weights here
for name, buf in module.named_buffers(recurse=False):
if (buf is not None and name not in module._non_persistent_buffers_set):
state_dict[prefix + name] = buf.detach().cpu()
#see_memory_usage("after GatheredParameters", force=True)
for name, child in module.named_children():
if child is not None:
get_layer_state_dict(child, prefix + name + ".")
# Prepare for checkpoint save by ensuring all parameters are partitioned
if self._optimizer_has_ckpt_event_prologue():
self.optimizer.checkpoint_event_prologue()
see_memory_usage("before get_layer_state_dict", force=False)
get_layer_state_dict(self.module, prefix="")
see_memory_usage("after get_layer_state_dict", force=False)
if self._optimizer_has_ckpt_event_epilogue():
self.optimizer.checkpoint_event_epilogue()
return state_dict
def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"):
"""has been renamed to save_16bit_model, keeping this around for backwards
compatibility"""
return self.save_16bit_model(save_dir, save_filename)
def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin"):
"""
Save 16bit model weights
This method saves the 16bit model weights at the desired destination.
Arguments:
save_dir: Required. Directory for saving the model
save_filename: Optional. Filename to save to. Defaults to ``pytorch_model.bin``
Returns:
``True`` when a model has been saved, ``False`` otherwise. It will not be saved if
stage3_gather_16bit_weights_on_model_save is ``False``.
Important: all processes must call this method and not just the process with rank 0. It is
because the processes need to work in sync to gather the weights. This method will hang
waiting to synchronize with other processes if it's called just for the process with rank 0.
"""
path = os.path.join(save_dir, save_filename)
if self.zero_optimization_partition_weights():
if self.zero_gather_16bit_weights_on_model_save():
# consolidation is expensive in time and memory and therefore isn't a default
state_dict = self._zero3_consolidated_16bit_state_dict()
else:
# the model will be bogus if not consolidated so don't confuse the user by saving it
logger.info(
f"Did not save the model {path} because `stage3_gather_16bit_weights_on_model_save` is False")
return False
else:
state_dict = self.module.state_dict()
tag = f"global_step{self.global_steps}"
tag = str(tag)
self.checkpoint_engine.create(tag)
if dist.get_rank() == 0:
self.checkpoint_engine.makedirs(save_dir, exist_ok=True)
logger.info(f"Saving model weights to {path}, tag: {tag}")
self.checkpoint_engine.save(state_dict, path)
self.checkpoint_engine.commit(tag)
return True
def empty_partition_cache(self):
"""
Release GPU memory consumed by offloaded model parameters.
"""
if hasattr(self.optimizer, 'empty_partition_cache'):
self.optimizer.empty_partition_cache()
gc.collect()
get_accelerator().empty_cache()
================================================
FILE: opensora/adaptor/modules.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
def fp32_layer_norm_forward(self, inputs: torch.Tensor) -> torch.Tensor:
origin_dtype = inputs.dtype
return F.layer_norm(inputs.float(), self.normalized_shape, self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None, self.eps).to(origin_dtype)
def fp32_silu_forward(self, inputs: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.silu(inputs.float(), inplace=self.inplace).to(inputs.dtype)
def fp32_gelu_forward(self, inputs: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.gelu(inputs.float(), approximate=self.approximate).to(inputs.dtype)
def replace_with_fp32_forwards():
nn.GELU.forward = fp32_gelu_forward
nn.SiLU.forward = fp32_silu_forward
nn.LayerNorm.forward = fp32_layer_norm_forward
================================================
FILE: opensora/adaptor/stage_1_and_2.py
================================================
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import os
import pdb
from deepspeed import comm as dist
from packaging import version as pkg_version
from collections import OrderedDict
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed.runtime import ZeROOptimizer
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.utils import (bwc_tensor_model_parallel_rank, get_global_norm, empty_cache, see_memory_usage,
inf, is_model_parallel_parameter, align_dense_tensors, all_gather_dp_groups)
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.utils import logger
from deepspeed.moe.utils import is_moe_param
from deepspeed.git_version_info import version
from deepspeed.runtime.constants import PIPE_REPLICATED
from deepspeed.accelerator import get_accelerator
from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER,
SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE,
BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS)
from deepspeed.utils import link_hp_params
from deepspeed.checkpoint import enable_universal_checkpoint
from deepspeed.utils import groups
from opensora.adaptor.zp_manager import zp_manager
# Toggle this to true to enable correctness test
# with gradient partitioning and without
pg_correctness_test = False
OPTIMIZER_ALLGATHER_TIMER = 'optimizer_allgather'
OPTIMIZER_GRADIENTS_TIMER = 'optimizer_gradients'
OPTIMIZER_STEP_TIMER = 'optimizer_step'
OPTIMIZER_TIMERS = [OPTIMIZER_ALLGATHER_TIMER, OPTIMIZER_GRADIENTS_TIMER, OPTIMIZER_STEP_TIMER]
def input(msg):
return
def split_half_float_double(tensors):
device_type = get_accelerator().device_name()
dtypes = [
"torch.{}.HalfTensor".format(device_type), "torch.{}.FloatTensor".format(device_type),
"torch.{}.DoubleTensor".format(device_type), "torch.{}.BFloat16Tensor".format(device_type)
]
buckets = []
for i, dtype in enumerate(dtypes):
bucket = [t for t in tensors if t.type() == dtype]
if bucket:
buckets.append(bucket)
return buckets
def isclose(a, b, rtol=1e-09, atol=0.0):
return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol)
def lcm(x, y):
from fractions import gcd # or can import gcd from `math` in Python 3
return x * y // gcd(x, y)
def get_alignment_padding(tensor_list, alignment):
num_elements = sum([tensor.numel() for tensor in tensor_list])
remainder = num_elements % alignment
return (alignment - remainder) if remainder else remainder
def move_to_cpu(tensor_list):
for tensor in tensor_list:
tensor.data = tensor.data.cpu()
def print_rank_msg(msg):
print(f"rank {dist.get_rank()} - {msg}")
def _get_padded_tensor(src_tensor, size):
if src_tensor.numel() >= size:
return src_tensor
padded_tensor = torch.zeros(size, dtype=src_tensor.dtype, device=src_tensor.device)
slice_tensor = torch.narrow(padded_tensor, 0, 0, src_tensor.numel())
slice_tensor.data.copy_(src_tensor.data)
return padded_tensor
def contigous_flatten(tensors):
return _flatten_dense_tensors([tensor.contiguous() for tensor in tensors])
def all_gather_into_tensor_dp_groups(groups_flat, partitioned_param_groups, zp_process_group):
for group_id, (group_flat, partitioned_params) in enumerate(zip(groups_flat, partitioned_param_groups)):
partition_id = dist.get_rank(group=zp_process_group[group_id])
dp_world_size = dist.get_world_size(group=zp_process_group[group_id])
if dp_world_size == 1:
# no groups share optimizer states
# pipeline parallel with bf16 will default call this even if dp size = 1.
continue
input_tensor = partitioned_params[partition_id].contiguous()
# print(f"call all_gather_into_tensor_dp_groups, input size is {input_tensor.size()}, "
# f"output size is {group_flat.size()}")
#
# print(f"groups_flat.size = {groups_flat.numel()}")
# print(f"partitioned_param_groups = {sum([v.numel() for v in partitioned_param_groups])}")
dist.all_gather_into_tensor(group_flat, input_tensor, zp_process_group[group_id])
class DeepSpeedZeroOptimizer(ZeROOptimizer):
"""
DeepSpeedZeroOptimizer designed to reduce the memory footprint
required for training large deep learning models.
For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
https://arxiv.org/abs/1910.02054
For usage examples, refer to TODO: DeepSpeed Tutorial
"""
def __init__(self,
init_optimizer,
param_names,
timers,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=True,
contiguous_gradients=True,
reduce_bucket_size=500000000,
use_multi_rank_bucket_allreduce=True,
allgather_bucket_size=5000000000,
dp_process_group=None,
expert_parallel_group=None,
expert_data_parallel_group=None,
reduce_scatter=True,
overlap_comm=False,
offload_optimizer_config=None,
mpu=None,
clip_grad=0.0,
gradient_accumulation_dtype=torch.float32,
communication_data_type=torch.float16,
postscale_gradients=True,
gradient_predivide_factor=1.0,
gradient_accumulation_steps=1,
ignore_unused_parameters=True,
partition_grads=True,
round_robin_gradients=False,
has_moe_layers=False,
fp16_master_weights_and_gradients=False,
elastic_checkpoint=False):
if offload_optimizer_config is not None and offload_optimizer_config.device != OffloadDeviceEnum.none:
self.cpu_offload = True
self.cpu_offload_pin_memory = offload_optimizer_config.pin_memory
else:
self.cpu_offload = False
self.cpu_offload_pin_memory = False
if dist.get_rank() == 0:
logger.info(f"Reduce bucket size {reduce_bucket_size}")
logger.info(f"Allgather bucket size {allgather_bucket_size}")
logger.info(f"CPU Offload: {self.cpu_offload}")
logger.info(f'Round robin gradient partitioning: {round_robin_gradients}')
# The fused optimizer does all the work. We need this layer for two reason:
# 1. maintain same user API from apex.fp16_utils
# 2. keep common stuff here in case we need to add ne552w fused optimizer later
self.elastic_checkpoint = elastic_checkpoint
self.param_names = param_names
self.mpu = mpu
# differences from apex.fp16_utils:
# - assume all model params in fp16
# - assume all params requires grad
# - flat by groups, not keeping state. TODO: remove state explicitly?
# - master grad and unflat master weight never exist. TODO: a way to save out unflat master?
if not get_accelerator().is_available():
raise SystemError("Accelerator is not detected, cannot perform low precision training (e.g., fp16, bf16).")
self.optimizer = init_optimizer
# Use torch (un)flatten ops
self.flatten = contigous_flatten
self.unflatten = _unflatten_dense_tensors
# ZeRO stage 1 (False) or 2 (True)
self.partition_gradients = partition_grads
self.zero_stage_string = "ZeRO-2" if partition_grads else "ZeRO-1"
self.timers = timers
self.reduce_scatter = reduce_scatter
self.overlap_comm = overlap_comm
self.deepspeed_adam_offload = self.cpu_offload
self.device = get_accelerator().current_device_name() if not self.cpu_offload else 'cpu'
zp_manager.init_group()
self.zp_process_group = zp_manager.zp_group
zp_rank = dist.get_rank(group=self.zp_process_group)
zp_size = dist.get_world_size(group=self.zp_process_group)
print(f"zp rank is {zp_rank}, zp_size={zp_size}")
self.dp_process_group = dp_process_group
self.sequence_parallel_size = groups._get_sequence_parallel_world_size()
# expert parallel group
self.ep_process_group = expert_parallel_group
# data parallel group for experts
self.expert_dp_process_group = expert_data_parallel_group
# data parallel size for non-experts
dp_size = dist.get_world_size(group=self.dp_process_group)
# For MoE models this maybe different for different param group
# It will be modified during MoE setup later in the init
self.real_zp_process_group = [self.zp_process_group for i in range(len(self.optimizer.param_groups))]
self.real_dp_process_group = [self.dp_process_group for i in range(len(self.optimizer.param_groups))]
self.partition_count = [zp_manager.zp_size for i in range(len(self.optimizer.param_groups))]
self.is_gradient_accumulation_boundary = True
# CPU-Offload requires contiguous gradients
self.contiguous_gradients = contiguous_gradients or self.cpu_offload
self.has_moe_layers = has_moe_layers
if self.has_moe_layers:
self._configure_moe_settings()
self._global_grad_norm = 0.
if mpu is None:
self.model_parallel_group = None
self.model_parallel_world_size = 1
self.model_parallel_rank = 0
else:
self.model_parallel_group = mpu.get_model_parallel_group()
self.model_parallel_world_size = mpu.get_model_parallel_world_size()
self.model_parallel_rank = bwc_tensor_model_parallel_rank(mpu)
self.overflow = False
self.clip_grad = clip_grad
self.communication_data_type = communication_data_type
self.gradient_predivide_factor = gradient_predivide_factor
self.postscale_gradients = postscale_gradients
self.gradient_accumulation_steps = gradient_accumulation_steps
self.micro_step_id = 0
self.ignore_unused_parameters = ignore_unused_parameters
self.round_robin_gradients = round_robin_gradients
self.extra_large_param_to_reduce = None
self.fp16_master_weights_and_gradients = fp16_master_weights_and_gradients
if self.fp16_master_weights_and_gradients:
assert self.cpu_offload and type(self.optimizer) in [DeepSpeedCPUAdam], \
f"fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32." \
f"Currently only supported using ZeRO-Offload with DeepSpeedCPUAdam. But current setting is ZeRO-Offload:{self.cpu_offload} and optimizer type {type(self.optimizer)}." \
f"Either disable fp16_master_weights_and_gradients or enable {self.zero_stage_string} Offload with DeepSpeedCPUAdam."
if self.reduce_scatter:
valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32)
assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'"
assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled"
assert self.postscale_gradients, "pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled"
# param flattened by groups
self.bit16_groups = []
self.bit16_groups_flat = []
# param partitioned by data parallel degree
# this will contain a list of equal sized tensors
# each of which will be updated by a different process
self.parallel_partitioned_bit16_groups = []
# a single 32-bit partition of the parallel partitioned parameters
# that this process will update
self.single_partition_of_fp32_groups = []
# param partition info
# These are the parameters in each group that will not be updated by this process directly
self.params_not_in_partition = []
# These are the parameters that will be updated by this process directly
self.params_in_partition = []
# Offset from the first parameter in the self.params_in_partition
# the parameter boundaries may not align with partition boundaries
# so we need to keep track of the offset
self.first_offset = []
# number of elements per partition in each group
self.partition_size = []
# align nccl all-gather send buffers to 4-byte boundary
self.nccl_start_alignment_factor = 16 # 4-byte alignment/sizeof(fp16) = 2
assert (
allgather_bucket_size % self.nccl_start_alignment_factor == 0
), f"allgather_bucket_size must be a multiple of nccl_start_alignment_factor, {self.nccl_start_alignment_factor} "
self.all_reduce_print = False
self.dtype = self.optimizer.param_groups[0]['params'][0].dtype
self.gradient_accumulation_dtype = gradient_accumulation_dtype
if self.dtype != self.gradient_accumulation_dtype:
self.use_separate_grad_accum = True
else:
self.use_separate_grad_accum = False
if self.use_separate_grad_accum and not self.partition_gradients:
self.use_grad_accum_attribute = True
else:
self.use_grad_accum_attribute = False
self.round_robin_bit16_groups = []
self.round_robin_bit16_indices = []
# Use different parallel to do all_to_all_reduce related things
# padding on each partition for alignment purposes
self.groups_padding = []
# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
partition_id = dist.get_rank(group=self.real_zp_process_group[i])
# push this group to list before modify
# TODO: Explore simplification that avoids the extra book-keeping by pushing the reordered group
trainable_parameters = []
for param in param_group['params']:
if param.requires_grad:
param.grad_accum = None
trainable_parameters.append(param)
self.bit16_groups.append(trainable_parameters)
# not sure why apex was cloning the weights before flattening
# removing cloning here
see_memory_usage(f"Before moving param group {i} to CPU")
# move all the parameters to cpu to free up GPU space for creating flat buffer
move_to_cpu(self.bit16_groups[i])
empty_cache()
see_memory_usage(f"After moving param group {i} to CPU", force=False)
# Reorder group parameters for load balancing of gradient partitioning during backward among ranks.
# This ensures that gradients are reduced in a fashion such that ownership round robins among the ranks.
# For example, rather than 3 gradients (g_n+2, g_n+1, g_n) that are reduced consecutively belonging
# to the same rank, instead they will belong to 3 ranks (r_m+2, r_m+1, r_m).
if self.round_robin_gradients:
round_robin_tensors, round_robin_indices = self._round_robin_reorder(
self.bit16_groups[i], dist.get_world_size(group=self.real_zp_process_group[i]))
else:
round_robin_tensors = self.bit16_groups[i]
round_robin_indices = list(range(len(self.bit16_groups[i])))
self.round_robin_bit16_groups.append(round_robin_tensors)
self.round_robin_bit16_indices.append(round_robin_indices)
# create flat buffer in CPU and move to GPU
self.bit16_groups_flat.append(
self.flatten_dense_tensors_aligned(
self.round_robin_bit16_groups[i],
self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_zp_process_group[i])).to(
get_accelerator().current_device_name()))
see_memory_usage(f"After flattening and moving param group {i} to GPU", force=False)
# Record padding required for alignment
if partition_id == dist.get_world_size(group=self.real_zp_process_group[i]) - 1:
padding = self.bit16_groups_flat[i].numel() - sum(
[t.numel() for t in self.round_robin_bit16_groups[i]])
else:
padding = 0
self.groups_padding.append(padding)
if dist.get_rank(group=self.real_zp_process_group[i]) == 0:
see_memory_usage(f"After Flattening and after emptying param group {i} cache", force=False)
# set model bit16 weight to slices of flattened buffer
self._update_model_bit16_weights(i)
# divide the flat weights into near equal partition equal to the data parallel degree
# each process will compute on a different part of the partition
data_parallel_partitions = self.get_data_parallel_partitions(self.bit16_groups_flat[i], i)
self.parallel_partitioned_bit16_groups.append(data_parallel_partitions)
# print(f"self.bit16_groups_flat[i].size = {self.bit16_groups_flat[i].numel()}")
# print(f"data_parallel_partitions = {sum([v.numel() for v in data_parallel_partitions])}")
# verify that data partition start locations are 4-byte aligned
for partitioned_data in data_parallel_partitions:
assert (partitioned_data.data_ptr() % (2 * self.nccl_start_alignment_factor) == 0)
# A partition of the fp32 master weights that will be updated by this process.
# Note that the params in single_partition_of_fp32_groups is cloned and detached
# from the origin params of the model.
if not fp16_master_weights_and_gradients:
self.single_partition_of_fp32_groups.append(self.parallel_partitioned_bit16_groups[i][partition_id].to(
self.device).clone().float().detach())
else:
self.single_partition_of_fp32_groups.append(self.parallel_partitioned_bit16_groups[i][partition_id].to(
self.device).clone().half().detach())
# Set local optimizer to have flat params of its own partition.
# After this, the local optimizer will only contain its own partition of params.
# In that case, the local optimizer only saves the states(momentum, variance, etc.) related to its partition's params(zero stage1).
self.single_partition_of_fp32_groups[
i].requires_grad = True # keep this in case internal optimizer uses it
param_group['params'] = [self.single_partition_of_fp32_groups[i]]
partition_size = len(self.bit16_groups_flat[i]) / dist.get_world_size(group=self.real_zp_process_group[i])
params_in_partition, params_not_in_partition, first_offset = self.get_partition_info(
self.round_robin_bit16_groups[i], partition_size, partition_id)
self.partition_size.append(partition_size)
self.params_in_partition.append(params_in_partition)
self.params_not_in_partition.append(params_not_in_partition)
self.first_offset.append(first_offset)
self.reduce_bucket_size = int(reduce_bucket_size)
self.use_multi_rank_bucket_allreduce = use_multi_rank_bucket_allreduce
self.allgather_bucket_size = int(allgather_bucket_size)
self.reduction_stream = None if get_accelerator().is_synchronized_device() else get_accelerator().Stream()
# self.copy_grad_stream = get_accelerator().Stream()
self.callback_queued = False
self.param_dict = {}
# map between param_id and bool to specify if a param is in this partition
self.is_param_in_current_partition = {}
self.grads_in_ipg_bucket = []
self.params_in_ipg_bucket = []
self.elements_in_ipg_bucket = 0
self.params_already_reduced = []
self._release_ipg_buffers()
self.previous_reduced_grads = None
self.ipg_bucket_has_moe_params = False
# simplified param id
self.param_id = {}
# interesting code: unique ids being assigned to individual parameters
largest_param_numel = 0
count = 0
for i, params_group in enumerate(self.bit16_groups):
for param in params_group:
unique_id = id(param)
self.param_id[unique_id] = count
self.param_dict[count] = param
self.params_already_reduced.append(False)
if param.numel() > largest_param_numel:
largest_param_numel = param.numel()
count = count + 1
for param_group in self.params_in_partition:
for param in param_group:
self.is_param_in_current_partition[self.get_param_id(param)] = True
for param_group in self.params_not_in_partition:
for param in param_group:
self.is_param_in_current_partition[self.get_param_id(param)] = False
if self.cpu_offload:
self.accumulated_grads_in_cpu = {}
self.norm_for_param_grads = {}
self.local_overflow = False
self.grad_position = {}
self.temp_grad_buffer_for_cpu_offload = torch.zeros(largest_param_numel,
device=self.device,
dtype=self.dtype)
if self.cpu_offload_pin_memory:
self.temp_grad_buffer_for_cpu_offload = get_accelerator().pin_memory(
self.temp_grad_buffer_for_cpu_offload)
self.temp_grad_buffer_for_gpu_offload = torch.zeros(largest_param_numel,
device=get_accelerator().current_device_name(),
dtype=self.dtype)
for i, params_group in enumerate(self.bit16_groups):
self.get_grad_position(i, self.params_in_partition[i], self.first_offset[i], self.partition_size[i])
# mapping from parameter to partition that it belongs to
self.param_to_partition_ids = {}
# stores if a partition has been reduced in this step
self.is_partition_reduced = {}
# number of grads in partition that still need to be computed
self.remaining_grads_in_partition = {}
# total number of grads in partition
self.total_grads_in_partition = {}
# stores if a grad in a partition has been computed or not
self.is_grad_computed = {}
# stores the offset at which a parameter gradient needs to be inserted in a partition
self.grad_partition_insertion_offset = {}
# the offset in the gradient at which it must be inserted at the beginning of the partition
self.grad_start_offset = {}
# will store the averaged gradients required by this partition
self.averaged_gradients = {}
# For cpu_offload, will store the averaged gradients required by this partition
self.offload_gradient_dict = {}
# store index of first parameter in each partition
self.first_param_index_in_partition = {}
# initializes all data structures for implementing gradient partitioning
self.initialize_gradient_partitioning_data_structures()
# resets the data structure value for the next backward propagation
self.reset_partition_gradient_structures()
# creates backward hooks for gradient partitioning
if self.partition_gradients or self.overlap_comm:
self.create_reduce_and_remove_grad_hooks()
self.custom_loss_scaler = False
self.external_loss_scale = None
# we may have a way of fusing dynamic scale. Do not support for now
self.loss_scaler = CreateLossScaler(dtype=self.dtype,
static_loss_scale=static_loss_scale,
dynamic_scaling=dynamic_loss_scale,
dynamic_loss_args=dynamic_loss_args)
self.dynamic_loss_scale = self.loss_scaler.dynamic
if self.dtype != torch.float16:
# Only fp16 should use dynamic loss scaling
assert self.loss_scaler.cur_scale == 1.0
assert not self.dynamic_loss_scale
see_memory_usage("Before initializing optimizer states", force=True)
self.initialize_optimizer_states()
see_memory_usage("After initializing optimizer states", force=True)
if dist.get_rank() == 0:
logger.info(f"optimizer state initialized")
if dist.get_rank(group=self.zp_process_group) == 0:
see_memory_usage(f"After initializing ZeRO optimizer", force=True)
self._link_all_hp_params()
self._enable_universal_checkpoint()
self._param_slice_mappings = self._create_param_mapping()
def _enable_universal_checkpoint(self):
for lp_param_group in self.bit16_groups:
enable_universal_checkpoint(param_list=lp_param_group)
def _create_param_mapping(self):
param_mapping = []
for i, _ in enumerate(self.optimizer.param_groups):
param_mapping_per_group = OrderedDict()
for lp in self.bit16_groups[i]:
if lp._hp_mapping is not None:
lp_name = self.param_names[lp]
param_mapping_per_group[lp_name] = lp._hp_mapping.get_hp_fragment_address()
param_mapping.append(param_mapping_per_group)
return param_mapping
def _link_all_hp_params(self):
dp_world_size = dist.get_world_size(group=self.zp_process_group)
if self.cpu_offload:
self._get_offload_gradient_dict()
for i, _ in enumerate(self.optimizer.param_groups):
# Link bit16 and fp32 params in partition
partition_id = dist.get_rank(group=self.real_zp_process_group[i])
partition_size = self.bit16_groups_flat[i].numel() // dp_world_size
flat_hp_partition = self.single_partition_of_fp32_groups[i]
link_hp_params(lp_param_list=self.bit16_groups[i],
flat_hp_partition=flat_hp_partition,
gradient_dict=self.averaged_gradients,
offload_gradient_dict=self.offload_gradient_dict,
use_offload=self.cpu_offload,
param_group_index=i,
partition_start=partition_id * partition_size,
partition_size=partition_size,
partition_optimizer_state=self.optimizer.state[flat_hp_partition],
dp_group=self.real_zp_process_group[i])
def is_moe_group(self, group):
return 'moe' in group and group['moe']
def _configure_moe_settings(self):
# if we're using ZeRO stage 2, ensure contiguous gradients are used
if self.partition_gradients:
assert self.contiguous_gradients, "Contiguous Gradients in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"
# NOTE: To run ZeRO stage 1 with MoE, we need to set self.contiguous_gradients to True or ignore the assertion
if not self.partition_gradients and not self.contiguous_gradients:
logger.warn(
"ZeRO Stage 1 has not been thoroughly tested with MoE. This configuration is still experimental.")
assert self.reduce_scatter, "Reduce Scatter in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"
assert any(
[self.is_moe_group(group) for group in self.optimizer.param_groups]
), "The model has moe layers, but None of the param groups are marked as MoE. Create a param group with 'moe' key set to True before creating optimizer"
self.is_moe_param_group = []
for i, group in enumerate(self.optimizer.param_groups):
if self.is_moe_group(group):
assert all([is_moe_param(param)
for param in group['params']]), "All params in MoE group must be MoE params"
self.real_zp_process_group[i] = self.expert_dp_process_group[group['name']]
self.partition_count[i] = dist.get_world_size(group=self.expert_dp_process_group[group['name']])
self.is_moe_param_group.append(True)
else:
self.is_moe_param_group.append(False)
assert self.expert_dp_process_group is not None, "Expert data parallel group should be configured with MoE"
assert self.ep_process_group is not None, "Expert parallel group should be configured with MoE"
def _update_model_bit16_weights(self, group_index):
updated_params = self.unflatten(self.bit16_groups_flat[group_index],
self.round_robin_bit16_groups[group_index])
for p, q in zip(self.round_robin_bit16_groups[group_index], updated_params):
p.data = q.data
# set model fp16 weight to slices of reordered flattened buffer
for param_index, param in enumerate(self.bit16_groups[group_index]):
new_index = self.round_robin_bit16_indices[group_index][param_index]
param.data = self.round_robin_bit16_groups[group_index][new_index].data
def _round_robin_reorder(self, tensor_list, num_partitions):
# disable round robin if need to debug something
# return tensor_list, list(range(len(tensor_list)))
partition_tensors = {}
for i, tensor in enumerate(tensor_list):
j = i % num_partitions
if not j in partition_tensors:
partition_tensors[j] = []
partition_tensors[j].append((i, tensor))
reordered_tensors = []
reordered_indices = {}
for partition_index in partition_tensors.keys():
for i, (original_index, tensor) in enumerate(partition_tensors[partition_index]):
reordered_indices[original_index] = len(reordered_tensors)
reordered_tensors.append(tensor)
return reordered_tensors, reordered_indices
def _release_ipg_buffers(self):
if self.contiguous_gradients:
self.ipg_buffer = None
self.grads_in_partition = None
self.grads_in_partition_offset = 0
def initialize_optimizer_states(self):
for i, group in enumerate(self.bit16_groups):
single_grad_partition = torch.zeros(int(self.partition_size[i]),
dtype=self.single_partition_of_fp32_groups[i].dtype,
device=self.device)
self.single_partition_of_fp32_groups[i].grad = get_accelerator().pin_memory(
single_grad_partition) if self.cpu_offload_pin_memory else single_grad_partition
# Initialize the optimizer states with the flattened fp32 partition.
# State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers
# which do lazy initialization of the state at the first call to step.
if isinstance(self.optimizer, torch.optim.Adagrad):
self.optimizer = torch.optim.Adagrad(self.single_partition_of_fp32_groups, **self.optimizer.defaults)
else:
self.optimizer.step()
if not self.cpu_offload:
for group in self.single_partition_of_fp32_groups:
group.grad = None # class init
return
#########################################################################
#################### ZeRO Stage 1 - reduce gradients ####################
#########################################################################
def reduce_gradients(self, pipeline_parallel=False):
world_size = dist.get_world_size(self.zp_process_group)
my_rank = dist.get_rank(self.zp_process_group)
# with PP we must create ipg buffer, since backward is handled outside zero
if pipeline_parallel and self.contiguous_gradients:
self.ipg_buffer = []
buf_0 = torch.empty(int(self.reduce_bucket_size),
dtype=self.dtype,
device=get_accelerator().current_device_name())
self.ipg_buffer.append(buf_0)
self.ipg_index = 0
if not self.overlap_comm:
for i, group in enumerate(self.bit16_groups):
for param in group:
grad_reduc = self.get_gradient_for_reduction(param)
if grad_reduc is not None:
self.reduce_ready_partitions_and_remove_grads(param, i)
# reduce any pending grads in either hook/non-hook case
self.overlapping_partition_gradients_reduce_epilogue()
#########################################################################
#########################ZeRO Partition Gradients########################
#########################################################################
def get_first_param_index(self, group_id, param_group, partition_id):
for index, param in enumerate(param_group):
param_id = self.get_param_id(param)
if partition_id in self.param_to_partition_ids[group_id][param_id]:
return index
return None
def initialize_gradient_partitioning_data_structures(self):
for i, param_group in enumerate(self.round_robin_bit16_groups):
total_partitions = dist.get_world_size(group=self.real_zp_process_group[i])
self.param_to_partition_ids[i] = {}
self.is_partition_reduced[i] = {}
self.total_grads_in_partition[i] = {}
self.remaining_grads_in_partition[i] = {}
self.is_grad_computed[i] = {}
self.grad_partition_insertion_offset[i] = {}
self.grad_start_offset[i] = {}
self.first_param_index_in_partition[i] = {}
for partition_id in range(total_partitions):
self.is_grad_computed[i][partition_id] = {}
self.grad_partition_insertion_offset[i][partition_id] = {}
self.grad_start_offset[i][partition_id] = {}
self.total_grads_in_partition[i][partition_id] = 0
self.initialize_gradient_partition(i, param_group, partition_id)
self.is_partition_reduced[i][partition_id] = False
self.first_param_index_in_partition[i][partition_id] = self.get_first_param_index(
i, param_group, partition_id)
def independent_gradient_partition_epilogue(self):
self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0)
self.reduce_ipg_grads()
self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0)
# if dist.get_rank() == 0:
# logger.info("Params already reduced %s", self.params_already_reduced)
for i in range(len(self.params_already_reduced)):
self.params_already_reduced[i] = False
if self.overlap_comm:
get_accelerator().synchronize()
# It is safe to clear previously reduced grads of other partitions
self._clear_previous_reduced_grads()
if self.cpu_offload is False:
for i, _ in enumerate(self.bit16_groups):
if not i in self.averaged_gradients or self.averaged_gradients[i] is None:
self.averaged_gradients[i] = self.get_flat_partition(
self.params_in_partition[i],
self.first_offset[i],
self.partition_size[i],
dtype=self.gradient_accumulation_dtype,
device=get_accelerator().current_device_name(),
return_tensor_list=True)
else:
avg_new = self.get_flat_partition(self.params_in_partition[i],
self.first_offset[i],
self.partition_size[i],
dtype=self.gradient_accumulation_dtype,
device=get_accelerator().current_device_name(),
return_tensor_list=True)
for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i], avg_new):
accumulated_grad.add_(new_avg_grad)
self._release_ipg_buffers()
# No need to keep the gradients anymore.
# All gradients required by the step
# are in self.averaged_gradients
self.zero_grad(set_to_none=True)
see_memory_usage(f"End ipg_epilogue")
# resets all partition to no reduced
# sets remaining grads to the total number of grads in each partition
# set is grad computed to false for all grads in partition
def reset_partition_gradient_structures(self):
for i, _ in enumerate(self.bit16_groups):
total_partitions = dist.get_world_size(group=self.real_zp_process_group[i])
for partition_id in range(total_partitions):
self.is_partition_reduced[i][partition_id] = False
self.remaining_grads_in_partition[i][partition_id] = self.total_grads_in_partition[i][partition_id]
for param_id in self.is_grad_computed[i][partition_id]:
self.is_grad_computed[i][partition_id][param_id] = False
def initialize_gradient_partition(self, i, param_group, partition_id):
def set_key_value_list(dictionary, key, value):
if key in dictionary:
dictionary[key].append(value)
else:
dictionary[key] = [value]
def increment_value(dictionary, key):
if key in dictionary:
dictionary[key] += 1
else:
dictionary[key] = 1
partition_size = self.partition_size[i]
start_index = partition_size * partition_id
end_index = partition_size * (partition_id + 1)
current_index = 0
first_offset = 0
for param in param_group:
param_size = param.numel()
param_id = self.get_param_id(param)
if start_index <= current_index < end_index:
set_key_value_list(self.param_to_partition_ids[i], param_id, partition_id)
increment_value(self.total_grads_in_partition[i], partition_id)
self.is_grad_computed[i][partition_id][param_id] = False
self.grad_partition_insertion_offset[i][partition_id][param_id] = current_index - start_index
self.grad_start_offset[i][partition_id][param_id] = 0
elif current_index < start_index < (current_index + param_size):
assert (first_offset == 0
), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
set_key_value_list(self.param_to_partition_ids[i], param_id, partition_id)
increment_value(self.total_grads_in_partition[i], partition_id)
self.is_grad_computed[i][partition_id][param_id] = False
self.grad_partition_insertion_offset[i][partition_id][param_id] = 0
self.grad_start_offset[i][partition_id][param_id] = first_offset
current_index = current_index + param_size
def overlapping_partition_gradients_reduce_epilogue(self):
self.independent_gradient_partition_epilogue()
def fill_grad_accum_attribute(self):
for group in self.bit16_groups:
for param in group:
if param.grad is not None:
if param.grad_accum is None:
param.grad_accum = param.grad.to(self.gradient_accumulation_dtype)
else:
param.grad_accum.add_(
param.grad.to(self.gradient_accumulation_dtype).view(param.grad_accum.shape))
param.grad = None
def get_gradient_for_reduction(self, param):
if self.use_grad_accum_attribute:
return param.grad_accum.to(self.dtype) if param.grad_accum is not None else None
else:
return param.grad
def get_param_gradient_attribute(self, param):
return param.grad_accum if self.use_grad_accum_attribute else param.grad
# Clear the tensor the reduction gradient attribute is pointing to
def clear_grad_attribute(self, param):
if self.use_grad_accum_attribute:
param.grad_accum = None
else:
param.grad = None
def create_reduce_and_remove_grad_hooks(self):
self.grad_accs = []
for i, param_group in enumerate(self.bit16_groups):
for param in param_group:
if param.requires_grad:
def wrapper(param, i):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def reduce_partition_and_remove_grads(*notneeded):
self.reduce_ready_partitions_and_remove_grads(param, i)
grad_acc.register_hook(reduce_partition_and_remove_grads)
self.grad_accs.append(grad_acc)
wrapper(param, i)
def get_param_id(self, param):
unique_id = id(param)
return self.param_id[unique_id]
def report_ipg_memory_usage(self, tag, param_elems):
elem_count = self.elements_in_ipg_bucket + param_elems
percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size
see_memory_usage(
f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}"
)
# create a flat tensor aligned at the alignment boundary
def flatten_dense_tensors_aligned(self, tensor_list, alignment):
return self.flatten(align_dense_tensors(tensor_list, alignment))
############### Independent Partition Gradient ########################
def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
grad_reduc = self.get_gradient_for_reduction(param)
if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size:
self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.numel())
self.reduce_ipg_grads()
if self.contiguous_gradients and self.overlap_comm:
# Swap ipg_index between 0 and 1
self.ipg_index = 1 - self.ipg_index
self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads", param.numel())
param_id = self.get_param_id(param)
assert self.params_already_reduced[param_id] == False, \
f"The parameter {param_id} has already been reduced. \
Gradient computed twice for this partition. \
Multiple gradient reduction is currently not supported"
if self.contiguous_gradients:
if param.numel() > self.reduce_bucket_size:
self.extra_large_param_to_reduce = param
else:
# keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening
new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(0, self.elements_in_ipg_bucket, param.numel())
new_grad_tensor.copy_(grad_reduc.view(-1))
grad_reduc.data = new_grad_tensor.data.view_as(grad_reduc)
self.elements_in_ipg_bucket += param.numel()
assert grad_reduc is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient"
self.grads_in_ipg_bucket.append(grad_reduc)
self.params_in_ipg_bucket.append((i, param, param_id))
# make sure the average tensor function knows how to average the gradients
if is_moe_param(param):
self.ipg_bucket_has_moe_params = True
self.report_ipg_memory_usage("End ipg_remove_grads", 0)
def print_rank_0(self, message):
if dist.get_rank() == 0:
logger.info(message)
def gradient_reduction_w_predivide(self, tensor):
dp_world_size = dist.get_world_size(group=self.dp_process_group)
tensor_to_allreduce = tensor
if self.communication_data_type != tensor.dtype:
tensor_to_allreduce = tensor.to(self.communication_data_type)
if self.postscale_gradients:
if self.gradient_predivide_factor != 1.0:
tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor)
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
if self.gradient_predivide_factor != dp_world_size:
tensor_to_allreduce.mul_(self.gradient_predivide_factor /
(dp_world_size / float(self.sequence_parallel_size)))
else:
tensor_to_allreduce.div_(dp_world_size / float(self.sequence_parallel_size))
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce:
tensor.copy_(tensor_to_allreduce)
return tensor
def allreduce_and_copy_with_multiple_ranks(self,
small_bucket,
log=None,
divide=True,
process_group=None,
bucket_ranks=None):
process_group = self.zp_process_group if process_group is None else process_group
allreduced = self.allreduce_bucket(small_bucket, log=log, divide=divide, process_group=process_group)
for buf, synced, bucket_rank in zip(small_bucket, self.unflatten(allreduced, small_bucket), bucket_ranks):
if dist.get_rank(group=process_group) == bucket_rank:
buf.copy_(synced)
def allreduce_and_scatter(self, bucket, numel_per_bucket=500000000, log=None, divide=True, process_group=None):
small_bucket = []
small_bucket_ranks = []
numel = 0
allreduce_sizes = []
for i, bucket_elem in enumerate(bucket):
rank, tensor = bucket_elem
small_bucket.append(tensor)
small_bucket_ranks.append(rank)
numel = numel + tensor.numel()
if numel > numel_per_bucket:
self.allreduce_and_copy_with_multiple_ranks(small_bucket,
log=None,
divide=divide,
process_group=process_group,
bucket_ranks=small_bucket_ranks)
small_bucket = []
small_bucket_ranks = []
numel = 0
if len(small_bucket) > 0:
self.allreduce_and_copy_with_multiple_ranks(small_bucket,
log=None,
divide=divide,
process_group=process_group,
bucket_ranks=small_bucket_ranks)
def average_tensor(self, tensor):
if self.overlap_comm:
stream = self.reduction_stream
if not get_accelerator().is_synchronized_device():
stream.wait_stream(get_accelerator().current_stream())
else:
stream = get_accelerator().current_stream()
with get_accelerator().stream(stream):
if not self.reduce_scatter:
self.gradient_reduction_w_predivide(tensor)
return
# Accumulate destination ranks and bucket offsets for each gradient slice.
# Note: potential future optimization, record access pattern of parameters
# in backward pass and partition gradients w.r.t. access pattern so that our
# bucket is guaranteed to be contiguous w.r.t. ranks
rank_and_offsets = []
real_dp_process_group = []
curr_size = 0
prev_id, prev_process_group = -1, None
process_group = self.zp_process_group
# count = 0
for i, param, param_id in self.params_in_ipg_bucket:
process_group = self.zp_process_group
grad_reduc = self.get_gradient_for_reduction(param)
# Averages gradients at parameter level if ipg has a moe param
# Otherwise averaging is done at the entire buffer level at the end of the loop
# MoE param have different groups
if self.ipg_bucket_has_moe_params:
process_group = self.expert_dp_process_group[param.group_name] if is_moe_param(
param) else self.zp_process_group
grad_reduc.data.div_(dist.get_world_size(group=process_group) / float(self.sequence_parallel_size))
partition_ids = self.param_to_partition_ids[i][param_id]
assert all([p_id < dist.get_world_size(group=process_group) for p_id in partition_ids
]), f"world size {dist.get_world_size(group=process_group)} and p_ids: {partition_ids}"
partition_size = self.partition_size[i]
# Get all partition ids + their offsets
partition_ids_w_offsets = []
for partition_id in partition_ids:
offset = self.grad_start_offset[i][partition_id][param_id]
partition_ids_w_offsets.append((partition_id, offset))
partition_ids_w_offsets.sort(key=lambda t: t[1])
# Calculate rank and offsets for grad slices
for idx in range(len(partition_ids_w_offsets)):
partition_id, offset = partition_ids_w_offsets[idx]
# if dist.get_rank() == 0 and count < 100:
# print(f"Rank {dist.get_rank()} rank offset id {idx} calculated dp size {dist.get_world_size(group=process_group)} real dp size {dist.get_world_size(self.real_dp_process_group[i])} and dst: {partition_id}")
# count += 1
# Calculate numel for grad slice depending on partition location
if idx == len(partition_ids_w_offsets) - 1:
# Last partition_id uses its own offset
numel = param.numel() - offset
else:
# Set numel to next partition's offset
numel = partition_ids_w_offsets[idx + 1][1] - offset
# Merge bucket ranges if they belong to the same rank
if partition_id == prev_id and process_group == prev_process_group:
prev_pid, prev_size, prev_numel = rank_and_offsets[-1]
rank_and_offsets[-1] = (prev_pid, prev_size, prev_numel + numel)
else:
rank_and_offsets.append((partition_id, curr_size, numel))
real_dp_process_group.append(process_group)
curr_size += numel
prev_id, prev_process_group = partition_id, process_group
if not self.ipg_bucket_has_moe_params:
tensor.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size))
buckets = {}
for i, (dst, bucket_offset, numel) in enumerate(rank_and_offsets):
grad_slice = tensor.narrow(0, int(bucket_offset), int(numel))
bucket_key = real_dp_process_group[i] if self.use_multi_rank_bucket_allreduce else (
dst, real_dp_process_group[i])
if bucket_key not in buckets:
buckets[bucket_key] = []
if self.use_multi_rank_bucket_allreduce:
buckets[bucket_key].append((dst, grad_slice))
else:
buckets[bucket_key].append(grad_slice)
for bucket_key in buckets:
if self.use_multi_rank_bucket_allreduce:
self.allreduce_and_scatter(buckets[bucket_key],
numel_per_bucket=self.reduce_bucket_size,
divide=self.ipg_bucket_has_moe_params,
process_group=bucket_key)
else:
dst, process_group = bucket_key
self.allreduce_no_retain(buckets[bucket_key],
numel_per_bucket=self.reduce_bucket_size,
rank=dst,
divide=self.ipg_bucket_has_moe_params,
process_group=process_group)
##############################################################################
############################# CPU Offload Methods#############################
##############################################################################
def get_grad_position(self, group_id, tensor_list, first_offset, partition_size):
current_offset = 0
for i, tensor in enumerate(tensor_list):
param_id = self.get_param_id(tensor)
param_start_offset = 0
num_elements = tensor.numel()
# we need to offset to get to the right element
if i == 0 and first_offset > 0:
tensor_offset = first_offset
num_elements = num_elements - tensor_offset
param_start_offset = first_offset
# we dont need all elements of the tensor
if num_elements > (partition_size - current_offset):
num_elements = partition_size - current_offset
self.grad_position[param_id] = [
int(group_id), int(param_start_offset),
int(current_offset), int(num_elements)
]
current_offset += num_elements
def update_overflow_tracker_for_param_grad(self, param):
grad_accum = self.get_param_gradient_attribute(param)
if grad_accum is not None and self._has_inf_or_nan(grad_accum.data):
self.local_overflow = True
def _get_offload_gradient_dict(self):
for param_group_index, _ in enumerate(self.optimizer.param_groups):
self.offload_gradient_dict[param_group_index] = []
for lp_param in self.params_in_partition[param_group_index]:
param_id = self.get_param_id(lp_param)
[_, _, dest_offset, num_elements] = self.grad_position[param_id]
dest_tensor = self.single_partition_of_fp32_groups[param_group_index].grad.view(-1).narrow(
0, dest_offset, num_elements)
self.offload_gradient_dict[param_group_index].append(dest_tensor)
def async_accumulate_grad_in_cpu_via_gpu(self, param):
param_id = self.get_param_id(param)
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
# copy to a preexisiting buffer to avoid memory allocation penalty
dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow(0, 0, param.numel())
# buffer for storing gradients for this parameter in CPU
def buffer_to_accumulate_to_in_cpu():
if not self.fp16_master_weights_and_gradients:
buffer = torch.zeros(param.numel(), dtype=param.dtype, device=self.device)
return get_accelerator().pin_memory(buffer) if self.cpu_offload_pin_memory else buffer
else:
return self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, num_elements)
# accumulate gradients into param.grad_accum or parts of it that belongs to this partition
def accumulate_gradients():
grad_accum = self.get_param_gradient_attribute(param)
if not self.fp16_master_weights_and_gradients:
dest_buffer.copy_(self.accumulated_grads_in_cpu[param_id].view(-1), non_blocking=True)
grad_accum.data.view(-1).add_(dest_buffer)
else:
dest_buffer.narrow(0, source_offset,
num_elements).copy_(self.accumulated_grads_in_cpu[param_id].view(-1),
non_blocking=True)
grad_accum.data.view(-1).narrow(0, source_offset,
num_elements).add_(dest_buffer.narrow(0, source_offset, num_elements))
# move accumulated gradients back to CPU
def copy_gradients_to_cpu():
grad_accum = self.get_param_gradient_attribute(param)
if not self.fp16_master_weights_and_gradients:
self.accumulated_grads_in_cpu[param_id].data.copy_(grad_accum.data.view(-1), non_blocking=True)
else:
self.accumulated_grads_in_cpu[param_id].data.copy_(grad_accum.data.view(-1).narrow(
0, source_offset, num_elements),
non_blocking=True)
if param_id not in self.accumulated_grads_in_cpu:
self.accumulated_grads_in_cpu[param_id] = buffer_to_accumulate_to_in_cpu()
if self.micro_step_id > 0:
accumulate_gradients()
# at the boundary we will send 32bit directly
if not self.is_gradient_accumulation_boundary:
copy_gradients_to_cpu()
def set_norm_for_param_grad(self, param):
param_id = self.get_param_id(param)
grad_accum = self.get_param_gradient_attribute(param)
accumulated_grad = self.accumulated_grads_in_cpu[
param_id] if self.gradient_accumulation_steps > 1 else grad_accum
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
start = source_offset
accumulated_grad = accumulated_grad.view(-1).narrow(0, start, num_elements)
self.norm_for_param_grads[param_id] = accumulated_grad.data.double().norm(2)
def set_norm_for_param_grad_in_gpu(self, param):
param_id = self.get_param_id(param)
grad_accum = self.get_param_gradient_attribute(param)
if grad_accum is None:
accumulated_grad = param.grad
else:
accumulated_grad = grad_accum
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
start = source_offset
accumulated_grad = accumulated_grad.view(-1).narrow(0, start, num_elements)
self.norm_for_param_grads[param_id] = accumulated_grad.data.double().norm(2)
def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param):
param_id = self.get_param_id(param)
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, num_elements)
grad_accum = self.get_param_gradient_attribute(param)
if grad_accum is None:
src_tensor = grad_accum.view(-1).narrow(0, source_offset, num_elements)
else:
src_tensor = grad_accum.view(-1).narrow(0, source_offset, num_elements)
if not self.fp16_master_weights_and_gradients:
src_tensor = src_tensor.float()
dest_tensor.copy_(src_tensor, non_blocking=True)
param.grad = None # offload only
def complete_grad_norm_calculation_for_cpu_offload(self, params):
total_norm = 0.0
norm_type = 2.0
for p in params:
# Pipeline parallelism may replicate parameters. Avoid multi-counting.
if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
continue
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
param_id = self.get_param_id(p)
# as some model have trainable parameters but skipped in training,
# their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run,
# so they have no norm_for_param_grads
if param_id in self.norm_for_param_grads:
param_norm = self.norm_for_param_grads[param_id]
total_norm += param_norm.item() ** 2
else:
# As unused parameters in modules may not be expected sometimes,
# add an explicit error msg when it occurred and an option to
# avoid the error
assert self.ignore_unused_parameters, """
This assert indicates that your module has parameters that
were not used in producing loss.
You can avoid this assert by
(1) enable ignore_unused_parameters option in zero_optimization config;
(2) making sure all trainable parameters and `forward` function
outputs participate in calculating loss.
"""
# Sum across all model parallel GPUs.
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group)
self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM)
total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
return total_norm
############################################################################################
def copy_grads_in_partition(self, param):
if self.cpu_offload:
if self.gradient_accumulation_steps > 1:
self.async_accumulate_grad_in_cpu_via_gpu(param)
if self.is_gradient_accumulation_boundary:
self.set_norm_for_param_grad_in_gpu(param)
self.update_overflow_tracker_for_param_grad(param)
self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param)
return
# print(f"ID {self.get_param_id(param)} grad norm {param.grad.norm()}")
if self.grads_in_partition is None:
self.grads_in_partition_offset = 0
total_size = 0
for group in self.params_in_partition:
for param_in_partition in group:
total_size += param_in_partition.numel()
see_memory_usage(f"before copying {total_size} gradients into partition")
self.grads_in_partition = torch.empty(int(total_size),
dtype=self.dtype,
device=get_accelerator().current_device_name())
see_memory_usage(f"after copying {total_size} gradients into partition")
grad_reduc = self.get_gradient_for_reduction(param)
# The allreduce buffer will be rewritten. Copy the gradients in partition to a new buffer
new_grad_tensor = self.grads_in_partition.view(-1).narrow(0, self.grads_in_partition_offset, param.numel())
new_grad_tensor.copy_(grad_reduc.view(-1))
grad_reduc.data = new_grad_tensor.data.view_as(grad_reduc)
# print(f"Grad norm after copy to contiguous_buffer {param.grad.data.norm()}")
self.grads_in_partition_offset += param.numel()
def reduce_ipg_grads(self):
if self.contiguous_gradients:
if self.extra_large_param_to_reduce is not None:
assert len(self.params_in_ipg_bucket) == 1, "more than 1 param in ipg bucket, this shouldn't happen"
_, _, param_id = self.params_in_ipg_bucket[0]
assert self.get_param_id(self.extra_large_param_to_reduce
) == param_id, "param in ipg bucket does not match extra-large param"
extra_large_grad_reduc = self.get_gradient_for_reduction(self.extra_large_param_to_reduce)
self.average_tensor(extra_large_grad_reduc.view(-1))
self.extra_large_param_to_reduce = None
else:
self.average_tensor(self.ipg_buffer[self.ipg_index])
else:
self.buffered_reduce_fallback(None,
self.grads_in_ipg_bucket,
elements_per_buffer=self.elements_in_ipg_bucket)
if self.overlap_comm:
stream = self.reduction_stream
elif self.cpu_offload:
# TODO: copy_grad_stream is disabled because of race with reduce. This hurts perf and should be fixed.
# get_accelerator().synchronize()
# stream = self.copy_grad_stream
stream = get_accelerator().current_stream()
else:
stream = get_accelerator().current_stream()
with get_accelerator().stream(stream):
for _, param, param_id in self.params_in_ipg_bucket:
assert self.params_already_reduced[param_id] == False, \
f"The parameter {param_id} has already been reduced. \
Gradient computed twice for this partition. \
Multiple gradient reduction is currently not supported"
self.params_already_reduced[param_id] = True
if self.partition_gradients:
if not self.is_param_in_current_partition[param_id]:
if self.overlap_comm and self.contiguous_gradients is False:
# Clear grads of other partitions during the next reduction
# to avoid clearing them before the reduction is complete.
if self.previous_reduced_grads is None:
self.previous_reduced_grads = []
self.previous_reduced_grads.append(param)
else:
self.clear_grad_attribute(param)
elif self.contiguous_gradients:
self.copy_grads_in_partition(param)
else: # zero stage 1 - partition only optimizer state
if self.contiguous_gradients and self.is_param_in_current_partition[param_id]:
self.copy_grads_in_partition(param)
self.grads_in_ipg_bucket = []
self.params_in_ipg_bucket = []
self.ipg_bucket_has_moe_params = False
self.elements_in_ipg_bucket = 0
#####################################################################
def reduce_ready_partitions_and_remove_grads(self, param, i):
if self.partition_gradients or self.is_gradient_accumulation_boundary:
self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
def zero_reduced_gradients(self, partition_id, i):
def are_all_related_partitions_reduced(params_id):
for partition_id in self.param_to_partition_ids[i][params_id]:
if not self.is_partition_reduced[i][partition_id]:
return False
return True
for params_id in self.is_grad_computed[i][partition_id]:
if are_all_related_partitions_reduced(params_id):
self.param_dict[params_id].grad = None # dead code
def flatten_and_print(self, message, tensors, start=0, n=5):
flatten_tensor = self.flatten(tensors)
def print_func():
logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n))
self.sequential_execution(print_func, message)
def get_grads_to_reduce(self, i, partition_id):
def get_reducible_portion(key):
grad = self.param_dict[key].grad
total_elements = grad.numel()
start = self.grad_start_offset[i][partition_id][key]
num_elements = min(total_elements - start,
self.partition_size[i] - self.grad_partition_insertion_offset[i][partition_id][key])
if not pg_correctness_test:
if num_elements == total_elements:
return grad
else:
return grad.contiguous().view(-1).narrow(0, int(start), int(num_elements))
else:
if num_elements == total_elements:
return grad.clone()
else:
return grad.clone().contiguous().view(-1).narrow(0, int(start), int(num_elements))
grads_to_reduce = []
for key in self.is_grad_computed[i][partition_id]:
grad = get_reducible_portion(key)
grads_to_reduce.append(grad)
return grads_to_reduce
def sequential_execution(self, function, message, group=None):
if group is None:
group = self.zp_process_group
if dist.get_rank(group=group) == 0:
logger.info(message)
for id in range(dist.get_world_size(group=group)):
if id == dist.get_rank(group=group):
function()
dist.barrier(group=group)
def set_none_gradients_to_zero(self, i, partition_id):
for param_id in self.is_grad_computed[i][partition_id]:
param = self.param_dict[param_id]
if param.grad is None:
param.grad = torch.zero_like(param)
######################Reduction Related Methods##############################
def allreduce_bucket(self, bucket, rank=None, log=None, divide=True, process_group=None):
rank = None
tensor = self.flatten(bucket)
process_group = self.zp_process_group if process_group is None else process_group
tensor_to_allreduce = tensor
if pg_correctness_test or self.sequence_parallel_size > 1:
communication_data_type = torch.float32
else:
communication_data_type = self.communication_data_type
if communication_data_type != tensor.dtype:
tensor_to_allreduce = tensor.to(communication_data_type)
if divide:
tensor_to_allreduce.div_(
dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size))
tensor_to_allreduce = tensor_to_allreduce.contiguous()
if rank is None:
# "All Reducing"
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
else:
global_rank = dist.get_global_rank(self.dp_process_group, rank)
dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group)
if communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce:
if rank is None or rank == dist.get_rank(group=process_group):
tensor.copy_(tensor_to_allreduce)
return tensor
def _clear_previous_reduced_grads(self):
if self.previous_reduced_grads is not None:
for param in self.previous_reduced_grads:
self.clear_grad_attribute(param)
self.previous_reduced_grads = None
# if rank is specified do a reduction instead of an allreduce
def allreduce_and_copy(self, small_bucket, rank=None, log=None, divide=True, process_group=None):
process_group = self.zp_process_group if process_group is None else process_group
if self.overlap_comm:
get_accelerator().synchronize()
# It is safe to clear the previously reduced grads of other partitions
self._clear_previous_reduced_grads()
stream = self.reduction_stream
else:
stream = get_accelerator().current_stream()
with get_accelerator().stream(stream):
allreduced = self.allreduce_bucket(
small_bucket,
rank=rank,
log=log,
divide=divide,
process_group=process_group,
)
if rank is None or rank == dist.get_rank(group=self.zp_process_group):
for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)):
buf.copy_(synced)
def allreduce_no_retain(
self,
bucket,
numel_per_bucket=500000000,
rank=None,
log=None,
divide=True,
process_group=None,
):
small_bucket = []
numel = 0
for tensor in bucket:
small_bucket.append(tensor)
numel = numel + tensor.numel()
if numel > numel_per_bucket:
self.allreduce_and_copy(small_bucket, rank=rank, log=None, divide=divide, process_group=process_group)
small_bucket = []
numel = 0
if len(small_bucket) > 0:
self.allreduce_and_copy(small_bucket, rank=rank, log=log, divide=divide, process_group=process_group)
# allows using reduction of gradients instead of using all_reduce
def buffered_reduce_fallback(self, rank, grads, elements_per_buffer=500000000, log=None):
split_buckets = split_half_float_double(grads)
for i, bucket in enumerate(split_buckets):
self.allreduce_no_retain(bucket, numel_per_bucket=elements_per_buffer, rank=rank, log=log)
#############################################################################
#############################################################################
#############################################################################
# views the tensor as multiple partitions and returns
# those partitions
def get_data_parallel_partitions(self, tensor, group_id):
partitions = []
dp = dist.get_world_size(group=self.real_zp_process_group[group_id])
# dp_id = dist.get_rank(group=self.real_dp_process_group[group_id])
total_num_elements = tensor.numel()
base_size = total_num_elements // dp
remaining = total_num_elements % dp
start = 0
for id in range(dp):
partition_size = base_size
if id < remaining:
partition_size = partition_size + 1
partitions.append(tensor.narrow(0, start, partition_size))
start = start + partition_size
return partitions
def get_partition_info(self, tensor_list, partition_size, partition_id):
params_in_partition = []
params_not_in_partition = []
start_index = partition_size * partition_id
end_index = partition_size * (partition_id + 1)
current_index = 0
first_offset = 0
for tensor in tensor_list:
tensor_size = tensor.numel()
if start_index <= current_index < end_index:
params_in_partition.append(tensor)
elif current_index < start_index < (current_index + tensor_size):
params_in_partition.append(tensor)
assert (first_offset == 0
), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
else:
params_not_in_partition.append(tensor)
current_index = current_index + tensor_size
return params_in_partition, params_not_in_partition, first_offset
def zero_grad(self, set_to_none=True):
"""
Zero FP16 parameter grads.
"""
# FP32 grad should never exist.
# For speed, set model fp16 grad to None by default
# zero all pointers to grad tensors
for group in self.bit16_groups:
for p in group:
if set_to_none:
p.grad = None # epilogue and in step
p.grad_accum = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
def _model_parallel_all_reduce(self, tensor, op):
""" Perform all reduce within model parallel group, if any.
"""
if self.model_parallel_group is None or self.model_parallel_world_size == 1:
pass
else:
dist.all_reduce(tensor=tensor, op=op, group=self.model_parallel_group)
def get_grad_norm_direct(self, gradients, params, norm_type=2):
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients)
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.dp_process_group)
# Take max across all GPUs.
self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX)
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.0
# if dist.get_rank() == 0:
# logger.info(f"Total Norm beginning {total_norm}")
for g, p in zip(gradients, params):
# Pipeline parallelism may replicate parameters. Avoid multi-counting.
if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
continue
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
param_norm = g.data.double().norm(2)
total_norm += param_norm.item() ** 2
# Sum across all model parallel GPUs.
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group)
self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM)
total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
return total_norm
# creates a flat fused tensor from the tensor list starting at the first_offset
# in the first tensor of the list. If there are not enough elements in the tensor
# list then the flat tensor will be padded with zeros
def get_flat_partition(self, tensor_list, first_offset, partition_size, dtype, device, return_tensor_list=False):
flat_tensor_list = []
current_size = 0
for i, tensor in enumerate(tensor_list):
grad_accum = self.get_param_gradient_attribute(tensor)
if grad_accum is None:
grad_accum = torch.zeros_like(tensor, dtype=dtype)
tensor = grad_accum
num_elements = tensor.numel()
tensor_offset = 0
# we need to offset to get to the right element
if i == 0 and first_offset > 0:
tensor_offset = first_offset
num_elements = num_elements - tensor_offset
# we dont need all elements of the tensor
if num_elements > (partition_size - current_size):
num_elements = partition_size - current_size
# we need a narrow view of the tensor based on the tensor offset and number of elements that
# we need from this tensor
if tensor_offset > 0 or num_elements < tensor.numel():
flat_tensor_list.append(tensor.contiguous().view(-1).narrow(0, int(tensor_offset), int(num_elements)))
else:
flat_tensor_list.append(tensor)
current_size = current_size + num_elements
# this means its the last partition and does not align with the dp boundary. We need to pad before flattening
if current_size < partition_size:
flat_tensor_list.append(torch.zeros(int(partition_size - current_size), dtype=dtype, device=device))
if return_tensor_list:
return flat_tensor_list
return self.flatten(flat_tensor_list)
def free_grad_in_param_list(self, param_list):
for p in param_list:
p.grad = None # in step
p.grad_accum = None
def reset_cpu_buffers(self):
self.norm_for_param_grads = {}
self.local_overflow = False
def set_lr(self, lr):
"""Set the learning rate."""
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr
def get_lr(self):
"""Return the current learning rate."""
return self.optimizer.param_groups[0]["lr"]
def override_loss_scale(self, loss_scale):
if loss_scale != self.external_loss_scale:
logger.info(f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}')
self.custom_loss_scaler = True
self.external_loss_scale = loss_scale
def scaled_global_norm(self, norm_type=2):
assert norm_type == 2, "only L2 norm supported"
norm_groups = []
for i, group in enumerate(self.bit16_groups):
partition_id = dist.get_rank(group=self.real_zp_process_group[i])
if self.cpu_offload:
norm_groups.append(self.complete_grad_norm_calculation_for_cpu_offload(self.params_in_partition[i]))
single_grad_partition = self.single_partition_of_fp32_groups[i].grad
else:
norm_groups.append(self.get_grad_norm_direct(self.averaged_gradients[i], self.params_in_partition[i]))
if self.has_moe_layers:
self._average_expert_grad_norms(norm_groups)
# note that the get_global_norm function only supports l2 norm
return get_global_norm(norm_list=norm_groups)
def get_bit16_param_group(self, group_no):
bit16_partitions = self.parallel_partitioned_bit16_groups[group_no]
partition_id = dist.get_rank(group=self.real_zp_process_group[group_no])
return [bit16_partitions[dist.get_rank(group=self.real_zp_process_group[group_no])]]
def _optimizer_step(self, group_no):
original_param_groups = self.optimizer.param_groups
self.optimizer.param_groups = [original_param_groups[group_no]]
# Disabling this as the C++ side copy & synchronize is not working correctly
# from deepspeed.ops.adam import DeepSpeedCPUAdam
# if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half:
# self.optimizer.step(fp16_param_groups=[self.get_bit16_param_group(group_no)])
# else:
# self.optimizer.step()
self.optimizer.step()
self.optimizer.param_groups = original_param_groups
def step(self, closure=None):
"""
Not supporting closure.
"""
self.micro_step_id = -1
see_memory_usage(f"In step before checking overflow")
# First compute norm for all group so we know if there is overflow
if self.dtype == torch.float16:
self.check_overflow()
prev_scale = self.loss_scale
self._update_scale(self.overflow)
if self.overflow:
see_memory_usage('After overflow before clearing gradients')
self.zero_grad(set_to_none=True)
if self.cpu_offload:
self.reset_cpu_buffers()
else:
self.averaged_gradients = {}
see_memory_usage('After overflow after clearing gradients')
for timer in OPTIMIZER_TIMERS:
self.timers(timer).start()
self.timers(timer).stop()
return
# Step 1:- Calculate gradient norm using bit-16 grads
see_memory_usage('Before norm calculation')
scaled_global_grad_norm = self.scaled_global_norm()
self._global_grad_norm = scaled_global_grad_norm / prev_scale
see_memory_usage('After norm before optimizer')
# Step 2:- run optimizer and upscaling simultaneously
for i, group in enumerate(self.bit16_groups):
self.timers(OPTIMIZER_GRADIENTS_TIMER).start()
partition_id = dist.get_rank(group=self.real_zp_process_group[i])
if self.cpu_offload:
single_grad_partition = self.single_partition_of_fp32_groups[i].grad
self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm)
self.timers(OPTIMIZER_GRADIENTS_TIMER).stop()
self.timers(OPTIMIZER_STEP_TIMER).start()
self._optimizer_step(i)
# Disabled, this is not currently working
# from deepspeed.ops.adam import DeepSpeedCPUAdam
# if not (type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half):
# bit16_partitions = self.parallel_partitioned_bit16_groups[i]
# fp32_partition = self.single_partition_of_fp32_groups[i]
# bit16_partitions[partition_id].data.copy_(fp32_partition.data)
bit16_partitions = self.parallel_partitioned_bit16_groups[i]
fp32_partition = self.single_partition_of_fp32_groups[i]
bit16_partitions[partition_id].data.copy_(fp32_partition.data)
self.timers(OPTIMIZER_STEP_TIMER).stop()
else:
# free gradients for all the parameters that are not updated by this process(ZeRO stage2)
self.free_grad_in_param_list(self.params_not_in_partition[i])
# create a flat gradients for parameters updated by this process
# If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors
if partition_id == dist.get_world_size(group=self.real_zp_process_group[i]) - 1:
single_grad_partition = self.flatten_dense_tensors_aligned(
self.averaged_gradients[i],
int(self.partition_size[i])).to(self.single_partition_of_fp32_groups[i].dtype)
else:
single_grad_partition = self.flatten(self.averaged_gradients[i]).to(
self.single_partition_of_fp32_groups[i].dtype)
assert single_grad_partition.numel() == self.partition_size[i], \
"averaged gradients have different number of elements that partition size {} {} {} {}".format(
single_grad_partition.numel(), self.partition_size[i], i, partition_id)
self.single_partition_of_fp32_groups[i].grad = single_grad_partition
# release all the gradient since we have already created a necessary copy in dp_grad_partition(ZeRO stage2)
self.free_grad_in_param_list(self.params_in_partition[i])
self.averaged_gradients[i] = None
self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm)
self.timers(OPTIMIZER_GRADIENTS_TIMER).stop()
# Step 3:- run the optimizer if no offloading
self.timers(OPTIMIZER_STEP_TIMER).start()
self._optimizer_step(i)
# Step 4:- get rid of the fp32 gradients. Not needed anymore
self.single_partition_of_fp32_groups[i].grad = None
del single_grad_partition
bit16_partitions = self.parallel_partitioned_bit16_groups[i]
fp32_partition = self.single_partition_of_fp32_groups[i]
bit16_partitions[partition_id].data.copy_(fp32_partition.data)
self.timers(OPTIMIZER_STEP_TIMER).stop()
see_memory_usage('After optimizer before all-gather')
if self.cpu_offload:
self.reset_cpu_buffers()
self.timers(OPTIMIZER_ALLGATHER_TIMER).start()
# if dist.get_rank(group=self.dp_process_group) == 0:
# pdb.set_trace() # 或者使用其他调试工具
# Gather the updated weights from everyone.
# Then all partitions of the model parameters are updated and ready for next round forward.
all_gather_into_tensor_dp_groups(groups_flat=self.bit16_groups_flat,
partitioned_param_groups=self.parallel_partitioned_bit16_groups,
zp_process_group=self.real_zp_process_group)
self.timers(OPTIMIZER_ALLGATHER_TIMER).stop()
# TODO: we probably don't need this? just to be safe
for i in range(len(self.bit16_groups)):
self._update_model_bit16_weights(i)
self.timers.log(OPTIMIZER_TIMERS)
see_memory_usage('After zero_optimizer step')
return
@torch.no_grad()
def update_lp_params(self):
for i, (bit16_partitions, fp32_partition) in enumerate(
zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)):
partition_id = dist.get_rank(group=self.real_zp_process_group[i])
bit16_partitions[partition_id].data.copy_(fp32_partition.data)
# print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True)
# if i == 0:
# print_rank_0(f'{fp32_partition[:10]=}', force=True)
all_gather_into_tensor_dp_groups(groups_flat=self.bit16_groups_flat,
partitioned_param_groups=self.parallel_partitioned_bit16_groups,
zp_process_group=self.real_zp_process_group)
def _average_expert_grad_norms(self, norm_groups):
for i, norm in enumerate(norm_groups):
if self.is_moe_param_group[i]:
scaled_norm = norm * 1.0 / float(dist.get_world_size(group=self.dp_process_group))
scaled_norm_tensor = torch.tensor(scaled_norm,
device=get_accelerator().device_name(),
dtype=torch.float)
dist.all_reduce(scaled_norm_tensor, group=self.dp_process_group)
norm_groups[i] = scaled_norm_tensor.item()
def unscale_and_clip_grads(self, grad_groups_flat, total_norm):
# compute combined scale factor for this group
combined_scale = self.loss_scale
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
if clip > 1:
combined_scale = clip * self.loss_scale
for grad in grad_groups_flat:
if isinstance(grad, list):
sub_partitions = grad
for g in sub_partitions:
g.data.mul_(1. / combined_scale)
else:
grad.data.mul_(1. / combined_scale)
def _check_overflow(self, partition_gradients=True):
self.overflow = self.has_overflow(partition_gradients)
# `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params, is_grad_list=False):
for p in params:
if p.grad is not None and self._has_inf_or_nan(p.grad.data):
return True
return False
def has_overflow_partitioned_grads_serial(self):
for i in range(len(self.bit16_groups)):
for j, grad in enumerate(self.averaged_gradients[i]):
if grad is not None and self._has_inf_or_nan(grad.data, j):
return True
return False
def has_overflow(self, partition_gradients=True):
if partition_gradients:
overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial()
overflow_gpu = get_accelerator().ByteTensor([overflow])
'''This will capture overflow across all data parallel and expert parallel process
Since expert parallel process are a subset of data parallel process'''
dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group)
else:
params = []
for group in self.bit16_groups:
for param in group:
params.append(param)
overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients)
overflow_gpu = get_accelerator().ByteTensor([overflow])
# Since each model parallel GPU carries only part of the model,
# make sure overflow flag is synced across all the model parallel GPUs
self._model_parallel_all_reduce(tensor=overflow_gpu, op=dist.ReduceOp.MAX)
overflow = overflow_gpu[0].item()
return bool(overflow)
# `x` is a torch.Tensor
@staticmethod
def _has_inf_or_nan(x, j=None):
try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x
# (which is true for some recent version of pytorch).
cpu_sum = float(x.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar
# cpu_sum = float(x.sum())
except RuntimeError as instance:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if "value cannot be converted" not in instance.args[0]:
raise
return True
else:
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
return False
def backward(self, loss, retain_graph=False):
"""
:attr:`backward` performs the following steps:
1. fp32_loss = loss.float()
2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
"""
self.micro_step_id += 1
if self.contiguous_gradients:
self.ipg_buffer = []
buf_0 = torch.empty(int(self.reduce_bucket_size),
dtype=self.dtype,
device=get_accelerator().current_device_name())
self.ipg_buffer.append(buf_0)
# Use double buffers to avoid data access conflict when overlap_comm is enabled.
if self.overlap_comm:
buf_1 = torch.empty(int(self.reduce_bucket_size),
dtype=self.dtype,
device=get_accelerator().current_device_name())
self.ipg_buffer.append(buf_1)
self.ipg_index = 0
if self.custom_loss_scaler:
scaled_loss = self.external_loss_scale * loss
scaled_loss.backward()
else:
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
# Only for Stage 1, Mode 2
if self.use_grad_accum_attribute:
self.fill_grad_accum_attribute()
def check_overflow(self, partition_gradients=True):
self._check_overflow(partition_gradients)
def _update_scale(self, has_overflow=False):
self.loss_scaler.update_scale(has_overflow)
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self):
if self.custom_loss_scaler:
return self.external_loss_scale
else:
return self.loss_scaler.cur_scale
def _set_loss_scale(self, value):
self.loss_scaler.cur_scale = value
loss_scale = property(_get_loss_scale, _set_loss_scale)
cur_scale = property(_get_loss_scale, _set_loss_scale)
# Return group tensor after removing paddings that are added for alignment to DP world size.
# This method works on the assumption that each group contains a single flattened tensor.
def _get_groups_without_padding(self, groups_with_padding):
groups_without_padding = []
for i, group in enumerate(groups_with_padding):
lean_length = group.numel() - self.groups_padding[i]
groups_without_padding.append(group[:lean_length])
return groups_without_padding
# Return optimizer state after removing paddings that are added for alignment.
def _get_state_without_padding(self, state_with_padding, padding):
lean_state = {}
for key, value in state_with_padding.items():
if torch.is_tensor(value):
lean_length = value.numel() - padding
lean_state[key] = value[:lean_length]
else:
lean_state[key] = value
return lean_state
# Return base optimizer states.
# This method assumes that each param group contains a single flattened tensor.
def _get_base_optimizer_state(self):
optimizer_groups_state = []
for i, group in enumerate(self.optimizer.param_groups):
p = group['params'][0]
lean_optimizer_state = self._get_state_without_padding(self.optimizer.state[p], self.groups_padding[i])
optimizer_groups_state.append(lean_optimizer_state)
return optimizer_groups_state
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
state_dict = {}
state_dict[LOSS_SCALER] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict[CLIP_GRAD] = self.clip_grad
if self.elastic_checkpoint:
state_dict[BASE_OPTIMIZER_STATE] = self._get_base_optimizer_state()
if "step" in self.optimizer.param_groups[0]:
# Assuming "step" is the only item that changes through training iterations
assert all(group["step"] == self.optimizer.param_groups[0]["step"]
for group in self.optimizer.param_groups), "All param groups must have the same step value"
state_dict[BASE_OPTIMIZER_STATE_STEP] = self.optimizer.param_groups[0]["step"]
else:
state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict()
# Remove paddings for DP alignment to enable loading for other alignment values
fp32_groups_without_padding = self._get_groups_without_padding(self.single_partition_of_fp32_groups)
state_dict[SINGLE_PARTITION_OF_FP32_GROUPS] = fp32_groups_without_padding
state_dict[
ZERO_STAGE] = ZeroStageEnum.gradients if self.partition_gradients else ZeroStageEnum.optimizer_states
state_dict[GROUP_PADDINGS] = self.groups_padding
state_dict[PARTITION_COUNT] = self.partition_count
state_dict[DS_VERSION] = version
state_dict[PARAM_SLICE_MAPPINGS] = self._param_slice_mappings
return state_dict
# Restore base optimizer fp32 weights from elastic checkpoint by:
# 1) Merging fp32 weights from checkpoints of all partitions
# 2) Extracting fp32 weights for current partition from merged weights
# 3) Using extracted weights to update base optimizer weights directly.
def _restore_from_elastic_fp32_weights(self, all_state_dict):
merged_single_partition_of_fp32_groups = []
for i in range(len(self.single_partition_of_fp32_groups)):
partition_id = dist.get_rank(group=self.real_zp_process_group[i])
merged_partitions = [sd[SINGLE_PARTITION_OF_FP32_GROUPS][i] for sd in all_state_dict]
if self.is_moe_group(self.optimizer.param_groups[i]):
ranks = self.get_ep_ranks(group_name=self.optimizer.param_groups[i]['name'])
merged_partitions = [merged_partitions[i] for i in ranks]
flat_merged_partitions = self.flatten_dense_tensors_aligned(
merged_partitions,
self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_zp_process_group[i]))
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions, i)
merged_single_partition_of_fp32_groups.append(dp_partitions[partition_id])
for current, saved in zip(self.single_partition_of_fp32_groups, merged_single_partition_of_fp32_groups):
current.data.copy_(saved.data)
# Restore base optimizer fp32 weights from ZeRO fp16 or bfloat16 weights
def _restore_from_bit16_weights(self):
for group_id, (bit16_partitions, fp32_partition) in enumerate(
zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)):
partition_id = dist.get_rank(group=self.real_zp_process_group[group_id])
fp32_partition.data.copy_(bit16_partitions[partition_id].data)
# Refresh the fp32 master params from the fp16 or bfloat16 copies.
def refresh_fp32_params(self):
self._restore_from_bit16_weights()
# Extract optimizer state for current partition from merged states of all partitions
def _partition_base_optimizer_state(self, state_key, all_partition_states, group_id):
partition_id = dist.get_rank(group=self.real_zp_process_group[group_id])
alignment = dist.get_world_size(group=self.real_zp_process_group[group_id])
if torch.is_tensor(all_partition_states[0]):
flat_merged_partitions = self.flatten_dense_tensors_aligned(all_partition_states, alignment)
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions, group_id)
return dp_partitions[partition_id]
else:
# Assume non-tensor states are not partitioned and equal across ranks, so return first one
return all_partition_states[0]
def _restore_base_optimizer_state(self, base_optimizer_group_states):
if type(base_optimizer_group_states) == dict:
base_optimizer_group_states = base_optimizer_group_states['state']
for i, group in enumerate(self.optimizer.param_groups):
p = group['params'][0]
for key, saved in base_optimizer_group_states[i].items():
if torch.is_tensor(self.optimizer.state[p][key]):
dst_tensor = self.optimizer.state[p][key]
src_tensor = _get_padded_tensor(saved, dst_tensor.numel())
self.optimizer.state[p][key].data.copy_(src_tensor.data)
else:
self.optimizer.state[p][key] = saved
def get_ep_ranks(self, rank=0, group_name=None):
from deepspeed.utils import groups
expert_parallel_size_ = groups._get_expert_parallel_world_size(group_name)
world_size = groups._get_data_parallel_world_size()
rank = groups._get_expert_parallel_rank(group_name)
ranks = range(rank, world_size, expert_parallel_size_)
return list(ranks)
# Restore base optimizer state from elastic checkpoint by
# 1) Merging optimizer state from checkpoints of all partitions
# 2) Extracting optimizer state for current partition from the merged state
# 3) Using the extracted value to directly update the base optimizer.
def _restore_elastic_base_optimizer_state(self, all_state_dict):
base_optimizer_group_states = []
for i in range(len(self.optimizer.param_groups)):
partition_states = {}
all_partition_group_states = [sd[BASE_OPTIMIZER_STATE][i] for sd in all_state_dict]
if self.is_moe_group(self.optimizer.param_groups[i]):
ranks = self.get_ep_ranks(group_name=self.optimizer.param_groups[i]['name'])
all_partition_group_states = [all_partition_group_states[i] for i in ranks]
for key in all_partition_group_states[0].keys():
all_partition_states = [all_states[key] for all_states in all_partition_group_states]
partition_states[key] = self._partition_base_optimizer_state(key, all_partition_states, i)
base_optimizer_group_states.append(partition_states)
self._restore_base_optimizer_state(base_optimizer_group_states)
# Restore step
if BASE_OPTIMIZER_STATE_STEP in all_state_dict[0]:
assert all(sd[BASE_OPTIMIZER_STATE_STEP] == all_state_dict[0][BASE_OPTIMIZER_STATE_STEP]
for sd in all_state_dict), "State dicts of all partitions must have the same step value"
loaded_param_groups_step = all_state_dict[0][BASE_OPTIMIZER_STATE_STEP]
for param_group in self.optimizer.param_groups:
param_group['step'] = loaded_param_groups_step
def load_state_dict(self,
state_dict_list,
load_optimizer_states=True,
load_from_fp32_weights=False,
checkpoint_folder=None,
load_serial=None):
if checkpoint_folder:
self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights)
else:
self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights)
def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):
self._load_hp_checkpoint_state(checkpoint_folder)
@property
def param_groups(self):
"""Forward the wrapped optimizer's parameters."""
return self.optimizer.param_groups
def _load_hp_checkpoint_state(self, checkpoint_dir):
checkpoint_dir = os.path.join(checkpoint_dir, "zero")
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
assert os.path.isfile(
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'
optim_sd = torch.load(optim_state_path)
self._load_global_state(optim_sd)
tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
else self.mpu.get_tensor_model_parallel_world_size()
for i, _ in enumerate(self.optimizer.param_groups):
for lp in self.bit16_groups[i]:
if lp._hp_mapping is not None:
# print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
tp_world_size)
def _load_global_state(self, sd):
self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler)
self.dynamic_loss_scale = sd.get('dynamic_loss_scale', self.dynamic_loss_scale)
self.overflow = sd.get('overflow', self.overflow)
self.clip_grad = sd.get(CLIP_GRAD, self.clip_grad)
ckpt_version = sd.get(DS_VERSION, False)
assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed"
ckpt_version = pkg_version.parse(ckpt_version)
# zero stage 1 mode
if not self.partition_gradients:
required_version = pkg_version.parse("0.3.17")
error_str = f"ZeRO stage 1 changed in {required_version} and is not backwards compatible " \
"with older stage 1 checkpoints. If you'd like to load an old ZeRO-1 checkpoint " \
"please use an older version of DeepSpeed (<= 0.5.8) and set 'legacy_stage1': true in your zero config json."
assert required_version <= ckpt_version, f"Old version: {ckpt_version} {error_str}"
def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False):
r"""Loading ZeRO checkpoint
Arguments:
state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition.
Note that the number of saved partitions may differ from number of loading partitions to support
changing GPU count, specifically DP world size, between saving and loading checkpoints.
load_optimizer_states: Boolean indicating whether or not to load base optimizer states
load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32
copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss).
"""
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).to(get_accelerator().device_name()).half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# I think it should actually be ok to reload the optimizer before the model.
dp_rank = dist.get_rank(group=self.zp_process_group)
current_rank_sd = state_dict_list[dp_rank]
self._load_global_state(current_rank_sd)
ckpt_is_rigid = isinstance(current_rank_sd[BASE_OPTIMIZER_STATE], dict)
# padding is always at the last rank/partition
# if DP=1024 and param-group elems=16 -> padding will be 1024-16 across all but one rank
# scenario-1 (shrink): saving w. 4 gpus -> loading w. 2 gpus
# scenario-2 (expand): saving w. 2 gpus -> loading w. 4 gpus
# if load_optimizer_states:
# if new_dp_size:
# self.strip_padding()
# self.add_padding_w_new_dp_size()
# self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE])
if load_optimizer_states:
if ckpt_is_rigid:
# loading rigid ckpt into either rigid or elastic exec
self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE])
else:
if self.elastic_checkpoint:
# loading elastic into elastic exec
self._restore_elastic_base_optimizer_state(state_dict_list)
else:
# loading an elastic checkpoint into rigid exec
self._restore_base_optimizer_state(current_rank_sd[BASE_OPTIMIZER_STATE])
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date.
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
# out of date. There are two options.
# 1: Refresh the master params from the model's fp16 params.
# This requires less storage but incurs precision loss.
# 2: Save and restore the fp32 master copies separately.
# We choose option 1 if changing DP degree and option 2 otherwise.
#
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
# of their associated parameters, because it's possible those buffers might not exist yet in
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
if load_from_fp32_weights:
# option 2 from above
if self.elastic_checkpoint and not ckpt_is_rigid:
self._restore_from_elastic_fp32_weights(state_dict_list)
else:
# For non-elastic checkpoint, simply copying from saved weights of current rank is sufficient.
for current, saved in zip(self.single_partition_of_fp32_groups,
current_rank_sd[SINGLE_PARTITION_OF_FP32_GROUPS]):
src_tensor = _get_padded_tensor(saved, current.numel())
current.data.copy_(src_tensor.data)
else:
# option 1 from above
self._restore_from_bit16_weights()
if load_optimizer_states:
self._link_all_hp_params()
def _handle_overflow(cpu_sum, x, i):
import math
rank = dist.get_rank()
if rank == 0:
t_i = -1
for v_i, v in enumerate(x.data.contiguous().view(-1)):
if not math.isfinite(float(v)):
t_i = v_i
break
logger.info(f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}")
def estimate_zero2_model_states_mem_needs(total_params,
num_gpus_per_node=1,
num_nodes=1,
cpu_offload=True,
additional_buffer_factor=1.5):
total_gpus = num_nodes * num_gpus_per_node
if cpu_offload:
gpu_mem = 2 * total_params
cpu_mem = total_params * max(4 * total_gpus, 16) * additional_buffer_factor
else:
gpu_mem = 4 * total_params + int(16 * total_params / total_gpus)
cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor
return int(cpu_mem), int(gpu_mem)
def model_to_params(model):
# shared params calculated only once
total_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
return total_params
def estimate_zero2_model_states_mem_needs_all_live(model,
num_gpus_per_node=1,
num_nodes=1,
additional_buffer_factor=1.5):
"""
Print out estimates on memory usage requirements for ZeRO 2 params, optim states and gradients
for a given ``model`` and hardware setup.
If you have an actual model object, use this function and everything will be derived
automatically.
If it's a hypothetical model, use ``estimate_zero2_model_states_mem_needs_all_cold`` where you have to pass
the ``total_params`` explicitly.
Args:
- ``model``: ``nn.Module`` object
- ``num_gpus_per_node``: how many gpus per node (defaults to 1)
- ``num_nodes``: how many nodes (defaults to 1),
- ``additional_buffer_factor``: estimation factor (defaults to 1.5):
"""
total_params = model_to_params(model)
estimate_zero2_model_states_mem_needs_all_cold(total_params=total_params,
num_gpus_per_node=num_gpus_per_node,
num_nodes=num_nodes,
additional_buffer_factor=additional_buffer_factor)
def estimate_zero2_model_states_mem_needs_all_cold(total_params,
num_gpus_per_node=1,
num_nodes=1,
additional_buffer_factor=1.5):
"""
Print out estimates on memory usage requirements for ZeRO 2 params, optim states and gradients
for a given ``model`` and hardware setup.
If it's a hypothetical model, use this function where you have to pass
the ``total_params`` and ``largest_layer_params`` explicitly.
If you have an actual model object, use ``estimate_zero2_model_states_mem_needs_all_live`` and everything
will be derived automatically.
Args:
- ``total_params``: total model params
- ``num_gpus_per_node``: how many gpus per node (defaults to 1)
- ``num_nodes``: how many nodes (defaults to 1),
- ``additional_buffer_factor``: estimation factor (defaults to 1.5):
"""
def format_options(cpu_offload):
enabled = []
device = f'{OffloadDeviceEnum.cpu:4}' if cpu_offload else "none"
enabled.append(f"offload_optimizer={device}")
return ", ".join(enabled)
nodes_str = "nodes" if num_nodes > 1 else "node"
gpus_str = "GPUs" if num_gpus_per_node > 1 else "GPU"
print("Estimated memory needed for params, optim states and gradients for a:\n"
f"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\n"
f"SW: Model with {int(total_params / 1e6)}M total params.")
print(" per CPU | per GPU | Options")
for cpu_offload in [True, False]:
cpu_mem, gpu_mem = estimate_zero2_model_states_mem_needs(total_params=total_params,
num_gpus_per_node=num_gpus_per_node,
num_nodes=num_nodes,
cpu_offload=cpu_offload,
additional_buffer_factor=additional_buffer_factor)
options_str = format_options(cpu_offload=cpu_offload)
print(f" {cpu_mem / 2 ** 30:7.2f}GB | {gpu_mem / 2 ** 30:6.2f}GB | {options_str}")
================================================
FILE: opensora/adaptor/utils.py
================================================
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Copyright NVIDIA/Megatron
Helper functions and classes from multiple sources.
"""
from collections.abc import Iterable
from deepspeed.moe.utils import is_moe_param
import os
import psutil
import gc
from math import sqrt
from packaging import version as pkg_version
import torch
from deepspeed import comm as dist
try:
from torch._six import inf
except ModuleNotFoundError:
from torch import inf
from deepspeed.utils import groups, logger
from deepspeed.runtime.constants import PIPE_REPLICATED
from numpy import prod
from deepspeed.accelerator import get_accelerator
from deepspeed.module_inject.policy import transpose
from torch.nn import functional as F
torch_memory_reserved = get_accelerator().memory_reserved
torch_max_memory_reserved = get_accelerator().max_memory_reserved
class DummyOptim():
"""
Dummy optimizer presents model parameters as a param group, this is
primarily used to allow ZeRO-3 without an optimizer
"""
def __init__(self, params):
self.param_groups = []
self.param_groups.append({'params': params})
graph_cache = {}
def graph_process(replay_first_step, func, *args, **kwargs):
# `func` should only contain operations on the GPU
# Please ensure that the memory address of the data required by 'func' remains constant
if func.__name__ not in graph_cache:
cuda_stream = get_accelerator().Stream()
cuda_stream.wait_stream(get_accelerator().current_stream())
with get_accelerator().stream(cuda_stream):
func(*args, **kwargs)
get_accelerator().current_stream().wait_stream(cuda_stream)
graph_cache[func.__name__] = get_accelerator().create_graph()
with get_accelerator().capture_to_graph(graph_cache[func.__name__]):
func(*args, **kwargs)
if replay_first_step:
get_accelerator().replay_graph(graph_cache[func.__name__])
else:
get_accelerator().replay_graph(graph_cache[func.__name__])
def noop_decorator(func):
return func
class noop_context(object):
def __init__(self):
pass
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def ensure_directory_exists(filename):
"""Create the directory path to ``filename`` if it does not already exist.
Args:
filename (str): A file path.
"""
dirname = os.path.dirname(filename)
os.makedirs(dirname, exist_ok=True)
def set_random_seed(seed):
"""Set the random seed for common PRNGs used during training: random, numpy, and torch.
Args:
seed (int): the seed to use
"""
import numpy
import random
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
def is_model_parallel_parameter(p) -> bool:
if hasattr(p, 'model_parallel') and p.model_parallel:
return True
if hasattr(p, 'tensor_model_parallel') and p.tensor_model_parallel:
return True
return False
def bwc_tensor_model_parallel_rank(mpu=None):
"""Backwards-compatible way of querying the tensor model parallel rank from
an ``mpu`` object.
*Tensor* model parallelism means that tensors are physically split across
processes. This contrasts with *pipeline* model parallelism, in which the
layers are partitioned but tensors left intact.
The API for tensor model parallelism has changed across versions and this
helper provides a best-effort implementation across versions of ``mpu``
objects. The preferred mechanism is
``mpu.get_tensor_model_parallel_rank()``.
This should "just work" with both Megatron-LM and DeepSpeed's pipeline
parallelism.
Args:
mpu (model parallel unit, optional): The tensor model parallel rank.
If ``mpu=None``, returns 0. Defaults to ``None``.
Returns:
int: the rank
"""
if mpu is None:
# No model parallelism in easy :)
return 0
if hasattr(mpu, 'get_tensor_model_parallel_rank'):
# New Megatron and DeepSpeed convention (post pipeline-parallelism release)
return mpu.get_tensor_model_parallel_rank()
elif hasattr(mpu, 'get_slice_parallel_rank'):
# Some DeepSpeed + pipeline parallelism versions
return mpu.get_slice_parallel_rank()
else:
# Deprecated Megatron and DeepSpeed convention
return mpu.get_model_parallel_rank()
def copy_to_device(item, device, criterion_func):
"""
Return a copy of tensor on specified device.
Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts.
Parameters:
item: tensor to copy or (possibly nested) container of tensors to copy.
device: target device
criterion_func: Function to restrict copy operation to items meet criterion
Returns:
None
"""
if criterion_func(item):
return item.to(device)
elif isinstance(item, list):
return [copy_to_device(v, device, criterion_func) for v in item]
elif isinstance(item, tuple):
return tuple([copy_to_device(v, device, criterion_func) for v in item])
elif isinstance(item, dict):
return {k: copy_to_device(v, device, criterion_func) for k, v in item.items()}
else:
return item
def move_to_device(item, device, criterion_func):
"""
Move tensor on to specified device by changing the storage.
Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts.
Parameters:
item: tensor to move or (possibly nested) container of tensors to move.
device: target device
criterion_func: Function to restrict move operation to items meet criterion
Returns:
None
"""
if criterion_func(item):
device_copy = item.to(device)
item.data = device_copy.data
return item
elif isinstance(item, list):
return [move_to_device(v, device, criterion_func) for v in item]
elif isinstance(item, tuple):
return tuple([move_to_device(v, device, criterion_func) for v in item])
elif isinstance(item, dict):
return {k: move_to_device(v, device, criterion_func) for k, v in item.items()}
else:
return item
class CheckOverflow(object):
'''Checks for overflow in gradient across parallel process'''
def __init__(self, param_groups=None, mpu=None, zero_reduce_scatter=False, deepspeed=None):
self.mpu = mpu
self.params = [] if param_groups else None
self.zero_reduce_scatter = zero_reduce_scatter
self.deepspeed = deepspeed
self.has_moe_params = False
if param_groups:
for group in param_groups:
for param in group:
self.params.append(param)
if is_moe_param(param):
self.has_moe_params = True
def check_using_norm(self, norm_group, reduce_overflow=True):
# TODO: I don't think reduce_overflow is needed if mpu is None
overflow = -1 in norm_group
overflow_gpu = get_accelerator().FloatTensor([overflow])
if self.has_moe_params:
# In this case, we need to do an all_reduce across
# the expert_parallel_group, so that if there was
# an overflow due to expert weights, we detect it
# Only need to check groups.get_largest_expert_parallel_group()
dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=groups._get_max_expert_parallel_group())
if self.mpu is not None:
dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_model_parallel_group())
elif reduce_overflow:
dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX)
dist.barrier()
overflow = overflow_gpu[0].item()
return bool(overflow)
def check(self, param_groups=None):
params = []
has_moe_params = False
if param_groups is None:
params = self.params
has_moe_params = self.has_moe_params
else:
assert param_groups is not None, \
"self.params and param_groups both cannot be none"
for group in param_groups:
for param in group:
params.append(param)
if is_moe_param(param):
has_moe_params = True
return self.has_overflow(params, has_moe_params=has_moe_params)
# `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params):
for i, p in enumerate(params):
if p.grad is not None and self._has_inf_or_nan(p.grad.data, i):
return True
return False
def has_overflow(self, params, has_moe_params=None):
if has_moe_params is None:
has_moe_params = self.has_moe_params
overflow = self.has_overflow_serial(params)
# Since each model parallel GPU carries only part of the model,
# make sure overflow flag is synced across all the model parallel GPUs
overflow_gpu = get_accelerator().ByteTensor([overflow])
# deepspeed.comm.all_reduce(overflow_gpu,
# op=deepspeed.comm.ReduceOp.MAX,
# group=mpu.get_model_parallel_group())
if has_moe_params:
# All reduce this across expert_parallel_group, so that if an expert
# overflows, we detect it here
dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=groups._get_max_expert_parallel_group())
if self.zero_reduce_scatter:
dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=dist.get_world_group())
elif self.mpu is not None:
if self.deepspeed is not None:
using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
if (using_pipeline and self.deepspeed.pipeline_enable_backward_allreduce is False) or (
not using_pipeline and self.deepspeed.enable_backward_allreduce is False):
dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_data_parallel_group())
dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_model_parallel_group())
elif self.deepspeed is not None and self.deepspeed.enable_backward_allreduce is False:
dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=dist.get_world_group())
overflow = overflow_gpu[0].item()
return bool(overflow)
# `x` is a torch.Tensor
@staticmethod
def _has_inf_or_nan(x, i):
try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x
# (which is true for some recent version of pytorch).
cpu_sum = float(x.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar
# cpu_sum = float(x.sum())
except RuntimeError as instance:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if "value cannot be converted" not in instance.args[0]:
raise
return True
else:
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
return False
def _handle_overflow(cpu_sum, x, i):
import math
rank = dist.get_rank()
if rank == 0:
t_i = -1
for v_i, v in enumerate(x.data.contiguous().view(-1)):
if not math.isfinite(float(v)):
t_i = v_i
break
logger.info(f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}")
def get_global_norm(norm_list):
""" Compute total from a list of norms
"""
total_norm = 0.0
for norm in norm_list:
total_norm += norm**2.0
# logger.info(f'norm_list = {norm_list} global = {sqrt(total_norm)}')
return sqrt(total_norm)
def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None):
"""Clips gradient norm of an iterable of parameters.
This has been adapted from Nvidia megatron. We add norm averaging
to consider MoE params when calculating norm as they will result
in different norms across different ranks.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
max_norm = float(max_norm)
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(p.grad.data.abs().max() for p in parameters)
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
# Take max across all GPUs.
if mpu is not None:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0
for p in parameters:
if mpu is not None:
if (mpu.get_model_parallel_rank() == 0) or is_model_parallel_parameter(p):
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item()**norm_type
else:
param_norm = p.grad.data.float().norm(norm_type)
total_norm += param_norm.item()**norm_type
# Sum across all model parallel GPUs.
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
if mpu is not None:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
# Need to average total_norm across different GPUs due to the presence of moe params
pg = groups._get_data_parallel_group()
scaled_norm = total_norm * 1.0 / float(dist.get_world_size(group=pg))
scaled_norm_tensor = get_accelerator().FloatTensor([float(scaled_norm)])
dist.all_reduce(scaled_norm_tensor, group=pg)
total_norm = scaled_norm_tensor.item()
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.data.mul_(clip_coef)
return total_norm
def get_grad_norm(parameters, norm_type=2, mpu=None):
"""Get grad norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place. Taken from Nvidia Megatron.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(p.grad.data.abs().max() for p in parameters)
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
# Take max across all GPUs.
if mpu is not None:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.
tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu)
for p in parameters:
# Pipeline parallelism may replicate parameters. Avoid multi-counting.
if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
continue
# Filter to avoid over-counting replicated tensors from tensor
# model parallelism
if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p):
continue
param_norm = p.grad.data.float().norm(norm_type)
total_norm += param_norm.item()**norm_type
# Sum across all model parallel GPUs.
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
if mpu is not None:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
return total_norm
def get_grad_zeros(parameters, mpu=None):
"""Compute the number of grads with zero values.
This is adapted from get_grad_norm
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
Returns:
Total number of params with zero values (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
total_zeros = 0.
tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu)
for p in parameters:
# Pipeline parallelism may replicate parameters. Avoid multi-counting.
if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
continue
# Filter to avoid over-counting replicated tensors from tensor
# model parallelism
if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p):
continue
count_zeros = p.grad.numel() - torch.count_nonzero(p.grad)
total_zeros += count_zeros.item()
# Sum across all model parallel GPUs.
total_zeros_cuda = get_accelerator().FloatTensor([float(total_zeros)])
if mpu is not None:
dist.all_reduce(total_zeros_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())
total_zeros = total_zeros_cuda[0].item()
return total_zeros
def get_weight_norm(parameters, norm_type=2, mpu=None):
"""Get norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place. Taken from Nvidia Megatron.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
-1 if the norm value is NaN or Inf.
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(p.data.abs().max() for p in parameters)
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
# Take max across all GPUs.
if mpu is not None:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.
tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu)
for p in parameters:
# Pipeline parallelism may replicate parameters. Avoid multi-counting.
if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
continue
# Filter to avoid over-counting replicated tensors from tensor
# model parallelism
if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p):
continue
param_norm = p.data.float().norm(norm_type)
total_norm += param_norm**norm_type
# Sum across all model parallel GPUs.
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
if mpu is not None:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
return total_norm
def prefix_sum_inc(weights):
""" Compute an inclusive prefix sum.
Example:
>>> prefix_sum_inc([3,4,5])
[3, 7, 12]
"""
weights_ = [w for w in weights]
for x in range(1, len(weights_)):
weights_[x] += weights_[x - 1]
return weights_
def partition_uniform(num_items, num_parts):
import numpy
parts = [0] * (num_parts + 1)
# First check for the trivial edge case
if num_items <= num_parts:
for p in range(num_parts + 1):
parts[p] = min(p, num_items)
return parts
chunksize = num_items // num_parts
residual = num_items - (chunksize * num_parts)
parts = numpy.arange(0, (num_parts + 1) * chunksize, chunksize)
for i in range(residual):
parts[i + 1:] += 1
parts = parts.tolist()
return parts
def partition_balanced(weights, num_parts):
"""
use dynamic programming solve `The Linear Partition Problem`.
see https://www8.cs.umu.se/kurser/TDBAfl/VT06/algorithms/BOOK/BOOK2/NODE45.HTM
"""
import numpy as np
n = len(weights)
m = num_parts
if n <= m:
return partition_uniform(n, m)
dp_max = np.full((n + 1, m + 1), np.inf)
dp_min = np.full((n + 1, m + 1), np.inf)
dp_cost = np.full((n + 1, m + 1), np.inf)
position = np.zeros((n + 1, m + 1), dtype=int)
prefix_sum = np.zeros((n + 1))
prefix_sum[1:] = np.cumsum(weights)
dp_max[0, 0] = 0
dp_cost[0, 0] = 0
for i in range(1, n + 1):
for j in range(1, min(i, m) + 1):
for k in range(i):
max_sum = max(dp_max[k, j - 1], prefix_sum[i] - prefix_sum[k])
min_sum = min(dp_min[k, j - 1], prefix_sum[i] - prefix_sum[k])
cost = max_sum - min_sum
if dp_cost[i, j] >= cost:
dp_cost[i, j] = cost
dp_max[i, j] = max_sum
dp_min[i, j] = min_sum
position[i, j] = k
parts = [n]
for i in reversed(range(1, m + 1)):
parts.append(position[parts[-1], i])
parts.reverse()
return parts
class PartitionedTensor:
def __init__(self, tensor, group, partition_meta=None):
super().__init__()
self.group = group
self.num_parts = dist.get_world_size(group=self.group)
self.rank = dist.get_rank(group=self.group)
self.orig_size = list(tensor.size())
self.orig_device = tensor.device
self.local_data, self.partition = self._partition_tensor(tensor)
self.even_split = tensor.numel() % self.num_parts == 0
@classmethod
def from_meta(cls, meta, local_part, group, device=get_accelerator().device_name()):
assert meta.dtype == torch.long
dummy = torch.ones(dist.get_world_size(group=group))
part_obj = cls(tensor=dummy, group=group)
meta = meta.tolist()
# [N, list0, ..., listN-1]
part_obj.orig_size = meta[1:(1 + meta[0])]
meta = meta[1 + meta[0]:]
part_obj.orig_device = device
part_obj.local_data = local_part.detach()
part_obj.group = group
# Partition is encoded like the rowptr of a CSR matrix:
# [num_parts, rank, 0, part_1, ..., part_num_parts]
# TODO: support shuffle between different partition granularities
assert part_obj.num_parts == meta[0]
assert part_obj.rank == meta[1]
part_obj.partition = meta[2:] # length num_parts+1
return part_obj
def _partition_tensor(self, tensor):
partition = partition_uniform(num_items=tensor.numel(), num_parts=self.num_parts)
start = partition[self.rank]
length = partition[self.rank + 1] - start
tensor_part = tensor.detach().contiguous().view(-1).narrow(0, start=start, length=length).clone()
return tensor_part, partition
def full(self, device=None):
if device is None:
device = self.orig_device
# Allocate the full tensor as a flat buffer.
full_numel = prod(self.full_size())
flat_tensor = torch.zeros([full_numel], dtype=self.local_data.dtype, device=device)
if self.even_split:
# Collect the full tensor
dist.all_gather_into_tensor(flat_tensor, self.local_data, group=self.group)
else:
for part_id in range(self.num_parts):
part_size = self.partition[part_id + 1] - self.partition[part_id]
buf = flat_tensor.narrow(0, start=self.partition[part_id], length=part_size)
if part_id == self.rank:
buf.copy_(self.local_data)
dist.broadcast(buf, part_id, self.group)
return flat_tensor.view(self.full_size()).clone().detach()
def to_meta(self):
"""Returns a torch.LongTensor that encodes partitioning information.
Can be used along with ``data()`` to serialize a ``PartitionedTensor`` for
communication.
Returns:
torch.LongTensor: a tensor encoding the meta-information for the partitioning
"""
meta = []
meta.append(len(self.orig_size))
meta += list(self.orig_size)
meta.append(self.num_parts)
meta.append(self.rank)
meta += self.partition
return torch.LongTensor(data=meta).to(self.orig_device)
def data(self):
return self.local_data
def local_size(self):
return self.local_data.size()
def full_size(self):
return self.orig_size
mem_alloced = 0
mem_cached = 0
def memory_status(msg, print_rank=-1, reset_max=False):
global mem_alloced, mem_cached
rank = dist.get_rank()
if print_rank != -1 and rank != print_rank:
return
get_accelerator().synchronize()
if reset_max:
get_accelerator().reset_max_memory_cached()
get_accelerator().reset_max_memory_allocated()
new_alloced = get_accelerator().memory_allocated()
new_cached = get_accelerator().memory_cached()
delta_alloced = new_alloced - mem_alloced
delta_cached = new_cached - mem_cached
mem_cached = new_cached
mem_alloced = new_alloced
max_alloced = get_accelerator().max_memory_allocated()
max_cached = get_accelerator().max_memory_cached()
# convert to GB for printing
new_alloced /= 1024**3
new_cached /= 1024**3
delta_alloced /= 1024**3
delta_cached /= 1024**3
max_alloced /= 1024**3
max_cached /= 1024**3
print(
f'RANK={rank} MEMSTATS', msg, f'device={get_accelerator().current_device_name()} '
f'current alloc={new_alloced:0.4f}GB (delta={delta_alloced:0.4f}GB max={max_alloced:0.4f}GB) '
f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)')
def get_ma_status():
if dist.is_initialized() and not dist.get_rank() == 0:
return 0
return get_accelerator().memory_allocated()
def empty_cache():
get_accelerator().empty_cache()
get_accelerator().reset_peak_memory_stats()
def see_memory_usage(message, force=False):
if not force:
return
if dist.is_initialized() and not dist.get_rank() == 0:
return
# python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
gc.collect()
# Print message except when distributed but not rank 0
logger.info(message)
logger.info(f"MA {round(get_accelerator().memory_allocated() / (1024 * 1024 * 1024),2 )} GB \
Max_MA {round(get_accelerator().max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \
CA {round(torch_memory_reserved() / (1024 * 1024 * 1024),2)} GB \
Max_CA {round(torch_max_memory_reserved() / (1024 * 1024 * 1024))} GB ")
vm_stats = psutil.virtual_memory()
used_GB = round(((vm_stats.total - vm_stats.available) / (1024**3)), 2)
logger.info(f'CPU Virtual Memory: used = {used_GB} GB, percent = {vm_stats.percent}%')
# get the peak memory to report correct data, so reset the counter for the next call
get_accelerator().reset_peak_memory_stats()
def call_to_str(base, *args, **kwargs):
"""Construct a string representation of a call.
Args:
base (str): name of the call
args (tuple, optional): args to ``base``
kwargs (dict, optional): kwargs supplied to ``base``
Returns:
str: A string representation of base(*args, **kwargs)
"""
name = f'{base}('
if args:
name += ', '.join(repr(arg) for arg in args)
if kwargs:
name += ', '
if kwargs:
name += ', '.join(f'{key}={repr(arg)}' for key, arg in kwargs.items())
name += ')'
return name
def get_only_unique_item(items):
item_set = set(items)
if len(item_set) != 1:
raise RuntimeError(f"expected there to be only one unique element in {items}")
unique_item, = item_set
return unique_item
def clip_gradients(parameters, max_norm=1.0, global_grad_norm=None, mpu=None, eps=1e-6):
"""Clip the gradient of a list of parameters.
Args:
parameters: List of parameters whose .grad will be clipped.
global_grad_norm (float, optional): Precomputed gradient norm. Defaults to None.
mpu (optional): model parallelism unit. Defaults to None.
eps (float, optional): epsilon value added to grad norm. Defaults to 1e-6
Returns:
float: the global gradient norm
"""
if global_grad_norm is None:
global_grad_norm = get_grad_norm(parameters, mpu=mpu)
clip_coef = max_norm / (global_grad_norm + eps)
if clip_coef < 1:
for p in parameters:
p.grad.detach().mul_(clip_coef)
return global_grad_norm
def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False):
"""Get norm of an iterable of tensors.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Taken from Nvidia Megatron.
Arguments:
input_tensors (Iterable[Tensor]): an iterable of Tensors will have norm computed
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the tensors (viewed as a single vector).
"""
assert isinstance(input_tensors, Iterable), f'expected Iterable type not {type(input_tensors)}'
assert all([torch.is_tensor(t) for t in input_tensors]), f'expected list of only tensors'
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(t.data.abs().max() for t in input_tensors)
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
if mpu is not None:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()
else:
if use_graph:
if 'norm_tensors_compute_buffer' not in graph_cache:
graph_cache['norm_tensors_compute_buffer'] = [t.data.float().norm(norm_type) for t in input_tensors]
compute_buffer = graph_cache['norm_tensors_compute_buffer']
def _norm_tensors(tensor_list, _compute_buffer, _norm_type):
for i, t in enumerate(tensor_list):
_compute_buffer[i].data.copy_(t.data.float().norm(_norm_type)**_norm_type)
if i != 0:
_compute_buffer[0].data.add_(_compute_buffer[i].data)
graph_process(False, _norm_tensors, input_tensors, compute_buffer, norm_type)
total_norm = compute_buffer[0]
else:
total_norm = sum([t.data.float().norm(norm_type).item()**norm_type for t in input_tensors])
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]).detach()
if mpu is not None:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
return total_norm
def clip_tensors_by_global_norm(input_tensors, max_norm=1.0, global_norm=None, mpu=None, eps=1e-6, use_graph=False):
"""Clip list of tensors by global norm.
Args:
input_tensors: List of tensors to be clipped
global_norm (float, optional): Precomputed norm. Defaults to None.
mpu (optional): model parallelism unit. Defaults to None.
eps (float, optional): epsilon value added to grad norm. Defaults to 1e-6
Returns:
float: the global norm
"""
if global_norm is None:
global_norm = get_global_norm_of_tensors(input_tensors, mpu=mpu, use_graph=use_graph)
clip_coef = max_norm / (global_norm + eps)
if clip_coef < 1:
if use_graph:
def clip_tensors(_tensor_list, _clip_coef_tensor):
for t in _tensor_list:
t.detach().mul_(_clip_coef_tensor)
if 'clip_coef_tensor' not in graph_cache:
# Alloc memory
graph_cache['clip_coef_tensor'] = torch.tensor(clip_coef,
dtype=torch.float32).to(get_accelerator().device_name())
clip_coef_tensor = graph_cache['clip_coef_tensor']
clip_coef_tensor.copy_(torch.tensor(clip_coef, dtype=torch.float32))
graph_process(False, clip_tensors, input_tensors, clip_coef_tensor)
else:
for t in input_tensors:
t.detach().mul_(clip_coef)
return global_norm
def align_dense_tensors(tensor_list, alignment):
num_elements = sum(t.numel() for t in tensor_list)
remaining = num_elements % alignment
if remaining:
elements_to_add = alignment - remaining
pad_tensor = torch.zeros(elements_to_add, device=tensor_list[0].device, dtype=tensor_list[0].dtype)
padded_tensor_list = tensor_list + [pad_tensor]
else:
padded_tensor_list = tensor_list
return padded_tensor_list
def all_gather_into_tensor_dp_groups(groups_flat, partitioned_param_groups, zp_process_group, dp_process_group=None):
for group_id, (group_flat, partitioned_params) in enumerate(zip(groups_flat, partitioned_param_groups)):
partition_id = dist.get_rank(group=zp_process_group[group_id])
dp_world_size = dist.get_world_size(group=dp_process_group)
if dp_world_size == 1:
# no groups share optimizer states
# pipeline parallel with bf16 will default call this even if dp size = 1.
continue
# print("call contiguous for all_gather_into_tensor_dp_groups")
dist.all_gather_into_tensor(group_flat, partitioned_params[partition_id].contiguous(), dp_process_group)
def all_gather_dp_groups(groups_flat, partitioned_param_groups, zp_process_group, start_alignment_factor,
allgather_bucket_size, dp_process_group=None):
# if dist.has_all_gather_into_tensor():
return all_gather_into_tensor_dp_groups(groups_flat, partitioned_param_groups, zp_process_group, dp_process_group)
# for group_id, partitioned_params in enumerate(partitioned_param_groups):
# # Sequential AllGather Best of both worlds
# partition_id = dist.get_rank(group=dp_process_group[group_id])
# dp_world_size = dist.get_world_size(group=dp_process_group[group_id])
#
# if dp_world_size == 1:
# # no groups share optimizer states
# # pipeline parallel with bf16 will default call this even if dp size = 1.
# continue
# num_shards = max(1, partitioned_params[partition_id].numel() * dp_world_size // allgather_bucket_size)
#
# shard_size = partitioned_params[partition_id].numel() // num_shards
#
# # Enforce nccl/rccl alignment of start location of each shard
# shard_size = shard_size - (shard_size % start_alignment_factor)
#
# num_elements = shard_size
#
# assert shard_size * num_shards <= partitioned_params[partition_id].numel()
#
# for shard_id in range(num_shards):
#
# if shard_id == (num_shards - 1):
# num_elements = partitioned_params[partition_id].numel() - shard_id * shard_size
#
# shard_list = []
# for dp_id in range(dp_world_size):
# curr_shard = partitioned_params[dp_id].narrow(0, shard_id * shard_size, num_elements).detach()
# shard_list.append(curr_shard)
# dist.all_gather(shard_list, shard_list[partition_id].contiguous(), dp_process_group[group_id])
class TLinear(torch.nn.Linear):
def __init__(self, orig_layer, name=""):
self.name = name
super().__init__(orig_layer.weight.shape[1], orig_layer.weight.shape[0], bias=(orig_layer.bias is not None))
self.weight.data = transpose(orig_layer.weight.data)
self.bias = orig_layer.bias
self._fwd_func = self._fwd_bias_add if self.bias is not None else self._fwd
def _fwd(self, input):
return F.linear(input, self.weight)
def _fwd_bias_add(self, input):
return F.linear(input, self.weight, bias=self.bias)
def forward(self, input):
return self._fwd_func(input)
def get_inactive_params(param_list):
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
return [param for param in param_list if (hasattr(param, 'ds_id') and \
param.ds_status == ZeroParamStatus.NOT_AVAILABLE)]
def required_torch_version(min_version=None, max_version=None):
assert min_version or max_version, "Must provide a min_version or max_version argument"
torch_version = pkg_version.parse(torch.__version__)
if min_version and pkg_version.parse(str(min_version)) > torch_version:
return False
if max_version and pkg_version.parse(str(max_version)) < torch_version:
return False
return True
================================================
FILE: opensora/adaptor/zp_manager.py
================================================
import torch
import os
import torch.distributed as dist
class ZPManager(object):
def __init__(self, zp_size=8):
self.rank = int(os.getenv('RANK', '0'))
self.world_size = int(os.getenv("WORLD_SIZE", '1'))
self.zp_size = zp_size
self.zp_group = None
self.zp_rank = None
self.is_initialized = False
def init_group(self):
if self.is_initialized:
return
self.is_initialized = True
"""Initialize the sequence parallel group."""
num_zp_groups: int = self.world_size // self.zp_size
for i in range(num_zp_groups):
ranks = range(i * self.zp_size, (i + 1) * self.zp_size)
group = dist.new_group(ranks)
if self.rank in ranks:
self.zp_group = group
self.zp_rank = self.rank % self.zp_size
zp_manager = ZPManager()
================================================
FILE: opensora/dataset/__init__.py
================================================
from torchvision.transforms import Compose
from transformers import AutoTokenizer, AutoImageProcessor
from torchvision import transforms
from torchvision.transforms import Lambda
try:
import torch_npu
except:
torch_npu = None
from opensora.dataset.t2v_datasets import T2V_dataset
from opensora.dataset.inpaint_dataset import Inpaint_dataset
from opensora.models.causalvideovae import ae_norm, ae_denorm
from opensora.dataset.transform import ToTensorVideo, TemporalRandomCrop, MaxHWResizeVideo, CenterCropResizeVideo, LongSideResizeVideo, SpatialStrideCropVideo, NormalizeVideo, ToTensorAfterResize
def getdataset(args):
temporal_sample = TemporalRandomCrop(args.num_frames) # 16 x
norm_fun = ae_norm[args.ae]
if args.force_resolution:
resize = [CenterCropResizeVideo((args.max_height, args.max_width)), ]
else:
resize = [
MaxHWResizeVideo(args.max_hxw),
SpatialStrideCropVideo(stride=args.hw_stride),
]
tokenizer_1 = AutoTokenizer.from_pretrained(args.text_encoder_name_1, cache_dir=args.cache_dir)
tokenizer_2 = None
if args.text_encoder_name_2 is not None:
tokenizer_2 = AutoTokenizer.from_pretrained(args.text_encoder_name_2, cache_dir=args.cache_dir)
if args.dataset == 't2v':
transform = transforms.Compose([
ToTensorVideo(),
*resize,
norm_fun
]) # also work for img, because img is video when frame=1
return T2V_dataset(
args, transform=transform, temporal_sample=temporal_sample,
tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2
)
elif args.dataset == 'i2v' or args.dataset == 'inpaint':
resize_transform = Compose(resize)
transform = Compose([
ToTensorAfterResize(),
norm_fun,
])
return Inpaint_dataset(
args, resize_transform=resize_transform, transform=transform,
temporal_sample=temporal_sample, tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2
)
raise NotImplementedError(args.dataset)
if __name__ == "__main__":
'''
python opensora/dataset/__init__.py
'''
from accelerate import Accelerator
from opensora.dataset.t2v_datasets import dataset_prog
from opensora.utils.dataset_utils import LengthGroupedSampler, Collate
from torch.utils.data import DataLoader
import random
from torch import distributed as dist
from tqdm import tqdm
args = type('args', (),
{
'ae': 'WFVAEModel_D32_4x8x8',
'dataset': 't2v',
'model_max_length': 512,
'max_height': 640,
'max_width': 640,
'hw_stride': 16,
'num_frames': 93,
'compress_kv_factor': 1,
'interpolation_scale_t': 1,
'interpolation_scale_h': 1,
'interpolation_scale_w': 1,
'cache_dir': '../cache_dir',
'data': '/home/image_data/gyy/mmdit/Open-Sora-Plan/scripts/train_data/current_hq_on_npu.txt',
'train_fps': 18,
'drop_short_ratio': 0.0,
'speed_factor': 1.0,
'cfg': 0.1,
'text_encoder_name_1': 'google/mt5-xxl',
'text_encoder_name_2': None,
'dataloader_num_workers': 8,
'force_resolution': False,
'use_decord': True,
'group_data': True,
'train_batch_size': 1,
'gradient_accumulation_steps': 1,
'ae_stride': 8,
'ae_stride_t': 4,
'patch_size': 2,
'patch_size_t': 1,
'total_batch_size': 256,
'sp_size': 1,
'max_hxw': 384*384,
'min_hxw': 384*288,
# 'max_hxw': 236544,
# 'min_hxw': 102400,
}
)
# accelerator = Accelerator()
dataset = getdataset(args)
# data = next(iter(dataset))
# import ipdb;ipdb.set_trace()
# print()
sampler = LengthGroupedSampler(
args.train_batch_size,
world_size=1,
gradient_accumulation_size=args.gradient_accumulation_steps,
initial_global_step=0,
lengths=dataset.lengths,
group_data=args.group_data,
)
train_dataloader = DataLoader(
dataset,
shuffle=False,
# pin_memory=True,
collate_fn=Collate(args),
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
sampler=sampler,
drop_last=False,
prefetch_factor=4
)
import ipdb;ipdb.set_trace()
import imageio
import numpy as np
from einops import rearrange
while True:
for idx, i in enumerate(tqdm(train_dataloader)):
pixel_values = i[0][0]
pixel_values_ = (pixel_values+1)/2
pixel_values_ = rearrange(pixel_values_, 'c t h w -> t h w c') * 255.0
pixel_values_ = pixel_values_.numpy().astype(np.uint8)
imageio.mimwrite(f'output{idx}.mp4', pixel_values_, fps=args.train_fps)
dist.barrier()
pass
================================================
FILE: opensora/dataset/inpaint_dataset.py
================================================
import time
import traceback
try:
import torch_npu
from opensora.npu_config import npu_config
except:
torch_npu = None
npu_config = None
import glob
import json
import pickle
import os, io, csv, math, random
import numpy as np
import torchvision
from einops import rearrange
from os.path import join as opj
from collections import Counter
import cv2
import pandas as pd
import time
import torch
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader, Dataset, get_worker_info
from tqdm import tqdm
from PIL import Image
from accelerate.logging import get_logger
import gc
import decord
from opensora.utils.dataset_utils import DecordInit
from opensora.utils.utils import text_preprocessing
from opensora.dataset.transform import get_params, maxhwresize, add_masking_notice, calculate_statistics, \
add_aesthetic_notice_image, add_aesthetic_notice_video
from opensora.utils.mask_utils import MaskProcessor, STR_TO_TYPE
from opensora.dataset.t2v_datasets import T2V_dataset, DataSetProg
logger = get_logger(__name__)
dataset_prog = DataSetProg()
def type_ratio_normalize(mask_type_ratio_dict):
for k, v in mask_type_ratio_dict.items():
assert v >= 0, f"mask_type_ratio_dict[{k}] should be non-negative, but got {v}"
total = sum(mask_type_ratio_dict.values())
length = len(mask_type_ratio_dict)
if total == 0:
return {k: 1.0 / length for k in mask_type_ratio_dict.keys()}
return {k: v / total for k, v in mask_type_ratio_dict.items()}
class Inpaint_dataset(T2V_dataset):
def __init__(self, args, resize_transform, transform, temporal_sample, tokenizer_1, tokenizer_2):
super().__init__(
args=args,
transform=transform,
temporal_sample=temporal_sample,
tokenizer_1=tokenizer_1,
tokenizer_2=tokenizer_2
)
self.resize_transform = resize_transform
if self.num_frames != 1:
self.mask_type_ratio_dict_video = args.mask_type_ratio_dict_video if args.mask_type_ratio_dict_video is not None else {'random_temporal': 1.0}
self.mask_type_ratio_dict_video = {STR_TO_TYPE[k]: v for k, v in self.mask_type_ratio_dict_video.items()}
self.mask_type_ratio_dict_video = type_ratio_normalize(self.mask_type_ratio_dict_video)
self.mask_type_ratio_dict_image = args.mask_type_ratio_dict_image if args.mask_type_ratio_dict_image is not None else {'random_spatial': 1.0}
self.mask_type_ratio_dict_image = {STR_TO_TYPE[k]: v for k, v in self.mask_type_ratio_dict_image.items()}
self.mask_type_ratio_dict_image = type_ratio_normalize(self.mask_type_ratio_dict_image)
print(f"mask_type_ratio_dict_video: {self.mask_type_ratio_dict_video}")
print(f"mask_type_ratio_dict_image: {self.mask_type_ratio_dict_image}")
self.mask_processor = MaskProcessor(
max_height=args.max_height,
max_width=args.max_width,
min_clear_ratio=args.min_clear_ratio,
max_clear_ratio=args.max_clear_ratio,
)
self.default_text_ratio = args.default_text_ratio
def __getitem__(self, idx):
try:
# future = self.executor.submit(self.get_data, idx)
# data = future.result(timeout=self.timeout)
# return data
return self.get_data(idx)
except Exception as e:
# if len(str(e)) < 2:
# e = f"TimeoutError, {self.timeout}s timeout occur with {dataset_prog.cap_list[idx]['path']}"
print(f'Error with {e}')
index_cand = self.shape_idx_dict[self.sample_size[idx]] # pick same shape
return self.__getitem__(random.choice(index_cand))
# return self.__getitem__(idx)
def get_data(self, idx):
path = dataset_prog.cap_list[idx]['path']
if not os.path.exists(path):
print(f"file {path} do not exist, random choice a new one with same shape!")
index_cand = self.shape_idx_dict[self.sample_size[idx]]
return self.__getitem__(random.choice(index_cand))
if path.endswith('.mp4'):
return self.get_video(idx)
else:
return self.get_image(idx)
def drop(self, text, is_video=True):
rand_num = random.random()
rand_num_text = random.random()
if rand_num < self.cfg:
if rand_num_text < self.default_text_ratio:
if not is_video:
text = "The image showcases a scene with coherent and clear visuals."
else:
text = "The video showcases a scene with coherent and clear visuals."
else:
text = ''
return dict(text=text)
def get_video(self, idx):
# npu_config.print_msg(f"current idx is {idx}")
# video = random.choice([random_video_noise(65, 3, 336, 448), random_video_noise(65, 3, 1024, 1024), random_video_noise(65, 3, 360, 480)])
# # print('random shape', video.shape)
# input_ids = torch.ones(1, 120).to(torch.long).squeeze(0)
# cond_mask = torch.cat([torch.ones(1, 60).to(torch.long), torch.ones(1, 60).to(torch.long)], dim=1).squeeze(0)
# logger.info(f'Now we use t2v dataset {idx}')
video_data = dataset_prog.cap_list[idx]
video_path = video_data['path']
# assert os.path.exists(video_path), f"file {video_path} do not exist!"
sample_h = video_data['resolution']['sample_height']
sample_w = video_data['resolution']['sample_width']
if self.video_reader == 'decord':
video = self.decord_read(video_data)
elif self.video_reader == 'opencv':
video = self.opencv_read(video_data)
else:
NotImplementedError(f'Found {self.video_reader}, but support decord or opencv')
# import ipdb;ipdb.set_trace()
video = self.resize_transform(video) # T C H W -> T C H W
assert video.shape[2] == sample_h and video.shape[3] == sample_w, f'sample_h ({sample_h}), sample_w ({sample_w}), video ({video.shape}), video_path ({video_path})'
inpaint_cond_data = self.mask_processor(video, mask_type_ratio_dict=self.mask_type_ratio_dict_video)
mask, masked_video = inpaint_cond_data['mask'], inpaint_cond_data['masked_pixel_values']
video = self.transform(video) # T C H W -> T C H W
masked_video = self.transform(masked_video) # T C H W -> T C H W
video = torch.cat([video, masked_video, mask], dim=1) # T 2C+1 H W
video = video.transpose(0, 1) # T C H W -> C T H W
text = video_data['cap']
if not isinstance(text, list):
text = [text]
text = [random.choice(text)]
if video_data.get('aesthetic', None) is not None or video_data.get('aes', None) is not None:
aes = video_data.get('aesthetic', None) or video_data.get('aes', None)
text = [add_aesthetic_notice_video(text[0], aes)]
text = self.drop(text, is_video=True)['text']
text_tokens_and_mask_1 = self.tokenizer_1(
text,
max_length=self.model_max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors='pt'
)
input_ids_1 = text_tokens_and_mask_1['input_ids']
cond_mask_1 = text_tokens_and_mask_1['attention_mask']
input_ids_2, cond_mask_2 = None, None
if self.tokenizer_2 is not None:
text_tokens_and_mask_2 = self.tokenizer_2(
text,
max_length=self.tokenizer_2.model_max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors='pt'
)
input_ids_2 = text_tokens_and_mask_2['input_ids']
cond_mask_2 = text_tokens_and_mask_2['attention_mask']
return dict(
pixel_values=video, input_ids_1=input_ids_1, cond_mask_1=cond_mask_1,
input_ids_2=input_ids_2, cond_mask_2=cond_mask_2,
)
def get_image(self, idx):
image_data = dataset_prog.cap_list[idx] # [{'path': path, 'cap': cap}, ...]
sample_h = image_data['resolution']['sample_height']
sample_w = image_data['resolution']['sample_width']
image = Image.open(image_data['path']).convert('RGB') # [h, w, c]
image = torch.from_numpy(np.array(image)) # [h, w, c]
image = rearrange(image, 'h w c -> c h w').unsqueeze(0) # [1 c h w]
image = self.resize_transform(image) # [1 c h w]
assert image.shape[2] == sample_h, image.shape[3] == sample_w
inpaint_cond_data = self.mask_processor(image, mask_type_ratio_dict=self.mask_type_ratio_dict_image)
mask, masked_image = inpaint_cond_data['mask'], inpaint_cond_data['masked_pixel_values']
image = self.transform(image)
masked_image = self.transform(masked_image)
image = torch.cat([image, masked_image, mask], dim=1) # [1 2C+1 H W]
# image = [torch.rand(1, 3, 480, 640) for i in image_data]
image = image.transpose(0, 1) # [1 C H W] -> [C 1 H W]
caps = image_data['cap'] if isinstance(image_data['cap'], list) else [image_data['cap']]
caps = [random.choice(caps)]
# caps = [caps[0]]
if '/sam/' in image_data['path']:
caps = [add_masking_notice(caps[0])]
if image_data.get('aesthetic', None) is not None or image_data.get('aes', None) is not None:
aes = image_data.get('aesthetic', None) or image_data.get('aes', None)
caps = [add_aesthetic_notice_image(caps[0], aes)]
text = text_preprocessing(caps, support_Chinese=self.support_Chinese)
text = self.drop(text, is_video=False)['text']
text_tokens_and_mask_1 = self.tokenizer_1(
text,
max_length=self.model_max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors='pt'
)
input_ids_1 = text_tokens_and_mask_1['input_ids'] # 1, l
cond_mask_1 = text_tokens_and_mask_1['attention_mask'] # 1, l
input_ids_2, cond_mask_2 = None, None
if self.tokenizer_2 is not None:
text_tokens_and_mask_2 = self.tokenizer_2(
text,
max_length=self.tokenizer_2.model_max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors='pt'
)
input_ids_2 = text_tokens_and_mask_2['input_ids'] # 1, l
cond_mask_2 = text_tokens_and_mask_2['attention_mask'] # 1, l
return dict(
pixel_values=image, input_ids_1=input_ids_1, cond_mask_1=cond_mask_1, motion_score=None,
input_ids_2=input_ids_2, cond_mask_2=cond_mask_2
)
================================================
FILE: opensora/dataset/t2v_datasets.py
================================================
import time
import traceback
try:
import torch_npu
from opensora.npu_config import npu_config
except:
torch_npu = None
npu_config = None
import glob
import json
import pickle
import os, io, csv, math, random
import numpy as np
import torchvision
from einops import rearrange
from os.path import join as opj
from collections import Counter
import cv2
import pandas as pd
import time
import torch
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader, Dataset, get_worker_info
from tqdm import tqdm
from PIL import Image
from accelerate.logging import get_logger
import gc
from opensora.utils.dataset_utils import DecordInit
from opensora.utils.utils import text_preprocessing
from opensora.dataset.transform import get_params, maxhwresize, add_masking_notice, calculate_statistics, \
add_aesthetic_notice_image, add_aesthetic_notice_video
import decord
from concurrent.futures import ThreadPoolExecutor, TimeoutError
logger = get_logger(__name__)
def filter_json_by_existed_files(directory, data, postfix=".mp4"):
# 构建搜索模式,以匹配指定后缀的文件
pattern = os.path.join(directory, '**', f'*{postfix}')
mp4_files = glob.glob(pattern, recursive=True) # 使用glob查找所有匹配的文件
# 使用文件的绝对路径构建集合
mp4_files_set = set(os.path.abspath(path) for path in mp4_files)
# 过滤数据条目,只保留路径在mp4文件集合中的条目
filtered_items = [item for item in data if item['path'] in mp4_files_set]
return filtered_items
def random_video_noise(t, c, h, w):
vid = torch.rand(t, c, h, w) * 255.0
vid = vid.to(torch.uint8)
return vid
class SingletonMeta(type):
"""
这是一个元类,用于创建单例类。
"""
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
return cls._instances[cls]
class DataSetProg(metaclass=SingletonMeta):
def __init__(self):
self.cap_list = []
self.elements = []
self.num_workers = 1
self.n_elements = 0
self.worker_elements = dict()
self.n_used_elements = dict()
def set_cap_list(self, num_workers, cap_list, n_elements):
self.num_workers = num_workers
self.cap_list = cap_list
self.n_elements = n_elements
self.elements = list(range(n_elements))
print(f"n_elements: {len(self.elements)}", flush=True)
# if torch_npu is not None:
# random.shuffle(self.elements)
# for i in range(self.num_workers):
# self.n_used_elements[i] = 0
# per_worker = int(math.ceil(len(self.elements) / float(self.num_workers)))
# start = i * per_worker
# end = min(start + per_worker, len(self.elements))
# self.worker_elements[i] = self.elements[start: end]
def get_item(self, work_info):
if work_info is None:
worker_id = 0
else:
worker_id = work_info.id
idx = self.worker_elements[worker_id][self.n_used_elements[worker_id] % len(self.worker_elements[worker_id])]
self.n_used_elements[worker_id] += 1
return idx
dataset_prog = DataSetProg()
def find_closest_y(x, vae_stride_t=4, model_ds_t=1):
min_num_frames = 29
if x < min_num_frames:
return -1
for y in range(x, min_num_frames - 1, -1):
if (y - 1) % vae_stride_t == 0 and ((y - 1) // vae_stride_t + 1) % model_ds_t == 0:
# 4, 8: y in [29, 61, 93, 125, 157, 189, 221, 253, 285, 317, 349, 381, 413, 445, 477, 509, ...]
# 4, 4: y in [29, 45, 61, 77, 93, 109, 125, 141, 157, 173, 189, 205, 221, 237, 253, 269, 285, 301, 317, 333, 349, 365, 381, 397, 413, 429, 445, 461, 477, 493, 509, ...]
# 8, 1: y in [33, 41, 49, 57, 65, 73, 81, 89, 97, 105]
# 8, 2: y in [41, 57, 73, 89, 105]
# 8, 4: y in [57, 89]
# 8, 8: y in [57]
return y
return -1
def filter_resolution(h, w, max_h_div_w_ratio=17/16, min_h_div_w_ratio=8 / 16):
if h / w <= max_h_div_w_ratio and h / w >= min_h_div_w_ratio:
return True
return False
def read_parquet(path):
df = pd.read_parquet(path)
data = df.to_dict(orient='records')
return data
class DecordDecoder(object):
def __init__(self, url, num_threads=1):
self.num_threads = num_threads
self.ctx = decord.cpu(0)
self.reader = decord.VideoReader(url,
ctx=self.ctx,
num_threads=self.num_threads)
def get_avg_fps(self):
return self.reader.get_avg_fps() if self.reader.get_avg_fps() > 0 else 30.0
def get_num_frames(self):
return len(self.reader)
def get_height(self):
return self.reader[0].shape[0] if self.get_num_frames() > 0 else 0
def get_width(self):
return self.reader[0].shape[1] if self.get_num_frames() > 0 else 0
# output shape [T, H, W, C]
def get_batch(self, frame_indices):
try:
#frame_indices[0] = 1000
video_data = self.reader.get_batch(frame_indices).asnumpy()
video_data = torch.from_numpy(video_data)
return video_data
except Exception as e:
print('get_batch execption:', e)
return None
class T2V_dataset(Dataset):
def __init__(self, args, transform, temporal_sample, tokenizer_1, tokenizer_2):
self.data = args.data
self.num_frames = args.num_frames
self.train_fps = args.train_fps
self.transform = transform
self.temporal_sample = temporal_sample
self.tokenizer_1 = tokenizer_1
self.tokenizer_2 = tokenizer_2
self.model_max_length = args.model_max_length
self.cfg = args.cfg
self.speed_factor = args.speed_factor
self.max_height = args.max_height
self.max_width = args.max_width
self.drop_short_ratio = args.drop_short_ratio
self.hw_stride = args.hw_stride
self.force_resolution = args.force_resolution
self.max_hxw = args.max_hxw
self.min_hxw = args.min_hxw
self.sp_size = args.sp_size
assert self.speed_factor >= 1
self.video_reader = 'decord' if args.use_decord else 'opencv'
self.ae_stride_t = args.ae_stride_t
self.total_batch_size = args.total_batch_size
self.seed = 42
self.generator = torch.Generator().manual_seed(self.seed)
self.hw_aspect_thr = 2.0 # just a threshold
self.too_long_factor = 5.0
self.support_Chinese = False
if 'mt5' in args.text_encoder_name_1:
self.support_Chinese = True
if args.text_encoder_name_2 is not None and 'mt5' in args.text_encoder_name_2:
self.support_Chinese = True
s = time.time()
cap_list, self.sample_size, self.shape_idx_dict = self.define_frame_index(self.data)
e = time.time()
print(f'Build data time: {e-s}')
self.lengths = self.sample_size
n_elements = len(cap_list)
dataset_prog.set_cap_list(args.dataloader_num_workers, cap_list, n_elements)
print(f"Data length: {len(dataset_prog.cap_list)}")
self.executor = ThreadPoolExecutor(max_workers=1)
self.timeout = 60
def set_checkpoint(self, n_used_elements):
for i in range(len(dataset_prog.n_used_elements)):
dataset_prog.n_used_elements[i] = n_used_elements
def __len__(self):
return dataset_prog.n_elements
def __getitem__(self, idx):
try:
future = self.executor.submit(self.get_data, idx)
data = future.result(timeout=self.timeout)
# data = self.get_data(idx)
return data
except Exception as e:
if len(str(e)) < 2:
e = f"TimeoutError, {self.timeout}s timeout occur with {dataset_prog.cap_list[idx]['path']}"
print(f'Error with {e}')
index_cand = self.shape_idx_dict[self.sample_size[idx]] # pick same shape
return self.__getitem__(random.choice(index_cand))
def get_data(self, idx):
path = dataset_prog.cap_list[idx]['path']
if path.endswith('.mp4'):
return self.get_video(idx)
else:
return self.get_image(idx)
def get_video(self, idx):
video_data = dataset_prog.cap_list[idx]
video_path = video_data['path']
assert os.path.exists(video_path), f"file {video_path} do not exist!"
sample_h = video_data['resolution']['sample_height']
sample_w = video_data['resolution']['sample_width']
if self.video_reader == 'decord':
video = self.decord_read(video_data)
elif self.video_reader == 'opencv':
video = self.opencv_read(video_data)
else:
NotImplementedError(f'Found {self.video_reader}, but support decord or opencv')
# import ipdb;ipdb.set_trace()
video = self.transform(video) # T C H W -> T C H W
assert video.shape[2] == sample_h and video.shape[3] == sample_w, f'sample_h ({sample_h}), sample_w ({sample_w}), video ({video.shape})'
# video = torch.rand(105, 3, 640, 640)
video = video.transpose(0, 1) # T C H W -> C T H W
text = video_data['cap']
if not isinstance(text, list):
text = [text]
text = [random.choice(text)]
if video_data.get('aesthetic', None) is not None or video_data.get('aes', None) is not None:
aes = video_data.get('aesthetic', None) or video_data.get('aes', None)
text = [add_aesthetic_notice_video(text[0], aes)]
text = text_preprocessing(text, support_Chinese=self.support_Chinese)
text = text if random.random() > self.cfg else ""
text_tokens_and_mask_1 = self.tokenizer_1(
text,
max_length=self.model_max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors='pt'
)
input_ids_1 = text_tokens_and_mask_1['input_ids']
cond_mask_1 = text_tokens_and_mask_1['attention_mask']
input_ids_2, cond_mask_2 = None, None
if self.tokenizer_2 is not None:
text_tokens_and_mask_2 = self.tokenizer_2(
text,
max_length=self.tokenizer_2.model_max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors='pt'
)
input_ids_2 = text_tokens_and_mask_2['input_ids']
cond_mask_2 = text_tokens_and_mask_2['attention_mask']
return dict(
pixel_values=video, input_ids_1=input_ids_1, cond_mask_1=cond_mask_1,
input_ids_2=input_ids_2, cond_mask_2=cond_mask_2,
)
def get_image(self, idx):
image_data = dataset_prog.cap_list[idx] # [{'path': path, 'cap': cap}, ...]
sample_h = image_data['resolution']['sample_height']
sample_w = image_data['resolution']['sample_width']
image = Image.open(image_data['path']).convert('RGB') # [h, w, c]
image = torch.from_numpy(np.array(image)) # [h, w, c]
image = rearrange(image, 'h w c -> c h w').unsqueeze(0) # [1 c h w]
image = self.transform(image) # [1 C H W] -> num_img [1 C H W]
assert image.shape[2] == sample_h and image.shape[3] == sample_w, f"image_data: {image_data}, but found image {image.shape}"
# image = torch.rand(1, 3, sample_h, sample_w)
image = image.transpose(0, 1) # [1 C H W] -> [C 1 H W]
caps = image_data['cap'] if isinstance(image_data['cap'], list) else [image_data['cap']]
caps = [random.choice(caps)]
if image_data.get('aesthetic', None) is not None or image_data.get('aes', None) is not None:
aes = image_data.get('aesthetic', None) or image_data.get('aes', None)
caps = [add_aesthetic_notice_image(caps[0], aes)]
text = text_preprocessing(caps, support_Chinese=self.support_Chinese)
text = text if random.random() > self.cfg else ""
text_tokens_and_mask_1 = self.tokenizer_1(
text,
max_length=self.model_max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors='pt'
)
input_ids_1 = text_tokens_and_mask_1['input_ids'] # 1, l
cond_mask_1 = text_tokens_and_mask_1['attention_mask'] # 1, l
input_ids_2, cond_mask_2 = None, None
if self.tokenizer_2 is not None:
text_tokens_and_mask_2 = self.tokenizer_2(
text,
max_length=self.tokenizer_2.model_max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors='pt'
)
input_ids_2 = text_tokens_and_mask_2['input_ids'] # 1, l
cond_mask_2 = text_tokens_and_mask_2['attention_mask'] # 1, l
return dict(
pixel_values=image, input_ids_1=input_ids_1, cond_mask_1=cond_mask_1,
input_ids_2=input_ids_2, cond_mask_2=cond_mask_2
)
def define_frame_index(self, data):
shape_idx_dict = {}
new_cap_list = []
sample_size = []
aesthetic_score = []
cnt_vid = 0
cnt_img = 0
cnt_too_long = 0
cnt_too_short = 0
cnt_no_cap = 0
cnt_no_resolution = 0
cnt_no_aesthetic = 0
cnt_img_res_mismatch_stride = 0
cnt_vid_res_mismatch_stride = 0
cnt_img_aspect_mismatch = 0
cnt_vid_aspect_mismatch = 0
cnt_img_res_too_small = 0
cnt_vid_res_too_small = 0
cnt_vid_after_filter = 0
cnt_img_after_filter = 0
cnt = 0
with open(data, 'r') as f:
folder_anno = [i.strip().split(',') for i in f.readlines() if len(i.strip()) > 0]
for sub_root, anno in tqdm(folder_anno):
print(f'Building {anno}...')
if anno.endswith('.json'):
with open(anno, 'r') as f:
sub_list = json.load(f)
elif anno.endswith('.pkl'):
with open(anno, "rb") as f:
sub_list = pickle.load(f)
for index, i in enumerate(tqdm(sub_list)):
cnt += 1
path = os.path.join(sub_root, i['path'])
i['path'] = path
if path.endswith('.mp4'):
cnt_vid += 1
elif path.endswith('.jpg'):
cnt_img += 1
# ======no aesthetic=====
if i.get('aesthetic', None) is None or i.get('aes', None) is None:
cnt_no_aesthetic += 1
else:
aesthetic_score.append(i.get('aesthetic', None) or i.get('aes', None))
# ======no caption=====
cap = i.get('cap', None)
if cap is None:
cnt_no_cap += 1
continue
# ======resolution mismatch=====
if i.get('resolution', None) is None:
cnt_no_resolution += 1
continue
else:
if i['resolution'].get('height', None) is None or i['resolution'].get('width', None) is None:
cnt_no_resolution += 1
continue
else:
height, width = i['resolution']['height'], i['resolution']['width']
if not self.force_resolution:
if height <= 0 or width <= 0:
cnt_no_resolution += 1
continue
tr_h, tr_w = maxhwresize(height, width, self.max_hxw)
_, _, sample_h, sample_w = get_params(tr_h, tr_w, self.hw_stride)
if sample_h <= 0 or sample_w <= 0:
if path.endswith('.mp4'):
cnt_vid_res_mismatch_stride += 1
elif path.endswith('.jpg'):
cnt_img_res_mismatch_stride += 1
continue
# filter min_hxw
if sample_h * sample_w < self.min_hxw:
if path.endswith('.mp4'):
cnt_vid_res_too_small += 1
elif path.endswith('.jpg'):
cnt_img_res_too_small += 1
continue
# filter aspect
is_pick = filter_resolution(
sample_h, sample_w, max_h_div_w_ratio=self.hw_aspect_thr, min_h_div_w_ratio=1/self.hw_aspect_thr
)
if not is_pick:
if path.endswith('.mp4'):
cnt_vid_aspect_mismatch += 1
elif path.endswith('.jpg'):
cnt_img_aspect_mismatch += 1
continue
i['resolution'].update(dict(sample_height=sample_h, sample_width=sample_w))
else:
aspect = self.max_height / self.max_width
is_pick = filter_resolution(
height, width, max_h_div_w_ratio=self.hw_aspect_thr*aspect, min_h_div_w_ratio=1/self.hw_aspect_thr*aspect
)
if not is_pick:
if path.endswith('.mp4'):
cnt_vid_aspect_mismatch += 1
elif path.endswith('.jpg'):
cnt_img_aspect_mismatch += 1
continue
sample_h, sample_w = self.max_height, self.max_width
i['resolution'].update(dict(sample_height=sample_h, sample_width=sample_w))
if path.endswith('.mp4'):
fps = i.get('fps', 24)
# max 5.0 and min 1.0 are just thresholds to filter some videos which have suitable duration.
if i['num_frames'] > self.too_long_factor * (self.num_frames * fps / self.train_fps * self.speed_factor): # too long video is not suitable for this training stage (self.num_frames)
cnt_too_long += 1
continue
# resample in case high fps, such as 50/60/90/144 -> train_fps(e.g, 24)
frame_interval = 1.0 if abs(fps - self.train_fps) < 0.1 else fps / self.train_fps
start_frame_idx = i.get('cut', [0])[0]
i['start_frame_idx'] = start_frame_idx
frame_indices = np.arange(start_frame_idx, start_frame_idx+i['num_frames'], frame_interval).astype(int)
frame_indices = frame_indices[frame_indices < start_frame_idx+i['num_frames']]
# comment out it to enable dynamic frames training
if len(frame_indices) < self.num_frames and torch.rand(1, generator=self.generator).item() < self.drop_short_ratio:
cnt_too_short += 1
continue
# too long video will be temporal-crop randomly
if len(frame_indices) > self.num_frames:
begin_index, end_index = self.temporal_sample(len(frame_indices))
frame_indices = frame_indices[begin_index: end_index]
# frame_indices = frame_indices[:self.num_frames] # head crop
# to find a suitable end_frame_idx, to ensure we do not need pad video
end_frame_idx = find_closest_y(
len(frame_indices), vae_stride_t=self.ae_stride_t, model_ds_t=self.sp_size
)
if end_frame_idx == -1: # too short that can not be encoded exactly by videovae
cnt_too_short += 1
continue
frame_indices = frame_indices[:end_frame_idx]
i['sample_frame_index'] = frame_indices.tolist()
new_cap_list.append(i)
cnt_vid_after_filter += 1
elif path.endswith('.jpg'): # image
cnt_img_after_filter += 1
i['sample_frame_index'] = [0]
new_cap_list.append(i)
else:
raise NameError(f"Unknown file extention {path.split('.')[-1]}, only support .mp4 for video and .jpg for image")
pre_define_shape = f"{len(i['sample_frame_index'])}x{sample_h}x{sample_w}"
sample_size.append(pre_define_shape)
# if shape_idx_dict.get(pre_define_shape, None) is None:
# shape_idx_dict[pre_define_shape] = [index]
# else:
# shape_idx_dict[pre_define_shape].append(index)
counter = Counter(sample_size)
counter_cp = counter
if not self.force_resolution and self.max_hxw is not None and self.min_hxw is not None:
assert all([np.prod(np.array(k.split('x')[1:]).astype(np.int32)) <= self.max_hxw for k in counter_cp.keys()])
assert all([np.prod(np.array(k.split('x')[1:]).astype(np.int32)) >= self.min_hxw for k in counter_cp.keys()])
len_before_filter_major = len(sample_size)
filter_major_num = 4 * self.total_batch_size
new_cap_list, sample_size = zip(*[[i, j] for i, j in zip(new_cap_list, sample_size) if counter[j] >= filter_major_num])
for idx, shape in enumerate(sample_size):
if shape_idx_dict.get(shape, None) is None:
shape_idx_dict[shape] = [idx]
else:
shape_idx_dict[shape].append(idx)
cnt_filter_minority = len_before_filter_major - len(sample_size)
counter = Counter(sample_size)
print(f'no_cap: {cnt_no_cap}, no_resolution: {cnt_no_resolution}\n'
f'too_long: {cnt_too_long}, too_short: {cnt_too_short}\n'
f'cnt_img_res_mismatch_stride: {cnt_img_res_mismatch_stride}, cnt_vid_res_mismatch_stride: {cnt_vid_res_mismatch_stride}\n'
f'cnt_img_res_too_small: {cnt_img_res_too_small}, cnt_vid_res_too_small: {cnt_vid_res_too_small}\n'
f'cnt_img_aspect_mismatch: {cnt_img_aspect_mismatch}, cnt_vid_aspect_mismatch: {cnt_vid_aspect_mismatch}\n'
f'cnt_filter_minority: {cnt_filter_minority}\n'
f'Counter(sample_size): {counter}\n'
f'cnt_vid: {cnt_vid}, cnt_vid_after_filter: {cnt_vid_after_filter}, use_ratio: {round(cnt_vid_after_filter/(cnt_vid+1e-6), 5)*100}%\n'
f'cnt_img: {cnt_img}, cnt_img_after_filter: {cnt_img_after_filter}, use_ratio: {round(cnt_img_after_filter/(cnt_img+1e-6), 5)*100}%\n'
f'before filter: {cnt}, after filter: {len(new_cap_list)}, use_ratio: {round(len(new_cap_list)/cnt, 5)*100}%')
# import ipdb;ipdb.set_trace()
if len(aesthetic_score) > 0:
stats_aesthetic = calculate_statistics(aesthetic_score)
print(f"before filter: {cnt}, after filter: {len(new_cap_list)}\n"
f"aesthetic_score: {len(aesthetic_score)}, cnt_no_aesthetic: {cnt_no_aesthetic}\n"
f"{len([i for i in aesthetic_score if i>=5.75])} > 5.75, 4.5 > {len([i for i in aesthetic_score if i<=4.5])}\n"
f"Mean: {stats_aesthetic['mean']}, Var: {stats_aesthetic['variance']}, Std: {stats_aesthetic['std_dev']}\n"
f"Min: {stats_aesthetic['min']}, Max: {stats_aesthetic['max']}")
return new_cap_list, sample_size, shape_idx_dict
def decord_read(self, video_data):
path = video_data['path']
predefine_frame_indice = video_data['sample_frame_index']
start_frame_idx = video_data['start_frame_idx']
clip_total_frames = video_data['num_frames']
fps = video_data['fps']
s_x, e_x, s_y, e_y = video_data.get('crop', [None, None, None, None])
predefine_num_frames = len(predefine_frame_indice)
# decord_vr = decord.VideoReader(path, ctx=decord.cpu(0), num_threads=1)
decord_vr = DecordDecoder(path)
frame_indices = self.get_actual_frame(
fps, start_frame_idx, clip_total_frames, path, predefine_num_frames, predefine_frame_indice
)
# video_data = decord_vr.get_batch(frame_indices).asnumpy()
# video_data = torch.from_numpy(video_data)
video_data = decord_vr.get_batch(frame_indices)
if video_data is not None:
video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W)
if s_y is not None:
video_data = video_data[:, :, s_y: e_y, s_x: e_x]
else:
raise ValueError(f'Get video_data {video_data}')
# del decord_vr
# gc.collect()
return video_data
def opencv_read(self, video_data):
path = video_data['path']
predefine_frame_indice = video_data['sample_frame_index']
start_frame_idx = video_data['start_frame_idx']
clip_total_frames = video_data['num_frames']
fps = video_data['fps']
s_x, e_x, s_y, e_y = video_data.get('crop', [None, None, None, None])
predefine_num_frames = len(predefine_frame_indice)
cv2_vr = cv2.VideoCapture(path)
if not cv2_vr.isOpened():
raise ValueError(f'can not open {path}')
frame_indices = self.get_actual_frame(
fps, start_frame_idx, clip_total_frames, path, predefine_num_frames, predefine_frame_indice
)
video_data = []
for frame_idx in frame_indices:
cv2_vr.set(1, frame_idx)
_, frame = cv2_vr.read()
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
video_data.append(torch.from_numpy(frame).permute(2, 0, 1))
cv2_vr.release()
video_data = torch.stack(video_data) # (T C H W)
if s_y is not None:
video_data = video_data[:, :, s_y: e_y, s_x: e_x]
return video_data
def get_actual_frame(self, fps, start_frame_idx, clip_total_frames, path, predefine_num_frames, predefine_frame_indice):
# resample in case high fps, such as 50/60/90/144 -> train_fps(e.g, 24)
frame_interval = 1.0 if abs(fps - self.train_fps) < 0.1 else fps / self.train_fps
frame_indices = np.arange(start_frame_idx, start_frame_idx+clip_total_frames, frame_interval).astype(int)
frame_indices = frame_indices[frame_indices < start_frame_idx+clip_total_frames]
# speed up
max_speed_factor = len(frame_indices) / self.num_frames
if self.speed_factor > 1 and max_speed_factor > 1:
# speed_factor = random.uniform(1.0, min(self.speed_factor, max_speed_factor))
speed_factor = min(self.speed_factor, max_speed_factor)
target_frame_count = int(len(frame_indices) / speed_factor)
speed_frame_idx = np.linspace(0, len(frame_indices) - 1, target_frame_count, dtype=int)
frame_indices = frame_indices[speed_frame_idx]
# too long video will be temporal-crop randomly
if len(frame_indices) > self.num_frames:
begin_index, end_index = self.temporal_sample(len(frame_indices))
frame_indices = frame_indices[begin_index: end_index]
# frame_indices = frame_indices[:self.num_frames] # head crop
# to find a suitable end_frame_idx, to ensure we do not need pad video
end_frame_idx = find_closest_y(
len(frame_indices), vae_stride_t=self.ae_stride_t, model_ds_t=self.sp_size
)
if end_frame_idx == -1: # too short that can not be encoded exactly by videovae
raise IndexError(f'video ({path}) has {clip_total_frames} frames, but need to sample {len(frame_indices)} frames ({frame_indices})')
frame_indices = frame_indices[:end_frame_idx]
if predefine_num_frames != len(frame_indices):
raise ValueError(f'video ({path}) predefine_num_frames ({predefine_num_frames}) ({predefine_frame_indice}) is not equal with frame_indices ({len(frame_indices)}) ({frame_indices})')
if len(frame_indices) < self.num_frames and self.drop_short_ratio >= 1:
raise IndexError(f'video ({path}) has {clip_total_frames} frames, but need to sample {len(frame_indices)} frames ({frame_indices})')
return frame_indices
================================================
FILE: opensora/dataset/transform.py
================================================
import torch
import random
import numbers
from torchvision.transforms import RandomCrop, RandomResizedCrop
import statistics
import numpy as np
import ftfy
import regex as re
import html
def _is_tensor_video_clip(clip):
if not torch.is_tensor(clip):
raise TypeError("clip should be Tensor. Got %s" % type(clip))
if not clip.ndimension() == 4:
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
return True
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
def crop(clip, i, j, h, w):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
"""
if len(clip.size()) != 4:
raise ValueError("clip should be a 4D tensor")
return clip[..., i: i + h, j: j + w]
def resize(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=True, antialias=True)
def resize_scale(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
H, W = clip.size(-2), clip.size(-1)
scale_ = target_size[0] / min(H, W)
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=True, antialias=True)
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
"""
Do spatial cropping and resizing to the video clip
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the cropped region.
w (int): Width of the cropped region.
size (tuple(int, int)): height and width of resized clip
Returns:
clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
clip = crop(clip, i, j, h, w)
clip = resize(clip, size, interpolation_mode)
return clip
def center_crop(clip, crop_size):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
th, tw = crop_size
if h < th or w < tw:
raise ValueError("height and width must be no smaller than crop_size")
i = int(round((h - th) / 2.0))
j = int(round((w - tw) / 2.0))
return crop(clip, i, j, th, tw)
def center_crop_using_short_edge(clip):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
if h < w:
th, tw = h, h
i = 0
j = int(round((w - tw) / 2.0))
else:
th, tw = w, w
i = int(round((h - th) / 2.0))
j = 0
return crop(clip, i, j, th, tw)
def center_crop_th_tw(clip, th, tw, top_crop):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
# import ipdb;ipdb.set_trace()
h, w = clip.size(-2), clip.size(-1)
tr = th / tw
if h / w > tr:
# hxw 720x1280 thxtw 320x640 hw_raito 9/16 > tr_ratio 8/16 newh=1280*320/640=640 neww=1280
new_h = int(w * tr)
new_w = w
else:
# hxw 720x1280 thxtw 480x640 hw_raito 9/16 < tr_ratio 12/16 newh=720 neww=720/(12/16)=960
# hxw 1080x1920 thxtw 720x1280 hw_raito 9/16 = tr_ratio 9/16 newh=1080 neww=1080/(9/16)=1920
new_h = h
new_w = int(h / tr)
i = 0 if top_crop else int(round((h - new_h) / 2.0))
j = int(round((w - new_w) / 2.0))
return crop(clip, i, j, new_h, new_w)
def random_shift_crop(clip):
'''
Slide along the long edge, with the short edge as crop size
'''
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
if h <= w:
long_edge = w
short_edge = h
else:
long_edge = h
short_edge = w
th, tw = short_edge, short_edge
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return crop(clip, i, j, th, tw)
def to_tensor(clip):
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
_is_tensor_video_clip(clip)
if not clip.dtype == torch.uint8:
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
# return clip.float().permute(3, 0, 1, 2) / 255.0
return clip.float() / 255.0
def to_tensor_after_resize(clip):
"""
Convert resized tensor to [0, 1]
Args:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W), but in [0, 1]
"""
_is_tensor_video_clip(clip)
# return clip.float().permute(3, 0, 1, 2) / 255.0
return clip.float() / 255.0
def normalize(clip, mean, std, inplace=False):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
mean (tuple): pixel RGB mean. Size is (3)
std (tuple): pixel standard deviation. Size is (3)
Returns:
normalized clip (torch.tensor): Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
if not inplace:
clip = clip.clone()
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
# print(mean)
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
return clip
def hflip(clip):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
Returns:
flipped clip (torch.tensor): Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
return clip.flip(-1)
class RandomCropVideo:
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: randomly cropped video clip.
size is (T, C, OH, OW)
"""
i, j, h, w = self.get_params(clip)
return crop(clip, i, j, h, w)
def get_params(self, clip):
h, w = clip.shape[-2:]
th, tw = self.size
if h < th or w < tw:
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
if w == tw and h == th:
return 0, 0, h, w
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return i, j, th, tw
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
def get_params(h, w, stride):
th, tw = h // stride * stride, w // stride * stride
i = (h - th) // 2
j = (w - tw) // 2
return i, j, th, tw
class SpatialStrideCropVideo:
def __init__(self, stride):
self.stride = stride
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: cropped video clip by stride.
size is (T, C, OH, OW)
"""
h, w = clip.shape[-2:]
i, j, h, w = get_params(h, w, self.stride)
return crop(clip, i, j, h, w)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(stride={self.stride})"
def longsideresize(h, w, size, skip_low_resolution):
if h <= size[0] and w <= size[1] and skip_low_resolution:
return h, w
if h / w > size[0] / size[1]:
# hxw 720x1280 size 320x640 hw_raito 9/16 > size_ratio 8/16 neww=320/720*1280=568 newh=320
w = int(size[0] / h * w)
h = size[0]
else:
# hxw 720x1280 size 480x640 hw_raito 9/16 < size_ratio 12/16 newh=640/1280*720=360 neww=640
# hxw 1080x1920 size 720x1280 hw_raito 9/16 = size_ratio 9/16 newh=1280/1920*1080=720 neww=1280
h = int(size[1] / w * h)
w = size[1]
return h, w
def maxhwresize(ori_height, ori_width, max_hxw):
if ori_height * ori_width > max_hxw:
scale_factor = np.sqrt(max_hxw / (ori_height * ori_width))
new_height = int(ori_height * scale_factor)
new_width = int(ori_width * scale_factor)
else:
new_height = ori_height
new_width = ori_width
return new_height, new_width
class LongSideResizeVideo:
'''
First use the long side,
then resize to the specified size
'''
def __init__(
self,
size,
skip_low_resolution=False,
interpolation_mode="bilinear",
):
self.size = size
self.skip_low_resolution = skip_low_resolution
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized video clip.
"""
_, _, h, w = clip.shape
tr_h, tr_w = longsideresize(h, w, self.size, self.skip_low_resolution)
if h == tr_h and w == tr_w:
return clip
resize_clip = resize(clip, target_size=(tr_h, tr_w),
interpolation_mode=self.interpolation_mode)
return resize_clip
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class MaxHWResizeVideo:
'''
First use the h*w,
then resize to the specified size
'''
def __init__(
self,
max_hxw,
interpolation_mode="bilinear",
):
self.max_hxw = max_hxw
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized video clip.
"""
_, _, h, w = clip.shape
tr_h, tr_w = maxhwresize(h, w, self.max_hxw)
if h == tr_h and w == tr_w:
return clip
resize_clip = resize(clip, target_size=(tr_h, tr_w),
interpolation_mode=self.interpolation_mode)
return resize_clip
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class CenterCropResizeVideo:
'''
First use the short side for cropping length,
center crop video, then resize to the specified size
'''
def __init__(
self,
size,
top_crop=False,
interpolation_mode="bilinear",
):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
self.top_crop = top_crop
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_center_crop = center_crop_th_tw(clip, self.size[0], self.size[1], top_crop=self.top_crop)
clip_center_crop_resize = resize(clip_center_crop, target_size=self.size,
interpolation_mode=self.interpolation_mode)
return clip_center_crop_resize
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class UCFCenterCropVideo:
'''
First scale to the specified size in equal proportion to the short edge,
then center cropping
'''
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
clip_center_crop = center_crop(clip_resize, self.size)
return clip_center_crop
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class KineticsRandomCropResizeVideo:
'''
Slide along the long edge, with the short edge as crop size. And resie to the desired size.
'''
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
clip_random_crop = random_shift_crop(clip)
clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
return clip_resize
class CenterCropVideo:
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_center_crop = center_crop(clip, self.size)
return clip_center_crop
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class NormalizeVideo:
"""
Normalize the video clip by mean subtraction and division by standard deviation
Args:
mean (3-tuple): pixel RGB mean
std (3-tuple): pixel RGB standard deviation
inplace (boolean): whether do in-place normalization
"""
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace
def __call__(self, clip):
"""
Args:
clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
"""
return normalize(clip, self.mean, self.std, self.inplace)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
class ToTensorVideo:
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
"""
def __init__(self):
pass
def __call__(self, clip):
"""
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
return to_tensor(clip)
def __repr__(self) -> str:
return self.__class__.__name__
class ToTensorAfterResize:
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
"""
def __init__(self):
pass
def __call__(self, clip):
"""
Args:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W), but in [0, 1]
"""
return to_tensor_after_resize(clip)
def __repr__(self) -> str:
return self.__class__.__name__
class RandomHorizontalFlipVideo:
"""
Flip the video clip along the horizontal direction with a given probability
Args:
p (float): probability of the clip being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Size is (T, C, H, W)
Return:
clip (torch.tensor): Size is (T, C, H, W)
"""
if random.random() < self.p:
clip = hflip(clip)
return clip
def __repr__(self) -> str:
return f"{self.__class__.__name__}(p={self.p})"
# ------------------------------------------------------------
# --------------------- Sampling ---------------------------
# ------------------------------------------------------------
class TemporalRandomCrop(object):
"""Temporally crop the given frame indices at a random location.
Args:
size (int): Desired length of frames will be seen in the model.
"""
def __init__(self, size):
self.size = size
def __call__(self, total_frames):
rand_end = max(0, total_frames - self.size - 1)
begin_index = random.randint(0, rand_end)
end_index = min(begin_index + self.size, total_frames)
return begin_index, end_index
class DynamicSampleDuration(object):
"""Temporally crop the given frame indices at a random location.
Args:
size (int): Desired length of frames will be seen in the model.
"""
def __init__(self, t_stride, extra_1):
self.t_stride = t_stride
self.extra_1 = extra_1
def __call__(self, t, h, w):
if self.extra_1:
t = t - 1
truncate_t_list = list(range(t+1))[t//2:][::self.t_stride] # need half at least
truncate_t = random.choice(truncate_t_list)
if self.extra_1:
truncate_t = truncate_t + 1
return 0, truncate_t
keywords = [
' man ', ' woman ', ' person ', ' people ', 'human',
' individual ', ' child ', ' kid ', ' girl ', ' boy ',
]
keywords += [i[:-1] + 's ' for i in keywords]
masking_notices = [
"Note: The faces in this image are blurred.",
"This image contains faces that have been pixelated.",
"Notice: Faces in this image are masked.",
"Please be aware that the faces in this image are obscured.",
"The faces in this image are hidden.",
"This is an image with blurred faces.",
"The faces in this image have been processed.",
"Attention: Faces in this image are not visible.",
"The faces in this image are partially blurred.",
"This image has masked faces.",
"Notice: The faces in this picture have been altered.",
"This is a picture with obscured faces.",
"The faces in this image are pixelated.",
"Please note, the faces in this image have been blurred.",
"The faces in this photo are hidden.",
"The faces in this picture have been masked.",
"Note: The faces in this picture are altered.",
"This is an image where faces are not clear.",
"Faces in this image have been obscured.",
"This picture contains masked faces.",
"The faces in this image are processed.",
"The faces in this picture are not visible.",
"Please be aware, the faces in this photo are pixelated.",
"The faces in this picture have been blurred.",
]
webvid_watermark_notices = [
"This video has a faint Shutterstock watermark in the center.",
"There is a slight Shutterstock watermark in the middle of this video.",
"The video contains a subtle Shutterstock watermark in the center.",
"This video features a light Shutterstock watermark at its center.",
"A faint Shutterstock watermark is present in the middle of this video.",
"There is a mild Shutterstock watermark at the center of this video.",
"This video has a slight Shutterstock watermark in the middle.",
"You can see a faint Shutterstock watermark in the center of this video.",
"A subtle Shutterstock watermark appears in the middle of this video.",
"This video includes a light Shutterstock watermark at its center.",
]
high_aesthetic_score_notices_video = [
"This video has a high aesthetic quality.",
"The beauty of this video is exceptional.",
"This video scores high in aesthetic value.",
"With its harmonious colors and balanced composition.",
"This video ranks highly for aesthetic quality",
"The artistic quality of this video is excellent.",
"This video is rated high for beauty.",
"The aesthetic quality of this video is impressive.",
"This video has a top aesthetic score.",
"The visual appeal of this video is outstanding.",
]
low_aesthetic_score_notices_video = [
"This video has a low aesthetic quality.",
"The beauty of this video is minimal.",
"This video scores low in aesthetic appeal.",
"The aesthetic quality of this video is below average.",
"This video ranks low for beauty.",
"The artistic quality of this video is lacking.",
"This video has a low score for aesthetic value.",
"The visual appeal of this video is low.",
"This video is rated low for beauty.",
"The aesthetic quality of this video is poor.",
]
high_aesthetic_score_notices_image = [
"This image has a high aesthetic quality.",
"The beauty of this image is exceptional",
"This photo scores high in aesthetic value.",
"With its harmonious colors and balanced composition.",
"This image ranks highly for aesthetic quality.",
"The artistic quality of this photo is excellent.",
"This image is rated high for beauty.",
"The aesthetic quality of this image is impressive.",
"This photo has a top aesthetic score.",
"The visual appeal of this image is outstanding.",
]
low_aesthetic_score_notices_image = [
"This image has a low aesthetic quality.",
"The beauty of this image is minimal.",
"This image scores low in aesthetic appeal.",
"The aesthetic quality of this image is below average.",
"This image ranks low for beauty.",
"The artistic quality of this image is lacking.",
"This image has a low score for aesthetic value.",
"The visual appeal of this image is low.",
"This image is rated low for beauty.",
"The aesthetic quality of this image is poor.",
]
high_aesthetic_score_notices_image_human = [
"High-quality image with visible human features and high aesthetic score.",
"Clear depiction of an individual in a high-quality image with top aesthetics.",
"High-resolution photo showcasing visible human details and high beauty rating.",
"Detailed, high-quality image with well-defined human subject and strong aesthetic appeal.",
"Sharp, high-quality portrait with clear human features and high aesthetic value.",
"High-quality image featuring a well-defined human presence and exceptional aesthetics.",
"Visible human details in a high-resolution photo with a high aesthetic score.",
"Clear, high-quality image with prominent human subject and superior aesthetic rating.",
"High-quality photo capturing a visible human with excellent aesthetics.",
"Detailed, high-quality image of a human with high visual appeal and aesthetic value.",
]
def add_masking_notice(caption):
if any(keyword in caption for keyword in keywords):
notice = random.choice(masking_notices)
return random.choice([caption + ' ' + notice, notice + ' ' + caption])
return caption
def add_webvid_watermark_notice(caption):
notice = random.choice(webvid_watermark_notices)
return random.choice([caption + ' ' + notice, notice + ' ' + caption])
def add_aesthetic_notice_video(caption, aesthetic_score):
if aesthetic_score <= 4.25:
notice = random.choice(low_aesthetic_score_notices_video)
return random.choice([caption + ' ' + notice, notice + ' ' + caption])
if aesthetic_score >= 5.75:
notice = random.choice(high_aesthetic_score_notices_video)
return random.choice([caption + ' ' + notice, notice + ' ' + caption])
return caption
def add_aesthetic_notice_image(caption, aesthetic_score):
if aesthetic_score <= 4.25:
notice = random.choice(low_aesthetic_score_notices_image)
return random.choice([caption + ' ' + notice, notice + ' ' + caption])
if aesthetic_score >= 5.75:
notice = random.choice(high_aesthetic_score_notices_image)
return random.choice([caption + ' ' + notice, notice + ' ' + caption])
return caption
def add_high_aesthetic_notice_image(caption):
notice = random.choice(high_aesthetic_score_notices_image)
return random.choice([caption + ' ' + notice, notice + ' ' + caption])
def add_high_aesthetic_notice_image_human(caption):
notice = random.choice(high_aesthetic_score_notices_image_human)
return random.choice([caption + ' ' + notice, notice + ' ' + caption])
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def clean_youtube(text, is_tags=False):
text = text.lower() + ' '
text = re.sub(
r'#video|video|#shorts|shorts| shorts|#short| short|#youtubeshorts|youtubeshorts|#youtube| youtube|#shortsyoutube|#ytshorts|ytshorts|#ytshort|#shortvideo|shortvideo|#shortsfeed|#tiktok|tiktok|#tiktokchallenge|#myfirstshorts|#myfirstshort|#viral|viralvideo|viral|#viralshorts|#virlshort|#ytviralshorts|#instagram',
' ', text)
text = re.sub(r' s |short|youtube|virlshort|#', ' ', text)
pattern = r'[^a-zA-Z0-9\s\.,;:?!\'\"|]'
if is_tags:
pattern = r'[^a-zA-Z0-9\s]'
text = re.sub(pattern, '', text)
text = whitespace_clean(basic_clean(text))
return text
def clean_vidal(text):
title_hashtags = text.split('#')
title, hashtags = title_hashtags[0], '#' + '#'.join(title_hashtags[1:])
title = clean_youtube(title)
hashtags = clean_youtube(hashtags, is_tags=True)
text = title + ', ' + hashtags
if text == '' or text.isspace():
raise ValueError('text is empty')
return text
def calculate_statistics(data):
if len(data) == 0:
return None
data = np.array(data)
mean = np.mean(data)
variance = np.var(data)
std_dev = np.std(data)
minimum = np.min(data)
maximum = np.max(data)
return {
'mean': mean,
'variance': variance,
'std_dev': std_dev,
'min': minimum,
'max': maximum
}
if __name__ == '__main__':
from torchvision import transforms
import torchvision.io as io
import numpy as np
from torchvision.utils import save_image
import os
vframes, aframes, info = io.read_video(
filename='./v_Archery_g01_c03.avi',
pts_unit='sec',
output_format='TCHW'
)
trans = transforms.Compose([
ToTensorVideo(),
RandomHorizontalFlipVideo(),
UCFCenterCropVideo(512),
# NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
target_video_len = 32
frame_interval = 1
total_frames = len(vframes)
print(total_frames)
temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
# Sampling video frames
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
# print(start_frame_ind)
# print(end_frame_ind)
assert end_frame_ind - start_frame_ind >= target_video_len
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
print(frame_indice)
select_vframes = vframes[frame_indice]
print(select_vframes.shape)
print(select_vframes.dtype)
select_vframes_trans = trans(select_vframes)
print(select_vframes_trans.shape)
print(select_vframes_trans.dtype)
select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
print(select_vframes_trans_int.dtype)
print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)
io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
for i in range(target_video_len):
save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True,
value_range=(-1, 1))
================================================
FILE: opensora/dataset/virtual_disk.py
================================================
import subprocess
import json
import pickle
from collections import OrderedDict
from opensora.npu_config import npu_config
import sys
import os
class SuppressStdout:
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(SuppressStdout, cls).__new__(cls, *args, **kwargs)
return cls._instance
def __enter__(self):
self._original_stdout = sys.stdout
sys.stdout = open(os.devnull, 'w')
def __exit__(self, exc_type, exc_value, traceback):
sys.stdout.close()
sys.stdout = self._original_stdout
# 创建单例
class ObsConnection:
"""
AK, SK, STS_TOKEN临时密钥有效时效云计算网站最长为24h
buckets & object: https://uconsole.ccaicc.com/#/mgt/modelarts -> 对象控制台
keys & tokens: https://uconsole.ccaicc.com/#/mgt/modelarts -> 对象控制台 -> 获取访问密匙(AK 和 SK)
"""
def __init__(self):
with open(f"{npu_config.work_path}/scripts/train_data/key.json", "r") as f:
key = json.load(f)
self.AK = key["AK"]
self.SK = key["SK"]
self.endpoint = key["EP"]
self.bucket = "sora"
self.suppress_stdout = SuppressStdout()
def connect(self, obs):
config_command = [
obs, 'config',
'-i=' + self.AK,
'-k=' + self.SK,
'-e=' + self.endpoint
]
result = subprocess.run(config_command, capture_output=True, text=True)
if result.returncode != 0:
print(f"Failed to configure obsutil: {result.stderr}")
else:
print("Successfully configured obsutil")
class VirtualDisk:
"""
:param storage_dir: 内存虚拟磁盘的挂载点路径。
:param size: 内存虚拟磁盘的大小,例如 '1G'。
:param obs: linux 系统里面obs具体位置
:param connection: 抽象出obs连接管理
"""
def __init__(self, storage_dir, size="1G", obs="/home/opensora/obsutil_linux_arm64_5.5.12/obsutil"):
self.obs = obs
self.connection = ObsConnection()
self.connection.connect(obs)
os.makedirs(storage_dir, exist_ok=True)
self.storage_dir = storage_dir
self.size = self._convert_size_to_bytes(size)
if not self.is_tmpfs_mounted():
self.create_ramdisk()
else:
print(f"{self.storage_dir} is already mounted as tmpfs.")
self.index_file = os.path.join(self.storage_dir, 'index.pkl')
self.index = self.load_index()
self.lru = OrderedDict()
self.current_size = self.get_total_storage_size() # 初始化时计算总大小
def _convert_size_to_bytes(self, size):
unit = size[-1].upper()
size_value = int(size[:-1])
if unit == 'K':
return size_value * 1024
elif unit == 'M':
return size_value * 1024 ** 2
elif unit == 'G':
return size_value * 1024 ** 3
else:
raise ValueError("Invalid size unit. Use K, M, or G.")
"""
创建并挂载一个 tmpfs 类型的内存虚拟磁盘。
"""
def create_ramdisk(self):
try:
# 如果挂载点目录不存在,创建它
if not os.path.exists(self.storage_dir):
os.makedirs(self.storage_dir)
# 挂载 tmpfs 到挂载点
subprocess.run(['sudo', 'mount', '-t', 'tmpfs', '-o', f'size={self.size}', 'tmpfs', self.storage_dir], check=True)
print(f"Successfully mounted tmpfs on {self.storage_dir} with size {self.size}.")
except subprocess.CalledProcessError as e:
print(f"Failed to mount tmpfs: {e}")
except Exception as e:
print(f"An error occurred: {e}")
def load_index(self):
"""
加载索引文件。
:return: 索引字典。
"""
if os.path.exists(self.index_file):
with open(self.index_file, 'rb') as f:
return pickle.load(f)
return {}
def save_index(self):
"""
保存索引文件。
"""
with open(self.index_file, 'wb') as f:
pickle.dump(self.index, f)
"""
取消挂载内存虚拟磁盘。
:param storage_dir: 内存虚拟磁盘的挂载点路径。
"""
def unmount_ramdisk(self):
try:
# 确保没有进程在使用挂载点后取消挂载
subprocess.run(['sudo', 'umount', self.storage_dir], check=True)
print(f"Successfully unmounted tmpfs from {self.storage_dir}.")
except subprocess.CalledProcessError as e:
print(f"Failed to unmount tmpfs: {e}")
except Exception as e:
print(f"An error occurred: {e}")
"""
检查挂载点是否已经被挂载为 tmpfs。
:param storage_dir: 挂载点路径。
:return: 如果已挂载为 tmpfs,返回 True;否则返回 False。
"""
def is_tmpfs_mounted(self):
try:
result = subprocess.run(['mountpoint', '-q', self.storage_dir], check=False)
if result.returncode == 0:
return True
return False
except Exception as e:
print(f"An error occurred while checking if tmpfs is mounted: {e}")
return False
def get_data(self, key):
"""
获取存储在本地磁盘上的数据。如果数据不存在,通过 obsutil 从远端获取并存储。
:param key: 数据的唯一键。
:return: 数据。
"""
# if key in self.index:
# data_file = self.index[key]
# if os.path.exists(data_file):
# self.lru.move_to_end(key)
# with open(data_file, 'rb') as f:
# # print(f"Successfully get {key} from local")
# return pickle.load(f)
# 如果数据不存在,使用 obsutil 从远端获取
object_name = key # 假设 key 对应于远端对象名称
local_path = os.path.join(self.storage_dir, key)
with self.connection.suppress_stdout:
self.download_and_convert_to_pickle(self.connection.bucket, object_name, local_path)
# 保存数据的位置
# self.index[key] = local_path
# self.save_index()
# self.lru[key] = local_path
#
# file_size = os.path.getsize(local_path)
# self.current_size += file_size
# self.ensure_storage_limit()
return local_path
def del_data(self, local_path):
os.remove(local_path)
def download_and_convert_to_pickle(self, bucket, object_name, local_path):
"""
使用 obsutil 从 OBS 下载文件并转换为 pickle 格式存储到本地路径。
:param bucket: OBS 存储桶名称。
:param object_name: OBS 中的对象名称。
:param local_path: 本地文件路径。
"""
# try:
# 下载文件到local_path路径
subprocess.run([self.obs, 'cp', f'obs://{bucket}/{object_name}', local_path], check=True)
# print(f"Successfully downloaded obs://{bucket}/{object_name} to {local_path}.")
# except subprocess.CalledProcessError as e:
# print(f"Failed to download obs://{bucket}/{object_name} to {local_path}: {e}")
def ensure_storage_limit(self):
"""
确保存储总大小不超过虚拟磁盘大小,超出时根据LRU策略删除最旧的文件。
"""
while self.current_size > self.size:
oldest_key, oldest_path = self.lru.popitem(last=False)
file_size = os.path.getsize(oldest_path)
os.remove(oldest_path)
del self.index[oldest_key]
self.save_index()
print(f"Removed {oldest_key} to free up {file_size} bytes.")
self.current_size -= file_size
def get_total_storage_size(self):
"""
获取当前所有存储文件的总大小。
:return: 总大小(字节)。
"""
total_size = 0
for path in self.lru.values():
if os.path.exists(path):
total_size += os.path.getsize(path)
return total_size
================================================
FILE: opensora/models/__init__.py
================================================
from .causalvideovae import CausalVAEModelWrapper, WFVAEModelWrapper
================================================
FILE: opensora/models/causalvideovae/__init__.py
================================================
from torchvision.transforms import Lambda
from .model.vae import CausalVAEModel, WFVAEModel
from einops import rearrange
import torch
try:
import torch_npu
from opensora.npu_config import npu_config
except:
torch_npu = None
npu_config = None
pass
import torch.nn as nn
import torch
class CausalVAEModelWrapper(nn.Module):
def __init__(self, model_path, subfolder=None, cache_dir=None, use_ema=False, **kwargs):
super(CausalVAEModelWrapper, self).__init__()
self.vae = CausalVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir, **kwargs)
def encode(self, x):
x = self.vae.encode(x).sample().mul_(0.18215)
return x
def decode(self, x):
x = self.vae.decode(x / 0.18215)
x = rearrange(x, 'b c t h w -> b t c h w').contiguous()
return x
def dtype(self):
return self.vae.dtype
class WFVAEModelWrapper(nn.Module):
def __init__(self, model_path, subfolder=None, cache_dir=None, **kwargs):
super(WFVAEModelWrapper, self).__init__()
self.vae = WFVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir, **kwargs)
self.register_buffer('shift', torch.tensor(self.vae.config.shift)[None, :, None, None, None])
self.register_buffer('scale', torch.tensor(self.vae.config.scale)[None, :, None, None, None])
def encode(self, x):
x = (self.vae.encode(x).sample() - self.shift.to(x.device, dtype=x.dtype)) * self.scale.to(x.device, dtype=x.dtype)
return x
def decode(self, x):
x = x / self.scale.to(x.device, dtype=x.dtype) + self.shift.to(x.device, dtype=x.dtype)
x = self.vae.decode(x)
x = rearrange(x, 'b c t h w -> b t c h w').contiguous()
return x
def dtype(self):
return self.vae.dtype
ae_wrapper = {
'CausalVAEModel_D4_2x8x8': CausalVAEModelWrapper,
'CausalVAEModel_D8_2x8x8': CausalVAEModelWrapper,
'CausalVAEModel_D4_4x8x8': CausalVAEModelWrapper,
'CausalVAEModel_D8_4x8x8': CausalVAEModelWrapper,
'WFVAEModel_D8_4x8x8': WFVAEModelWrapper,
'WFVAEModel_D16_4x8x8': WFVAEModelWrapper,
'WFVAEModel_D32_4x8x8': WFVAEModelWrapper,
'WFVAEModel_D32_8x8x8': WFVAEModelWrapper,
}
ae_stride_config = {
'CausalVAEModel_D4_2x8x8': [2, 8, 8],
'CausalVAEModel_D8_2x8x8': [2, 8, 8],
'CausalVAEModel_D4_4x8x8': [4, 8, 8],
'CausalVAEModel_D8_4x8x8': [4, 8, 8],
'WFVAEModel_D8_4x8x8': [4, 8, 8],
'WFVAEModel_D16_4x8x8': [4, 8, 8],
'WFVAEModel_D32_4x8x8': [4, 8, 8],
'WFVAEModel_D32_8x8x8': [8, 8, 8],
}
ae_channel_config = {
'CausalVAEModel_D4_2x8x8': 4,
'CausalVAEModel_D8_2x8x8': 8,
'CausalVAEModel_D4_4x8x8': 4,
'CausalVAEModel_D8_4x8x8': 8,
'WFVAEModel_D8_4x8x8': 8,
'WFVAEModel_D16_4x8x8': 16,
'WFVAEModel_D32_4x8x8': 32,
'WFVAEModel_D32_8x8x8': 32,
}
ae_denorm = {
'CausalVAEModel_D4_2x8x8': lambda x: (x + 1.) / 2.,
'CausalVAEModel_D8_2x8x8': lambda x: (x + 1.) / 2.,
'CausalVAEModel_D4_4x8x8': lambda x: (x + 1.) / 2.,
'CausalVAEModel_D8_4x8x8': lambda x: (x + 1.) / 2.,
'WFVAEModel_D8_4x8x8': lambda x: (x + 1.) / 2.,
'WFVAEModel_D16_4x8x8': lambda x: (x + 1.) / 2.,
'WFVAEModel_D32_4x8x8': lambda x: (x + 1.) / 2.,
'WFVAEModel_D32_8x8x8': lambda x: (x + 1.) / 2.,
}
ae_norm = {
'CausalVAEModel_D4_2x8x8': Lambda(lambda x: 2. * x - 1.),
'CausalVAEModel_D8_2x8x8': Lambda(lambda x: 2. * x - 1.),
'CausalVAEModel_D4_4x8x8': Lambda(lambda x: 2. * x - 1.),
'CausalVAEModel_D8_4x8x8': Lambda(lambda x: 2. * x - 1.),
'WFVAEModel_D8_4x8x8': Lambda(lambda x: 2. * x - 1.),
'WFVAEModel_D16_4x8x8': Lambda(lambda x: 2. * x - 1.),
'WFVAEModel_D32_4x8x8': Lambda(lambda x: 2. * x - 1.),
'WFVAEModel_D32_8x8x8': Lambda(lambda x: 2. * x - 1.),
}
================================================
FILE: opensora/models/causalvideovae/dataset/__init__.py
================================================
================================================
FILE: opensora/models/causalvideovae/dataset/ddp_sampler.py
================================================
import math
from typing import TypeVar, Optional, Iterator
import torch
from torch.utils.data import Sampler, Dataset
import torch.distributed as dist
T_co = TypeVar('T_co', covariant=True)
class CustomDistributedSampler(Sampler[T_co]):
r"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each
process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a
:class:`~torch.utils.data.DataLoader` sampler, and load a subset of the
original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size and that any instance of it always
returns the same elements in the same order.
Args:
dataset: Dataset used for sampling.
num_replicas (int, optional): Number of processes participating in
distributed training. By default, :attr:`world_size` is retrieved from the
current distributed group.
rank (int, optional): Rank of the current process within :attr:`num_replicas`.
By default, :attr:`rank` is retrieved from the current distributed
group.
shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
indices.
seed (int, optional): random seed used to shuffle the sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. Default: ``0``.
drop_last (bool, optional): if ``True``, then the sampler will drop the
tail of the data to make it evenly divisible across the number of
replicas. If ``False``, the sampler will add extra indices to make
the data evenly divisible across the replicas. Default: ``False``.
.. warning::
In distributed mode, calling the :meth:`set_epoch` method at
the beginning of each epoch **before** creating the :class:`DataLoader` iterator
is necessary to make shuffling work properly across multiple epochs. Otherwise,
the same ordering will be always used.
Example::
>>> # xdoctest: +SKIP
>>> sampler = DistributedSampler(dataset) if is_distributed else None
>>> loader = DataLoader(dataset, shuffle=(sampler is None),
... sampler=sampler)
>>> for epoch in range(start_epoch, n_epochs):
... if is_distributed:
... sampler.set_epoch(epoch)
... train(loader)
"""
def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
rank: Optional[int] = None, shuffle: bool = True,
seed: int = 0, drop_last: bool = False) -> None:
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
if rank >= num_replicas or rank < 0:
raise ValueError(
f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]")
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.current_index = 0
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
def __iter__(self) -> Iterator[T_co]:
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
while self.current_index < len(indices):
yield indices[self.current_index]
self.current_index += 1
self.current_index = 0
def __len__(self) -> int:
return self.num_samples
def set_epoch(self, epoch: int) -> None:
r"""
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
use a different random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
def state_dict(self) -> dict:
return {
'epoch': self.epoch,
'seed': self.seed,
'current_index': self.current_index
}
def load_state_dict(self, state_dict: dict) -> None:
self.epoch = state_dict['epoch']
self.seed = state_dict['seed']
self.current_index = state_dict.get('current_index', 0)
================================================
FILE: opensora/models/causalvideovae/dataset/transform.py
================================================
import torch
import random
import numbers
from torchvision.transforms import RandomCrop, RandomResizedCrop
def _is_tensor_video_clip(clip):
if not torch.is_tensor(clip):
raise TypeError("clip should be Tensor. Got %s" % type(clip))
if not clip.ndimension() == 4:
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
return True
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
def crop(clip, i, j, h, w):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
"""
if len(clip.size()) != 4:
raise ValueError("clip should be a 4D tensor")
return clip[..., i: i + h, j: j + w]
def resize(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=True, antialias=True)
def resize_scale(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
H, W = clip.size(-2), clip.size(-1)
scale_ = target_size[0] / min(H, W)
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=True, antialias=True)
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
"""
Do spatial cropping and resizing to the video clip
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the cropped region.
w (int): Width of the cropped region.
size (tuple(int, int)): height and width of resized clip
Returns:
clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
clip = crop(clip, i, j, h, w)
clip = resize(clip, size, interpolation_mode)
return clip
def center_crop(clip, crop_size):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
th, tw = crop_size
if h < th or w < tw:
raise ValueError("height and width must be no smaller than crop_size")
i = int(round((h - th) / 2.0))
j = int(round((w - tw) / 2.0))
return crop(clip, i, j, th, tw)
def center_crop_using_short_edge(clip):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
if h < w:
th, tw = h, h
i = 0
j = int(round((w - tw) / 2.0))
else:
th, tw = w, w
i = int(round((h - th) / 2.0))
j = 0
return crop(clip, i, j, th, tw)
def random_shift_crop(clip):
'''
Slide along the long edge, with the short edge as crop size
'''
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
if h <= w:
long_edge = w
short_edge = h
else:
long_edge = h
short_edge = w
th, tw = short_edge, short_edge
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return crop(clip, i, j, th, tw)
def to_tensor(clip):
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
_is_tensor_video_clip(clip)
if not clip.dtype == torch.uint8:
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
# return clip.float().permute(3, 0, 1, 2) / 255.0
return clip.float() / 255.0
def normalize(clip, mean, std, inplace=False):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
mean (tuple): pixel RGB mean. Size is (3)
std (tuple): pixel standard deviation. Size is (3)
Returns:
normalized clip (torch.tensor): Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
if not inplace:
clip = clip.clone()
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
# print(mean)
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
return clip
def hflip(clip):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
Returns:
flipped clip (torch.tensor): Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
return clip.flip(-1)
class RandomCropVideo:
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: randomly cropped video clip.
size is (T, C, OH, OW)
"""
i, j, h, w = self.get_params(clip)
return crop(clip, i, j, h, w)
def get_params(self, clip):
h, w = clip.shape[-2:]
th, tw = self.size
if h < th or w < tw:
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
if w == tw and h == th:
return 0, 0, h, w
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return i, j, th, tw
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
class SpatialStrideCropVideo:
def __init__(self, stride):
self.stride = stride
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: cropped video clip by stride.
size is (T, C, OH, OW)
"""
i, j, h, w = self.get_params(clip)
return crop(clip, i, j, h, w)
def get_params(self, clip):
h, w = clip.shape[-2:]
th, tw = h // self.stride * self.stride, w // self.stride * self.stride
return 0, 0, th, tw # from top-left
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
class LongSideResizeVideo:
'''
First use the long side,
then resize to the specified size
'''
def __init__(
self,
size,
skip_low_resolution=False,
interpolation_mode="bilinear",
):
self.size = size
self.skip_low_resolution = skip_low_resolution
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized video clip.
size is (T, C, 512, *) or (T, C, *, 512)
"""
_, _, h, w = clip.shape
if self.skip_low_resolution and max(h, w) <= self.size:
return clip
if h > w:
w = int(w * self.size / h)
h = self.size
else:
h = int(h * self.size / w)
w = self.size
resize_clip = resize(clip, target_size=(h, w),
interpolation_mode=self.interpolation_mode)
return resize_clip
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class CenterCropResizeVideo:
'''
First use the short side for cropping length,
center crop video, then resize to the specified size
'''
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_center_crop = center_crop_using_short_edge(clip)
clip_center_crop_resize = resize(clip_center_crop, target_size=self.size,
interpolation_mode=self.interpolation_mode)
return clip_center_crop_resize
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class UCFCenterCropVideo:
'''
First scale to the specified size in equal proportion to the short edge,
then center cropping
'''
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
clip_center_crop = center_crop(clip_resize, self.size)
return clip_center_crop
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class KineticsRandomCropResizeVideo:
'''
Slide along the long edge, with the short edge as crop size. And resie to the desired size.
'''
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
clip_random_crop = random_shift_crop(clip)
clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
return clip_resize
class CenterCropVideo:
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_center_crop = center_crop(clip, self.size)
return clip_center_crop
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class NormalizeVideo:
"""
Normalize the video clip by mean subtraction and division by standard deviation
Args:
mean (3-tuple): pixel RGB mean
std (3-tuple): pixel RGB standard deviation
inplace (boolean): whether do in-place normalization
"""
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace
def __call__(self, clip):
"""
Args:
clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
"""
return normalize(clip, self.mean, self.std, self.inplace)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
class ToTensorVideo:
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
"""
def __init__(self):
pass
def __call__(self, clip):
"""
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
return to_tensor(clip)
def __repr__(self) -> str:
return self.__class__.__name__
class RandomHorizontalFlipVideo:
"""
Flip the video clip along the horizontal direction with a given probability
Args:
p (float): probability of the clip being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Size is (T, C, H, W)
Return:
clip (torch.tensor): Size is (T, C, H, W)
"""
if random.random() < self.p:
clip = hflip(clip)
return clip
def __repr__(self) -> str:
return f"{self.__class__.__name__}(p={self.p})"
# ------------------------------------------------------------
# --------------------- Sampling ---------------------------
# ------------------------------------------------------------
class TemporalRandomCrop(object):
"""Temporally crop the given frame indices at a random location.
Args:
size (int): Desired length of frames will be seen in the model.
"""
def __init__(self, size):
self.size = size
def __call__(self, total_frames):
rand_end = max(0, total_frames - self.size - 1)
begin_index = random.randint(0, rand_end)
end_index = min(begin_index + self.size, total_frames)
return begin_index, end_index
class DynamicSampleDuration(object):
"""Temporally crop the given frame indices at a random location.
Args:
size (int): Desired length of frames will be seen in the model.
"""
def __init__(self, t_stride, extra_1):
self.t_stride = t_stride
self.extra_1 = extra_1
def __call__(self, t, h, w):
if self.extra_1:
t = t - 1
truncate_t_list = list(range(t+1))[t//2:][::self.t_stride] # need half at least
truncate_t = random.choice(truncate_t_list)
if self.extra_1:
truncate_t = truncate_t + 1
return 0, truncate_t
if __name__ == '__main__':
from torchvision import transforms
import torchvision.io as io
import numpy as np
from torchvision.utils import save_image
import os
vframes, aframes, info = io.read_video(
filename='./v_Archery_g01_c03.avi',
pts_unit='sec',
output_format='TCHW'
)
trans = transforms.Compose([
ToTensorVideo(),
RandomHorizontalFlipVideo(),
UCFCenterCropVideo(512),
# NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
target_video_len = 32
frame_interval = 1
total_frames = len(vframes)
print(total_frames)
temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
# Sampling video frames
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
# print(start_frame_ind)
# print(end_frame_ind)
assert end_frame_ind - start_frame_ind >= target_video_len
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
print(frame_indice)
select_vframes = vframes[frame_indice]
print(select_vframes.shape)
print(select_vframes.dtype)
select_vframes_trans = trans(select_vframes)
print(select_vframes_trans.shape)
print(select_vframes_trans.dtype)
select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
print(select_vframes_trans_int.dtype)
print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)
io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
for i in range(target_video_len):
save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True,
value_range=(-1, 1))
================================================
FILE: opensora/models/causalvideovae/dataset/video_dataset.py
================================================
import os.path as osp
import random
from glob import glob
from torchvision import transforms
import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F
import pickle
import decord
from torch.nn import functional as F
from .transform import ToTensorVideo, CenterCropVideo
from torchvision.transforms._transforms_video import CenterCropVideo as TVCenterCropVideo
from torchvision.transforms import Lambda, Compose, Resize
import torch
import os
class DecordInit(object):
def __init__(self, num_threads=1):
self.num_threads = num_threads
self.ctx = decord.cpu(0)
def __call__(self, filename):
reader = decord.VideoReader(
filename, ctx=self.ctx, num_threads=self.num_threads
)
return reader
def __repr__(self):
repr_str = (
f"{self.__class__.__name__}("
f"sr={self.sr},"
f"num_threads={self.num_threads})"
)
return repr_str
def TemporalRandomCrop(total_frames, size):
rand_end = max(0, total_frames - size - 1)
begin_index = random.randint(0, rand_end)
end_index = min(begin_index + size, total_frames)
return begin_index, end_index
def _format_video_shape(video, time_compress=4, spatial_compress=8):
"""Prepare video for VAE"""
time = video.shape[1]
height = video.shape[2]
width = video.shape[3]
new_time = (
(time - (time - 1) % time_compress) if (time - 1) % time_compress != 0 else time
)
new_height = (
(height - (height) % spatial_compress)
if height % spatial_compress != 0
else height
)
new_width = (
(width - (width) % spatial_compress) if width % spatial_compress != 0 else width
)
return video[:, :new_time, :new_height, :new_width]
class TrainVideoDataset(data.Dataset):
video_exts = ["avi", "mp4", "webm"]
def __init__(
self,
video_folder,
sequence_length,
train=True,
resolution=64,
sample_rate=1,
dynamic_sample=True,
cache_file=None,
is_main_process=False,
):
self.train = train
self.sequence_length = sequence_length
self.sample_rate = sample_rate
self.resolution = resolution
self.v_decoder = DecordInit()
self.video_folder = video_folder
self.dynamic_sample = dynamic_sample
self.cache_file = cache_file
self.transform = transforms.Compose(
[
ToTensorVideo(),
Resize(self.resolution),
CenterCropVideo(self.resolution),
Lambda(lambda x: 2.0 * x - 1.0),
]
)
print("Building datasets...")
self.is_main_process = is_main_process
self.samples = self._make_dataset()
def _make_dataset(self):
cache_file = osp.join(self.video_folder, self.cache_file)
if osp.exists(cache_file):
with open(cache_file, "rb") as f:
samples = pickle.load(f)
else:
samples = []
samples += sum(
[
glob(osp.join(self.video_folder, "**", f"*.{ext}"), recursive=True)
for ext in self.video_exts
],
[],
)
if self.is_main_process:
with open(cache_file, "wb") as f:
pickle.dump(samples, f)
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
video_path = self.samples[idx]
try:
video = self.decord_read(video_path)
video = self.transform(video) # T C H W -> T C H W
video = video.transpose(0, 1) # T C H W -> C T H W
return dict(video=video, label="")
except Exception as e:
print(f"Error with {e}, {video_path}")
return self.__getitem__(random.randint(0, self.__len__() - 1))
def decord_read(self, path):
decord_vr = self.v_decoder(path)
total_frames = len(decord_vr)
# Sampling video frames
if self.dynamic_sample:
sample_rate = random.randint(1, self.sample_rate)
else:
sample_rate = self.sample_rate
size = self.sequence_length * sample_rate
start_frame_ind, end_frame_ind = TemporalRandomCrop(total_frames, size)
frame_indice = np.linspace(
start_frame_ind, end_frame_ind - 1, self.sequence_length, dtype=int
)
video_data = decord_vr.get_batch(frame_indice).asnumpy()
video_data = torch.from_numpy(video_data)
video_data = video_data.permute(0, 3, 1, 2)
return video_data
def resize(x, resolution):
height, width = x.shape[-2:]
aspect_ratio = width / height
if width <= height:
new_width = resolution
new_height = int(resolution / aspect_ratio)
else:
new_height = resolution
new_width = int(resolution * aspect_ratio)
resized_x = F.interpolate(x, size=(new_height, new_width), mode='bilinear', align_corners=True, antialias=True)
return resized_x
class ValidVideoDataset(data.Dataset):
video_exts = ["avi", "mp4", "webm"]
def __init__(
self,
real_video_dir,
num_frames,
sample_rate=1,
crop_size=None,
resolution=128,
is_main_process=False
) -> None:
super().__init__()
self.is_main_process = is_main_process
self.real_video_files = self._make_dataset(real_video_dir)
self.num_frames = num_frames
self.sample_rate = sample_rate
self.crop_size = crop_size
self.short_size = resolution
self.v_decoder = DecordInit()
self.transform = Compose(
[
ToTensorVideo(),
Resize(resolution),
CenterCropVideo(resolution) if crop_size is not None else Lambda(lambda x: x),
]
)
def _make_dataset(self, real_video_dir):
cache_file = osp.join(real_video_dir, "idx.pkl")
if osp.exists(cache_file):
with open(cache_file, "rb") as f:
samples = pickle.load(f)
else:
samples = []
samples += sum(
[
glob(osp.join(real_video_dir, "**", f"*.{ext}"), recursive=True)
for ext in self.video_exts
],
[],
)
if self.is_main_process:
with open(cache_file, "wb") as f:
pickle.dump(samples, f)
return samples
def __len__(self):
return len(self.real_video_files)
def __getitem__(self, index):
try:
if index >= len(self):
raise IndexError
real_video_file = self.real_video_files[index]
real_video_tensor = self._load_video(real_video_file)
real_video_tensor = self.transform(real_video_tensor)
video_name = os.path.basename(real_video_file)
return {'video': real_video_tensor, 'file_name': video_name }
except:
print(f"Video error: {self.real_video_files[index]}")
return self.__getitem__(0)
def _load_video(self, video_path, sample_rate=None):
num_frames = self.num_frames
if not sample_rate:
sample_rate = self.sample_rate
try:
decord_vr = self.v_decoder(video_path)
except:
raise Exception(f"fail to load {video_path}.")
total_frames = len(decord_vr)
sample_frames_len = sample_rate * num_frames
if total_frames >= sample_frames_len:
s = 0
e = s + sample_frames_len
num_frames = num_frames
else:
raise Exception("video too short!")
frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)
video_data = decord_vr.get_batch(frame_id_list).asnumpy()
video_data = torch.from_numpy(video_data)
video_data = video_data.permute(3, 0, 1, 2)
return video_data
================================================
FILE: opensora/models/causalvideovae/eval/cal_fvd.py
================================================
import numpy as np
import torch
from tqdm import tqdm
def trans(x):
# if greyscale images add channel
if x.shape[-3] == 1:
x = x.repeat(1, 1, 3, 1, 1)
# permute BTCHW -> BCTHW
x = x.permute(0, 2, 1, 3, 4)
return x
def calculate_fvd(videos1, videos2, device, method='styleganv'):
if method == 'styleganv':
from fvd.styleganv.fvd import get_fvd_feats, frechet_distance, load_i3d_pretrained
elif method == 'videogpt':
from fvd.videogpt.fvd import load_i3d_pretrained
from fvd.videogpt.fvd import get_fvd_logits as get_fvd_feats
from fvd.videogpt.fvd import frechet_distance
print("calculate_fvd...")
# videos [batch_size, timestamps, channel, h, w]
assert videos1.shape == videos2.shape
i3d = load_i3d_pretrained(device=device)
fvd_results = []
# support grayscale input, if grayscale -> channel*3
# BTCHW -> BCTHW
# videos -> [batch_size, channel, timestamps, h, w]
videos1 = trans(videos1)
videos2 = trans(videos2)
fvd_results = {}
# for calculate FVD, each clip_timestamp must >= 10
for clip_timestamp in tqdm(range(10, videos1.shape[-3]+1)):
# get a video clip
# videos_clip [batch_size, channel, timestamps[:clip], h, w]
videos_clip1 = videos1[:, :, : clip_timestamp]
videos_clip2 = videos2[:, :, : clip_timestamp]
# get FVD features
feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device)
feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device)
# calculate FVD when timestamps[:clip]
fvd_results[clip_timestamp] = frechet_distance(feats1, feats2)
result = {
"value": fvd_results,
"video_setting": videos1.shape,
"video_setting_name": "batch_size, channel, time, heigth, width",
}
return result
# test code / using example
def main():
NUMBER_OF_VIDEOS = 8
VIDEO_LENGTH = 50
CHANNEL = 3
SIZE = 64
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
device = torch.device("cuda")
# device = torch.device("cpu")
import json
result = calculate_fvd(videos1, videos2, device, method='videogpt')
print(json.dumps(result, indent=4))
result = calculate_fvd(videos1, videos2, device, method='styleganv')
print(json.dumps(result, indent=4))
if __name__ == "__main__":
main()
================================================
FILE: opensora/models/causalvideovae/eval/cal_lpips.py
================================================
import numpy as np
import torch
from tqdm import tqdm
import math
import torch
import lpips
spatial = True # Return a spatial map of perceptual distance.
# Linearly calibrated models (LPIPS)
loss_fn = lpips.LPIPS(net='alex', spatial=spatial) # Can also set net = 'squeeze' or 'vgg'
# loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg'
def trans(x):
# if greyscale images add channel
if x.shape[-3] == 1:
x = x.repeat(1, 1, 3, 1, 1)
# value range [0, 1] -> [-1, 1]
x = x * 2 - 1
return x
def calculate_lpips(videos1, videos2, device):
# image should be RGB, IMPORTANT: normalized to [-1,1]
print("calculate_lpips...")
assert videos1.shape == videos2.shape
# videos [batch_size, timestamps, channel, h, w]
# support grayscale input, if grayscale -> channel*3
# value range [0, 1] -> [-1, 1]
videos1 = trans(videos1)
videos2 = trans(videos2)
lpips_results = []
for video_num in tqdm(range(videos1.shape[0])):
# get a video
# video [timestamps, channel, h, w]
video1 = videos1[video_num]
video2 = videos2[video_num]
lpips_results_of_a_video = []
for clip_timestamp in range(len(video1)):
# get a img
# img [timestamps[x], channel, h, w]
# img [channel, h, w] tensor
img1 = video1[clip_timestamp].unsqueeze(0).to(device)
img2 = video2[clip_timestamp].unsqueeze(0).to(device)
loss_fn.to(device)
# calculate lpips of a video
lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist())
lpips_results.append(lpips_results_of_a_video)
lpips_results = np.array(lpips_results)
lpips = {}
lpips_std = {}
for clip_timestamp in range(len(video1)):
lpips[clip_timestamp] = np.mean(lpips_results[:,clip_timestamp])
lpips_std[clip_timestamp] = np.std(lpips_results[:,clip_timestamp])
result = {
"value": lpips,
"value_std": lpips_std,
"video_setting": video1.shape,
"video_setting_name": "time, channel, heigth, width",
}
return result
# test code / using example
def main():
NUMBER_OF_VIDEOS = 8
VIDEO_LENGTH = 50
CHANNEL = 3
SIZE = 64
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
device = torch.device("cuda")
# device = torch.device("cpu")
import json
result = calculate_lpips(videos1, videos2, device)
print(json.dumps(result, indent=4))
if __name__ == "__main__":
main()
================================================
FILE: opensora/models/causalvideovae/eval/cal_psnr.py
================================================
import numpy as np
import torch
from tqdm import tqdm
import math
def img_psnr_cuda(img1, img2):
# [0,1]
# compute mse
# mse = np.mean((img1-img2)**2)
mse = torch.mean((img1 / 1.0 - img2 / 1.0) ** 2)
# compute psnr
if mse < 1e-10:
return 100
psnr = 20 * torch.log10(1 / torch.sqrt(mse))
return psnr
def img_psnr(img1, img2):
# [0,1]
# compute mse
# mse = np.mean((img1-img2)**2)
mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2)
# compute psnr
if mse < 1e-10:
return 100
psnr = 20 * math.log10(1 / math.sqrt(mse))
return psnr
def trans(x):
return x
def calculate_psnr(videos1, videos2):
print("calculate_psnr...")
# videos [batch_size, timestamps, channel, h, w]
assert videos1.shape == videos2.shape
videos1 = trans(videos1)
videos2 = trans(videos2)
psnr_results = []
for video_num in tqdm(range(videos1.shape[0])):
# get a video
# video [timestamps, channel, h, w]
video1 = videos1[video_num]
video2 = videos2[video_num]
psnr_results_of_a_video = []
for clip_timestamp in range(len(video1)):
# get a img
# img [timestamps[x], channel, h, w]
# img [channel, h, w] numpy
img1 = video1[clip_timestamp].numpy()
img2 = video2[clip_timestamp].numpy()
# calculate psnr of a video
psnr_results_of_a_video.append(img_psnr(img1, img2))
psnr_results.append(psnr_results_of_a_video)
psnr_results = np.array(psnr_results) # [batch_size, num_frames]
psnr = {}
psnr_std = {}
for clip_timestamp in range(len(video1)):
psnr[clip_timestamp] = np.mean(psnr_results[:,clip_timestamp])
psnr_std[clip_timestamp] = np.std(psnr_results[:,clip_timestamp])
result = {
"value": psnr,
"value_std": psnr_std,
"video_setting": video1.shape,
"video_setting_name": "time, channel, heigth, width",
}
return result
# test code / using example
def main():
NUMBER_OF_VIDEOS = 8
VIDEO_LENGTH = 50
CHANNEL = 3
SIZE = 64
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
import json
result = calculate_psnr(videos1, videos2)
print(json.dumps(result, indent=4))
if __name__ == "__main__":
main()
================================================
FILE: opensora/models/causalvideovae/eval/cal_ssim.py
================================================
import numpy as np
import torch
from tqdm import tqdm
import cv2
def ssim(img1, img2):
C1 = 0.01 ** 2
C2 = 0.03 ** 2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1 ** 2
mu2_sq = mu2 ** 2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
def calculate_ssim_function(img1, img2):
# [0,1]
# ssim is the only metric extremely sensitive to gray being compared to b/w
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
if img1.ndim == 2:
return ssim(img1, img2)
elif img1.ndim == 3:
if img1.shape[0] == 3:
ssims = []
for i in range(3):
ssims.append(ssim(img1[i], img2[i]))
return np.array(ssims).mean()
elif img1.shape[0] == 1:
return ssim(np.squeeze(img1), np.squeeze(img2))
else:
raise ValueError('Wrong input image dimensions.')
def trans(x):
return x
def calculate_ssim(videos1, videos2):
print("calculate_ssim...")
# videos [batch_size, timestamps, channel, h, w]
assert videos1.shape == videos2.shape
videos1 = trans(videos1)
videos2 = trans(videos2)
ssim_results = []
for video_num in tqdm(range(videos1.shape[0])):
# get a video
# video [timestamps, channel, h, w]
video1 = videos1[video_num]
video2 = videos2[video_num]
ssim_results_of_a_video = []
for clip_timestamp in range(len(video1)):
# get a img
# img [timestamps[x], channel, h, w]
# img [channel, h, w] numpy
img1 = video1[clip_timestamp].numpy()
img2 = video2[clip_timestamp].numpy()
# calculate ssim of a video
ssim_results_of_a_video.append(calculate_ssim_function(img1, img2))
ssim_results.append(ssim_results_of_a_video)
ssim_results = np.array(ssim_results)
ssim = {}
ssim_std = {}
for clip_timestamp in range(len(video1)):
ssim[clip_timestamp] = np.mean(ssim_results[:,clip_timestamp])
ssim_std[clip_timestamp] = np.std(ssim_results[:,clip_timestamp])
result = {
"value": ssim,
"value_std": ssim_std,
"video_setting": video1.shape,
"video_setting_name": "time, channel, heigth, width",
}
return result
# test code / using example
def main():
NUMBER_OF_VIDEOS = 8
VIDEO_LENGTH = 50
CHANNEL = 3
SIZE = 64
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
device = torch.device("cuda")
import json
result = calculate_ssim(videos1, videos2)
print(json.dumps(result, indent=4))
if __name__ == "__main__":
main()
================================================
FILE: opensora/models/causalvideovae/eval/eval.py
================================================
import os
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
import sys
from glob import glob
sys.path.append(".")
from opensora.models.causalvideovae.eval.cal_lpips import calculate_lpips
from opensora.models.causalvideovae.eval.cal_fvd import calculate_fvd
from opensora.models.causalvideovae.eval.cal_psnr import calculate_psnr
from opensora.models.causalvideovae.eval.cal_ssim import calculate_ssim
from opensora.models.causalvideovae.dataset.video_dataset import (
ValidVideoDataset,
DecordInit,
Compose,
Lambda,
resize,
CenterCropVideo,
ToTensorVideo
)
class EvalDataset(ValidVideoDataset):
def __init__(
self,
real_video_dir,
generated_video_dir,
num_frames,
sample_rate=1,
crop_size=None,
resolution=128,
) -> None:
self.is_main_process = False
self.v_decoder = DecordInit()
self.real_video_files = []
self.generated_video_files = self._make_dataset(generated_video_dir)
for video_file in self.generated_video_files:
filename = os.path.basename(video_file)
if not os.path.exists(os.path.join(real_video_dir, filename)):
raise Exception(os.path.join(real_video_dir, filename))
self.real_video_files.append(os.path.join(real_video_dir, filename))
self.num_frames = num_frames
self.sample_rate = sample_rate
self.crop_size = crop_size
self.short_size = resolution
self.transform = Compose(
[
ToTensorVideo(),
Lambda(lambda x: resize(x, self.short_size)),
(
CenterCropVideo(crop_size)
if crop_size is not None
else Lambda(lambda x: x)
),
]
)
def _make_dataset(self, real_video_dir):
samples = []
samples += sum(
[
glob(os.path.join(real_video_dir, f"*.{ext}"), recursive=True)
for ext in self.video_exts
],
[],
)
return samples
def __len__(self):
return len(self.real_video_files)
def __getitem__(self, index):
if index >= len(self):
raise IndexError
real_video_file = self.real_video_files[index]
generated_video_file = self.generated_video_files[index]
real_video_tensor = self._load_video(real_video_file, self.sample_rate)
generated_video_tensor = self._load_video(generated_video_file, 1)
return {"real": self.transform(real_video_tensor), "generated": self.transform(generated_video_tensor)}
def calculate_common_metric(args, dataloader, device):
score_list = []
for batch_data in tqdm(dataloader):
real_videos = batch_data["real"].to(device)
generated_videos = batch_data["generated"].to(device)
assert real_videos.shape[2] == generated_videos.shape[2]
if args.metric == "fvd":
tmp_list = list(
calculate_fvd(
real_videos, generated_videos, args.device, method=args.fvd_method
)["value"].values()
)
elif args.metric == "ssim":
tmp_list = list(
calculate_ssim(real_videos, generated_videos)["value"].values()
)
elif args.metric == "psnr":
tmp_list = [calculate_psnr(real_videos, generated_videos)]
else:
tmp_list = [calculate_lpips(real_videos, generated_videos, args.device)]
score_list += tmp_list
return np.mean(score_list)
def main():
if args.device is None:
device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
else:
device = torch.device(args.device)
if args.num_workers is None:
try:
num_cpus = len(os.sched_getaffinity(0))
except AttributeError:
num_cpus = os.cpu_count()
num_workers = min(num_cpus, 8) if num_cpus is not None else 0
else:
num_workers = args.num_workers
dataset = EvalDataset(
args.real_video_dir,
args.generated_video_dir,
num_frames=args.num_frames,
sample_rate=args.sample_rate,
crop_size=args.crop_size,
resolution=args.resolution,
)
if args.subset_size:
indices = range(args.subset_size)
dataset = Subset(dataset, indices=indices)
dataloader = DataLoader(
dataset, args.batch_size, num_workers=num_workers, pin_memory=True
)
metric_score = calculate_common_metric(args, dataloader, device)
print(metric_score)
def parse_args():
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument("--batch_size", type=int, default=2, help="Batch size to use")
parser.add_argument("--real_video_dir", type=str, help=("the path of real videos`"))
parser.add_argument(
"--generated_video_dir", type=str, help=("the path of generated videos`")
)
parser.add_argument(
"--device",
type=str,
default=None,
help="Device to use. Like cuda, cuda:0 or cpu",
)
parser.add_argument(
"--num_workers",
type=int,
default=8,
help=(
"Number of processes to use for data loading. "
"Defaults to `min(8, num_cpus)`"
),
)
parser.add_argument("--sample_fps", type=int, default=30)
parser.add_argument("--resolution", type=int, default=336)
parser.add_argument("--crop_size", type=int, default=None)
parser.add_argument("--num_frames", type=int, default=100)
parser.add_argument("--sample_rate", type=int, default=1)
parser.add_argument("--subset_size", type=int, default=None)
parser.add_argument(
"--metric",
type=str,
default="fvd",
choices=["fvd", "psnr", "ssim", "lpips", "flolpips"],
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
main()
================================================
FILE: opensora/models/causalvideovae/eval/fvd/styleganv/fvd.py
================================================
import torch
import os
import math
import torch.nn.functional as F
# https://github.com/universome/fvd-comparison
def load_i3d_pretrained(device=torch.device('cpu')):
i3D_WEIGHTS_URL = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt"
filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_torchscript.pt')
print(filepath)
if not os.path.exists(filepath):
print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.")
os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}")
i3d = torch.jit.load(filepath).eval().to(device)
i3d = torch.nn.DataParallel(i3d)
return i3d
def get_feats(videos, detector, device, bs=10):
# videos : torch.tensor BCTHW [0, 1]
detector_kwargs = dict(rescale=False, resize=False, return_features=True) # Return raw features before the softmax layer.
feats = np.empty((0, 400))
with torch.no_grad():
for i in range((len(videos)-1)//bs + 1):
feats = np.vstack([feats, detector(torch.stack([preprocess_single(video) for video in videos[i*bs:(i+1)*bs]]).to(device), **detector_kwargs).detach().cpu().numpy()])
return feats
def get_fvd_feats(videos, i3d, device, bs=10):
# videos in [0, 1] as torch tensor BCTHW
# videos = [preprocess_single(video) for video in videos]
embeddings = get_feats(videos, i3d, device, bs)
return embeddings
def preprocess_single(video, resolution=224, sequence_length=None):
# video: CTHW, [0, 1]
c, t, h, w = video.shape
# temporal crop
if sequence_length is not None:
assert sequence_length <= t
video = video[:, :sequence_length]
# scale shorter side to resolution
scale = resolution / min(h, w)
if h < w:
target_size = (resolution, math.ceil(w * scale))
else:
target_size = (math.ceil(h * scale), resolution)
video = F.interpolate(video, size=target_size, mode='bilinear', align_corners=False)
# center crop
c, t, h, w = video.shape
w_start = (w - resolution) // 2
h_start = (h - resolution) // 2
video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
# [0, 1] -> [-1, 1]
video = (video - 0.5) * 2
return video.contiguous()
"""
Copy-pasted from https://github.com/cvpr2022-stylegan-v/stylegan-v/blob/main/src/metrics/frechet_video_distance.py
"""
from typing import Tuple
from scipy.linalg import sqrtm
import numpy as np
def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
mu = feats.mean(axis=0) # [d]
sigma = np.cov(feats, rowvar=False) # [d, d]
return mu, sigma
def frechet_distance(feats_fake: np.ndarray, feats_real: np.ndarray) -> float:
mu_gen, sigma_gen = compute_stats(feats_fake)
mu_real, sigma_real = compute_stats(feats_real)
m = np.square(mu_gen - mu_real).sum()
if feats_fake.shape[0]>1:
s, _ = sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
else:
fid = np.real(m)
return float(fid)
================================================
FILE: opensora/models/causalvideovae/eval/fvd/videogpt/fvd.py
================================================
import torch
import os
import math
import torch.nn.functional as F
import numpy as np
import einops
def load_i3d_pretrained(device=torch.device('cpu')):
i3D_WEIGHTS_URL = "https://onedrive.live.com/download?cid=78EEF3EB6AE7DBCB&resid=78EEF3EB6AE7DBCB%21199&authkey=AApKdFHPXzWLNyI"
filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_pretrained_400.pt')
print(filepath)
if not os.path.exists(filepath):
print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.")
os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}")
from .pytorch_i3d import InceptionI3d
i3d = InceptionI3d(400, in_channels=3).eval().to(device)
i3d.load_state_dict(torch.load(filepath, map_location=device))
i3d = torch.nn.DataParallel(i3d)
return i3d
def preprocess_single(video, resolution, sequence_length=None):
# video: THWC, {0, ..., 255}
video = video.permute(0, 3, 1, 2).float() / 255. # TCHW
t, c, h, w = video.shape
# temporal crop
if sequence_length is not None:
assert sequence_length <= t
video = video[:sequence_length]
# scale shorter side to resolution
scale = resolution / min(h, w)
if h < w:
target_size = (resolution, math.ceil(w * scale))
else:
target_size = (math.ceil(h * scale), resolution)
video = F.interpolate(video, size=target_size, mode='bilinear',
align_corners=False)
# center crop
t, c, h, w = video.shape
w_start = (w - resolution) // 2
h_start = (h - resolution) // 2
video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
video = video.permute(1, 0, 2, 3).contiguous() # CTHW
video -= 0.5
return video
def preprocess(videos, target_resolution=224):
# we should tras videos in [0-1] [b c t h w] as th.float
# -> videos in {0, ..., 255} [b t h w c] as np.uint8 array
videos = einops.rearrange(videos, 'b c t h w -> b t h w c')
videos = (videos*255).numpy().astype(np.uint8)
b, t, h, w, c = videos.shape
videos = torch.from_numpy(videos)
videos = torch.stack([preprocess_single(video, target_resolution) for video in videos])
return videos * 2 # [-0.5, 0.5] -> [-1, 1]
def get_fvd_logits(videos, i3d, device, bs=10):
videos = preprocess(videos)
embeddings = get_logits(i3d, videos, device, bs=10)
return embeddings
# https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161
def _symmetric_matrix_square_root(mat, eps=1e-10):
u, s, v = torch.svd(mat)
si = torch.where(s < eps, s, torch.sqrt(s))
return torch.matmul(torch.matmul(u, torch.diag(si)), v.t())
# https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L400
def trace_sqrt_product(sigma, sigma_v):
sqrt_sigma = _symmetric_matrix_square_root(sigma)
sqrt_a_sigmav_a = torch.matmul(sqrt_sigma, torch.matmul(sigma_v, sqrt_sigma))
return torch.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a))
# https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2
def cov(m, rowvar=False):
'''Estimate a covariance matrix given data.
Covariance indicates the level to which two variables vary together.
If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`,
then the covariance matrix element `C_{ij}` is the covariance of
`x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`.
Args:
m: A 1-D or 2-D array containing multiple variables and observations.
Each row of `m` represents a variable, and each column a single
observation of all those variables.
rowvar: If `rowvar` is True, then each row represents a
variable, with observations in the columns. Otherwise, the
relationship is transposed: each column represents a variable,
while the rows contain observations.
Returns:
The covariance matrix of the variables.
'''
if m.dim() > 2:
raise ValueError('m has more than 2 dimensions')
if m.dim() < 2:
m = m.view(1, -1)
if not rowvar and m.size(0) != 1:
m = m.t()
fact = 1.0 / (m.size(1) - 1) # unbiased estimate
m -= torch.mean(m, dim=1, keepdim=True)
mt = m.t() # if complex: mt = m.t().conj()
return fact * m.matmul(mt).squeeze()
def frechet_distance(x1, x2):
x1 = x1.flatten(start_dim=1)
x2 = x2.flatten(start_dim=1)
m, m_w = x1.mean(dim=0), x2.mean(dim=0)
sigma, sigma_w = cov(x1, rowvar=False), cov(x2, rowvar=False)
mean = torch.sum((m - m_w) ** 2)
if x1.shape[0]>1:
sqrt_trace_component = trace_sqrt_product(sigma, sigma_w)
trace = torch.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component
fd = trace + mean
else:
fd = np.real(mean)
return float(fd)
def get_logits(i3d, videos, device, bs=10):
# assert videos.shape[0] % 16 == 0
with torch.no_grad():
logits = []
for i in range(0, videos.shape[0], bs):
batch = videos[i:i + bs].to(device)
# logits.append(i3d.module.extract_features(batch)) # wrong
logits.append(i3d(batch)) # right
logits = torch.cat(logits, dim=0)
return logits
================================================
FILE: opensora/models/causalvideovae/eval/fvd/videogpt/pytorch_i3d.py
================================================
# Original code from https://github.com/piergiaj/pytorch-i3d
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class MaxPool3dSamePadding(nn.MaxPool3d):
def compute_pad(self, dim, s):
if s % self.stride[dim] == 0:
return max(self.kernel_size[dim] - self.stride[dim], 0)
else:
return max(self.kernel_size[dim] - (s % self.stride[dim]), 0)
def forward(self, x):
# compute 'same' padding
(batch, channel, t, h, w) = x.size()
out_t = np.ceil(float(t) / float(self.stride[0]))
out_h = np.ceil(float(h) / float(self.stride[1]))
out_w = np.ceil(float(w) / float(self.stride[2]))
pad_t = self.compute_pad(0, t)
pad_h = self.compute_pad(1, h)
pad_w = self.compute_pad(2, w)
pad_t_f = pad_t // 2
pad_t_b = pad_t - pad_t_f
pad_h_f = pad_h // 2
pad_h_b = pad_h - pad_h_f
pad_w_f = pad_w // 2
pad_w_b = pad_w - pad_w_f
pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
x = F.pad(x, pad)
return super(MaxPool3dSamePadding, self).forward(x)
class Unit3D(nn.Module):
def __init__(self, in_channels,
output_channels,
kernel_shape=(1, 1, 1),
stride=(1, 1, 1),
padding=0,
activation_fn=F.relu,
use_batch_norm=True,
use_bias=False,
name='unit_3d'):
"""Initializes Unit3D module."""
super(Unit3D, self).__init__()
self._output_channels = output_channels
self._kernel_shape = kernel_shape
self._stride = stride
self._use_batch_norm = use_batch_norm
self._activation_fn = activation_fn
self._use_bias = use_bias
self.name = name
self.padding = padding
self.conv3d = nn.Conv3d(in_channels=in_channels,
out_channels=self._output_channels,
kernel_size=self._kernel_shape,
stride=self._stride,
padding=0, # we always want padding to be 0 here. We will dynamically pad based on input size in forward function
bias=self._use_bias)
if self._use_batch_norm:
self.bn = nn.BatchNorm3d(self._output_channels, eps=1e-5, momentum=0.001)
def compute_pad(self, dim, s):
if s % self._stride[dim] == 0:
return max(self._kernel_shape[dim] - self._stride[dim], 0)
else:
return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0)
def forward(self, x):
# compute 'same' padding
(batch, channel, t, h, w) = x.size()
out_t = np.ceil(float(t) / float(self._stride[0]))
out_h = np.ceil(float(h) / float(self._stride[1]))
out_w = np.ceil(float(w) / float(self._stride[2]))
pad_t = self.compute_pad(0, t)
pad_h = self.compute_pad(1, h)
pad_w = self.compute_pad(2, w)
pad_t_f = pad_t // 2
pad_t_b = pad_t - pad_t_f
pad_h_f = pad_h // 2
pad_h_b = pad_h - pad_h_f
pad_w_f = pad_w // 2
pad_w_b = pad_w - pad_w_f
pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
x = F.pad(x, pad)
x = self.conv3d(x)
if self._use_batch_norm:
x = self.bn(x)
if self._activation_fn is not None:
x = self._activation_fn(x)
return x
class InceptionModule(nn.Module):
def __init__(self, in_channels, out_channels, name):
super(InceptionModule, self).__init__()
self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0,
name=name+'/Branch_0/Conv3d_0a_1x1')
self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0,
name=name+'/Branch_1/Conv3d_0a_1x1')
self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 3, 3],
name=name+'/Branch_1/Conv3d_0b_3x3')
self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0,
name=name+'/Branch_2/Conv3d_0a_1x1')
self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 3, 3],
name=name+'/Branch_2/Conv3d_0b_3x3')
self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3],
stride=(1, 1, 1), padding=0)
self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0,
name=name+'/Branch_3/Conv3d_0b_1x1')
self.name = name
def forward(self, x):
b0 = self.b0(x)
b1 = self.b1b(self.b1a(x))
b2 = self.b2b(self.b2a(x))
b3 = self.b3b(self.b3a(x))
return torch.cat([b0,b1,b2,b3], dim=1)
class InceptionI3d(nn.Module):
"""Inception-v1 I3D architecture.
The model is introduced in:
Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset
Joao Carreira, Andrew Zisserman
https://arxiv.org/pdf/1705.07750v1.pdf.
See also the Inception architecture, introduced in:
Going deeper with convolutions
Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
http://arxiv.org/pdf/1409.4842v1.pdf.
"""
# Endpoints of the model in order. During construction, all the endpoints up
# to a designated `final_endpoint` are returned in a dictionary as the
# second return value.
VALID_ENDPOINTS = (
'Conv3d_1a_7x7',
'MaxPool3d_2a_3x3',
'Conv3d_2b_1x1',
'Conv3d_2c_3x3',
'MaxPool3d_3a_3x3',
'Mixed_3b',
'Mixed_3c',
'MaxPool3d_4a_3x3',
'Mixed_4b',
'Mixed_4c',
'Mixed_4d',
'Mixed_4e',
'Mixed_4f',
'MaxPool3d_5a_2x2',
'Mixed_5b',
'Mixed_5c',
'Logits',
'Predictions',
)
def __init__(self, num_classes=400, spatial_squeeze=True,
final_endpoint='Logits', name='inception_i3d', in_channels=3, dropout_keep_prob=0.5):
"""Initializes I3D model instance.
Args:
num_classes: The number of outputs in the logit layer (default 400, which
matches the Kinetics dataset).
spatial_squeeze: Whether to squeeze the spatial dimensions for the logits
before returning (default True).
final_endpoint: The model contains many possible endpoints.
`final_endpoint` specifies the last endpoint for the model to be built
up to. In addition to the output at `final_endpoint`, all the outputs
at endpoints up to `final_endpoint` will also be returned, in a
dictionary. `final_endpoint` must be one of
InceptionI3d.VALID_ENDPOINTS (default 'Logits').
name: A string (optional). The name of this module.
Raises:
ValueError: if `final_endpoint` is not recognized.
"""
if final_endpoint not in self.VALID_ENDPOINTS:
raise ValueError('Unknown final endpoint %s' % final_endpoint)
super(InceptionI3d, self).__init__()
self._num_classes = num_classes
self._spatial_squeeze = spatial_squeeze
self._final_endpoint = final_endpoint
self.logits = None
if self._final_endpoint not in self.VALID_ENDPOINTS:
raise ValueError('Unknown final endpoint %s' % self._final_endpoint)
self.end_points = {}
end_point = 'Conv3d_1a_7x7'
self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7],
stride=(2, 2, 2), padding=(3,3,3), name=name+end_point)
if self._final_endpoint == end_point: return
end_point = 'MaxPool3d_2a_3x3'
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
padding=0)
if self._final_endpoint == end_point: return
end_point = 'Conv3d_2b_1x1'
self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0,
name=name+end_point)
if self._final_endpoint == end_point: return
end_point = 'Conv3d_2c_3x3'
self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1,
name=name+end_point)
if self._final_endpoint == end_point: return
end_point = 'MaxPool3d_3a_3x3'
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
padding=0)
if self._final_endpoint == end_point: return
end_point = 'Mixed_3b'
self.end_points[end_point] = InceptionModule(192, [64,96,128,16,32,32], name+end_point)
if self._final_endpoint == end_point: return
end_point = 'Mixed_3c'
self.end_points[end_point] = InceptionModule(256, [128,128,192,32,96,64], name+end_point)
if self._final_endpoint == end_point: return
end_point = 'MaxPool3d_4a_3x3'
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2),
padding=0)
if self._final_endpoint == end_point: return
end_point = 'Mixed_4b'
self.end_points[end_point] = InceptionModule(128+192+96+64, [192,96,208,16,48,64], name+end_point)
if self._final_endpoint == end_point: return
end_point = 'Mixed_4c'
self.end_points[end_point] = InceptionModule(192+208+48+64, [160,112,224,24,64,64], name+end_point)
if self._final_endpoint == end_point: return
end_point = 'Mixed_4d'
self.end_points[end_point] = InceptionModule(160+224+64+64, [128,128,256,24,64,64], name+end_point)
if self._final_endpoint == end_point: return
end_point = 'Mixed_4e'
self.end_points[end_point] = InceptionModule(128+256+64+64, [112,144,288,32,64,64], name+end_point)
if self._final_endpoint == end_point: return
end_point = 'Mixed_4f'
self.end_points[end_point] = InceptionModule(112+288+64+64, [256,160,320,32,128,128], name+end_point)
if self._final_endpoint == end_point: return
end_point = 'MaxPool3d_5a_2x2'
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2),
padding=0)
if self._final_endpoint == end_point: return
end_point = 'Mixed_5b'
self.end_points[end_point] = InceptionModule(256+320+128+128, [256,160,320,32,128,128], name+end_point)
if self._final_endpoint == end_point: return
end_point = 'Mixed_5c'
self.end_points[end_point] = InceptionModule(256+320+128+128, [384,192,384,48,128,128], name+end_point)
if self._final_endpoint == end_point: return
end_point = 'Logits'
self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7],
stride=(1, 1, 1))
self.dropout = nn.Dropout(dropout_keep_prob)
self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,
kernel_shape=[1, 1, 1],
padding=0,
activation_fn=None,
use_batch_norm=False,
use_bias=True,
name='logits')
self.build()
def replace_logits(self, num_classes):
self._num_classes = num_classes
self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,
kernel_shape=[1, 1, 1],
padding=0,
activation_fn=None,
use_batch_norm=False,
use_bias=True,
name='logits')
def build(self):
for k in self.end_points.keys():
self.add_module(k, self.end_points[k])
def forward(self, x):
for end_point in self.VALID_ENDPOINTS:
if end_point in self.end_points:
x = self._modules[end_point](x) # use _modules to work with dataparallel
x = self.logits(self.dropout(self.avg_pool(x)))
if self._spatial_squeeze:
logits = x.squeeze(3).squeeze(3)
logits = logits.mean(dim=2)
# logits is batch X time X classes, which is what we want to work with
return logits
def extract_features(self, x):
for end_point in self.VALID_ENDPOINTS:
if end_point in self.end_points:
x = self._modules[end_point](x)
return self.avg_pool(x)
================================================
FILE: opensora/models/causalvideovae/eval/script/cal_clip_score.sh
================================================
# clip_score cross modality
python eval_clip_score.py \
--real_path path/to/image \
--generated_path path/to/text \
--batch-size 50 \
--device "cuda"
# clip_score within the same modality
python eval_clip_score.py \
--real_path path/to/textA \
--generated_path path/to/textB \
--real_flag txt \
--generated_flag txt \
--batch-size 50 \
--device "cuda"
python eval_clip_score.py \
--real_path path/to/imageA \
--generated_path path/to/imageB \
--real_flag img \
--generated_flag img \
--batch-size 50 \
--device "cuda"
================================================
FILE: opensora/models/causalvideovae/eval/script/cal_fvd.sh
================================================
python eval_common_metric.py \
--real_video_dir path/to/imageA\
--generated_video_dir path/to/imageB \
--batch_size 10 \
--crop_size 64 \
--num_frames 20 \
--device 'cuda' \
--metric 'fvd' \
--fvd_method 'styleganv'
================================================
FILE: opensora/models/causalvideovae/eval/script/cal_lpips.sh
================================================
python eval_common_metric.py \
--real_video_dir path/to/imageA\
--generated_video_dir path/to/imageB \
--batch_size 10 \
--num_frames 20 \
--crop_size 64 \
--device 'cuda' \
--metric 'lpips'
================================================
FILE: opensora/models/causalvideovae/eval/script/cal_psnr.sh
================================================
python eval_common_metric.py \
--real_video_dir /data/xiaogeng_liu/data/video1 \
--generated_video_dir /data/xiaogeng_liu/data/video2 \
--batch_size 10 \
--num_frames 20 \
--crop_size 64 \
--device 'cuda' \
--metric 'psnr'
================================================
FILE: opensora/models/causalvideovae/eval/script/cal_ssim.sh
================================================
python eval_common_metric.py \
--real_video_dir /data/xiaogeng_liu/data/video1 \
--generated_video_dir /data/xiaogeng_liu/data/video2 \
--batch_size 10 \
--num_frames 20 \
--crop_size 64 \
--device 'cuda' \
--metric 'ssim'
================================================
FILE: opensora/models/causalvideovae/model/__init__.py
================================================
from .registry import ModelRegistry
from .vae import (
CausalVAEModel, WFVAEModel
)
================================================
FILE: opensora/models/causalvideovae/model/configuration_videobase.py
================================================
import json
import yaml
from typing import TypeVar, Dict, Any
from diffusers import ConfigMixin
T = TypeVar('T', bound='VideoBaseConfiguration')
class VideoBaseConfiguration(ConfigMixin):
config_name = "VideoBaseConfiguration"
_nested_config_fields: Dict[str, Any] = {}
def __init__(self, **kwargs):
pass
def to_dict(self) -> Dict[str, Any]:
d = {}
for key, value in vars(self).items():
if isinstance(value, VideoBaseConfiguration):
d[key] = value.to_dict() # Serialize nested VideoBaseConfiguration instances
elif isinstance(value, tuple):
d[key] = list(value)
else:
d[key] = value
return d
def to_yaml_file(self, yaml_path: str):
with open(yaml_path, 'w') as yaml_file:
yaml.dump(self.to_dict(), yaml_file, default_flow_style=False)
@classmethod
def load_from_yaml(cls: T, yaml_path: str) -> T:
with open(yaml_path, 'r') as yaml_file:
config_dict = yaml.safe_load(yaml_file)
for field, field_type in cls._nested_config_fields.items():
if field in config_dict:
config_dict[field] = field_type.load_from_dict(config_dict[field])
return cls(**config_dict)
@classmethod
def load_from_dict(cls: T, config_dict: Dict[str, Any]) -> T:
# Process nested configuration objects
for field, field_type in cls._nested_config_fields.items():
if field in config_dict:
config_dict[field] = field_type.load_from_dict(config_dict[field])
return cls(**config_dict)
================================================
FILE: opensora/models/causalvideovae/model/dataset_videobase.py
================================================
import os.path as osp
import random
from glob import glob
from torchvision import transforms
import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F
from torchvision.transforms import Lambda
from ..dataset.transform import ToTensorVideo, CenterCropVideo
from ..utils.dataset_utils import DecordInit
def TemporalRandomCrop(total_frames, size):
"""
Performs a random temporal crop on a video sequence.
This function randomly selects a continuous frame sequence of length `size` from a video sequence.
`total_frames` indicates the total number of frames in the video sequence, and `size` represents the length of the frame sequence to be cropped.
Parameters:
- total_frames (int): The total number of frames in the video sequence.
- size (int): The length of the frame sequence to be cropped.
Returns:
- (int, int): A tuple containing two integers. The first integer is the starting frame index of the cropped sequence,
and the second integer is the ending frame index (inclusive) of the cropped sequence.
"""
rand_end = max(0, total_frames - size - 1)
begin_index = random.randint(0, rand_end)
end_index = min(begin_index + size, total_frames)
return begin_index, end_index
def resize(x, resolution):
height, width = x.shape[-2:]
resolution = min(2 * resolution, height, width)
aspect_ratio = width / height
if width <= height:
new_width = resolution
new_height = int(resolution / aspect_ratio)
else:
new_height = resolution
new_width = int(resolution * aspect_ratio)
resized_x = F.interpolate(x, size=(new_height, new_width), mode='bilinear', align_corners=True, antialias=True)
return resized_x
class VideoDataset(data.Dataset):
""" Generic dataset for videos files stored in folders
Returns BCTHW videos in the range [-0.5, 0.5] """
video_exts = ['avi', 'mp4', 'webm']
def __init__(self, video_folder, sequence_length, image_folder=None, train=True, resolution=64, sample_rate=1, dynamic_sample=True):
self.train = train
self.sequence_length = sequence_length
self.sample_rate = sample_rate
self.resolution = resolution
self.v_decoder = DecordInit()
self.video_folder = video_folder
self.dynamic_sample = dynamic_sample
self.transform = transforms.Compose([
ToTensorVideo(),
# Lambda(lambda x: resize(x, self.resolution)),
CenterCropVideo(self.resolution),
Lambda(lambda x: 2.0 * x - 1.0)
])
print('Building datasets...')
self.samples = self._make_dataset()
def _make_dataset(self):
samples = []
samples += sum([glob(osp.join(self.video_folder, '**', f'*.{ext}'), recursive=True)
for ext in self.video_exts], [])
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
video_path = self.samples[idx]
try:
video = self.decord_read(video_path)
video = self.transform(video) # T C H W -> T C H W
video = video.transpose(0, 1) # T C H W -> C T H W
return dict(video=video, label="")
except Exception as e:
print(f'Error with {e}, {video_path}')
return self.__getitem__(random.randint(0, self.__len__()-1))
def decord_read(self, path):
decord_vr = self.v_decoder(path)
total_frames = len(decord_vr)
# Sampling video frames
if self.dynamic_sample:
sample_rate = random.randint(1, self.sample_rate)
else:
sample_rate = self.sample_rate
size = self.sequence_length * sample_rate
start_frame_ind, end_frame_ind = TemporalRandomCrop(total_frames, size)
# assert end_frame_ind - start_frame_ind >= self.num_frames
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.sequence_length, dtype=int)
video_data = decord_vr.get_batch(frame_indice).asnumpy()
video_data = torch.from_numpy(video_data)
video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W)
return video_data
================================================
FILE: opensora/models/causalvideovae/model/ema_model.py
================================================
class EMA:
def __init__(self, model, decay):
self.model = model
self.decay = decay
self.shadow = {}
self.backup = {}
def register(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
def update(self):
for name, param in self.model.named_parameters():
if name in self.shadow:
new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
self.shadow[name] = new_average.clone()
def apply_shadow(self):
for name, param in self.model.named_parameters():
if name in self.shadow:
self.backup[name] = param.data
param.data = self.shadow[name]
def restore(self):
for name, param in self.model.named_parameters():
if name in self.shadow:
param.data = self.backup[name]
self.backup = {}
================================================
FILE: opensora/models/causalvideovae/model/losses/__init__.py
================================================
from .perceptual_loss import LPIPSWithDiscriminator3D
================================================
FILE: opensora/models/causalvideovae/model/losses/discriminator.py
================================================
import functools
import torch.nn as nn
from ..modules.conv import CausalConv3d
from einops import rearrange
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
def weights_init_conv(m):
if hasattr(m, 'conv'):
m = m.conv
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class NLayerDiscriminator3D(nn.Module):
"""Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs."""
def __init__(self, input_nc=1, ndf=64, n_layers=3, use_actnorm=False):
"""
Construct a 3D PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input volumes
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
use_actnorm (bool) -- flag to use actnorm instead of batchnorm
"""
super(NLayerDiscriminator3D, self).__init__()
if not use_actnorm:
norm_layer = nn.BatchNorm3d
else:
raise NotImplementedError("Not implemented.")
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func != nn.BatchNorm3d
else:
use_bias = norm_layer != nn.BatchNorm3d
kw = 3
padw = 1
sequence = [nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=(2 if n==1 else 1,2,2), padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.main = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
return self.main(input)
# class NLayerDiscriminator3D(nn.Module):
# """Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs."""
# def __init__(self, input_nc=1, ndf=64, n_layers=3, use_actnorm=False):
# """
# Construct a 3D PatchGAN discriminator
# Parameters:
# input_nc (int) -- the number of channels in input volumes
# ndf (int) -- the number of filters in the last conv layer
# n_layers (int) -- the number of conv layers in the discriminator
# use_actnorm (bool) -- flag to use actnorm instead of batchnorm
# """
# super(NLayerDiscriminator3D, self).__init__()
# if not use_actnorm:
# norm_layer = nn.BatchNorm3d
# else:
# raise NotImplementedError("Not implemented.")
# if type(norm_layer) == functools.partial:
# use_bias = norm_layer.func != nn.BatchNorm3d
# else:
# use_bias = norm_layer != nn.BatchNorm3d
# kw = 4
# padw = 1
# sequence = [CausalConv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
# nf_mult = 1
# nf_mult_prev = 1
# for n in range(1, n_layers): # gradually increase the number of filters
# nf_mult_prev = nf_mult
# nf_mult = min(2 ** n, 8)
# sequence += [
# CausalConv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=(2 if n==1 else 1,2,2), padding=padw, bias=use_bias),
# norm_layer(ndf * nf_mult),
# nn.LeakyReLU(0.2, True)
# ]
# nf_mult_prev = nf_mult
# nf_mult = min(2 ** n_layers, 8)
# sequence += [
# CausalConv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias),
# norm_layer(ndf * nf_mult),
# nn.LeakyReLU(0.2, True)
# ]
# sequence += [CausalConv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
# self.main = nn.Sequential(*sequence)
# def forward(self, input):
# """Standard forward."""
# return self.main(input)
================================================
FILE: opensora/models/causalvideovae/model/losses/lpips.py
================================================
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
import torch
import torch.nn as nn
from torchvision import models
from collections import namedtuple
from .....utils.taming_download import get_ckpt_path
class LPIPS(nn.Module):
# Learned perceptual metric
def __init__(self, use_dropout=True):
super().__init__()
self.scaling_layer = ScalingLayer()
self.chns = [64, 128, 256, 512, 512] # vg16 features
self.net = vgg16(pretrained=True, requires_grad=False)
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.load_from_pretrained()
for param in self.parameters():
param.requires_grad = False
def load_from_pretrained(self, name="vgg_lpips"):
ckpt = get_ckpt_path(name, ".cache/lpips")
self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
print("loaded pretrained LPIPS loss from {}".format(ckpt))
@classmethod
def from_pretrained(cls, name="vgg_lpips"):
if name != "vgg_lpips":
raise NotImplementedError
model = cls()
ckpt = get_ckpt_path(name)
model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
return model
def forward(self, input, target):
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
outs0, outs1 = self.net(in0_input), self.net(in1_input)
feats0, feats1, diffs = {}, {}, {}
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
for kk in range(len(self.chns)):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
val = res[0]
for l in range(1, len(self.chns)):
val += res[l]
return val
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
def forward(self, inp):
return (inp - self.shift) / self.scale
class NetLinLayer(nn.Module):
""" A single linear layer which does a 1x1 conv """
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = [nn.Dropout(), ] if (use_dropout) else []
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
self.model = nn.Sequential(*layers)
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
def normalize_tensor(x,eps=1e-10):
norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
return x/(norm_factor+eps)
def spatial_average(x, keepdim=True):
return x.mean([2,3],keepdim=keepdim)
================================================
FILE: opensora/models/causalvideovae/model/losses/perceptual_loss.py
================================================
import torch
from torch import nn
import torch.nn.functional as F
from .lpips import LPIPS
from einops import rearrange
from .discriminator import weights_init, NLayerDiscriminator3D
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1.0 - logits_real))
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
d_loss = 0.5 * (
torch.mean(torch.nn.functional.softplus(-logits_real))
+ torch.mean(torch.nn.functional.softplus(logits_fake))
)
return d_loss
def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3])
loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3])
loss_real = (weights * loss_real).sum() / weights.sum()
loss_fake = (weights * loss_fake).sum() / weights.sum()
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
def measure_perplexity(predicted_indices, n_embed):
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use
def l1(x, y):
return torch.abs(x - y)
def l2(x, y):
return torch.pow((x - y), 2)
class LPIPSWithDiscriminator3D(nn.Module):
def __init__(
self,
disc_start,
logvar_init=0.0,
kl_weight=1.0,
pixelloss_weight=1.0,
perceptual_weight=1.0,
disc_num_layers=4,
disc_in_channels=3,
disc_factor=1.0,
disc_weight=1.0,
use_actnorm=False,
disc_conditional=False,
disc_loss="hinge",
learn_logvar: bool = False,
wavelet_weight=0.01,
loss_type: str = "l1",
):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
self.wavelet_weight = wavelet_weight
self.kl_weight = kl_weight
self.pixel_weight = pixelloss_weight
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
self.logvar = nn.Parameter(
torch.full((), logvar_init), requires_grad=learn_logvar
)
self.discriminator = NLayerDiscriminator3D(
input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm
).apply(weights_init)
self.discriminator_iter_start = disc_start
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
self.loss_func = l1 if loss_type == "l1" else l2
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
layer = last_layer if last_layer is not None else self.last_layer[0]
nll_grads = torch.autograd.grad(nll_loss, layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, layer, retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(
self,
inputs,
reconstructions,
posteriors,
optimizer_idx,
global_step,
split="train",
weights=None,
last_layer=None,
wavelet_coeffs=None,
cond=None,
):
bs = inputs.shape[0]
t = inputs.shape[2]
if optimizer_idx == 0:
inputs = rearrange(inputs, "b c t h w -> (b t) c h w").contiguous()
reconstructions = rearrange(
reconstructions, "b c t h w -> (b t) c h w"
).contiguous()
rec_loss = self.loss_func(inputs, reconstructions)
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(inputs, reconstructions)
rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if weights is not None:
weighted_nll_loss = weights * nll_loss
weighted_nll_loss = (
torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
)
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
kl_loss = posteriors.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
if wavelet_coeffs:
wl_loss_l2 = torch.sum(l1(wavelet_coeffs[0], wavelet_coeffs[1])) / bs
wl_loss_l3 = torch.sum(l1(wavelet_coeffs[2], wavelet_coeffs[3])) / bs
wl_loss = wl_loss_l2 + wl_loss_l3
else:
wl_loss = torch.tensor(0.0)
inputs = rearrange(inputs, "(b t) c h w -> b c t h w", t=t).contiguous()
reconstructions = rearrange(
reconstructions, "(b t) c h w -> b c t h w", t=t
).contiguous()
logits_fake = self.discriminator(reconstructions)
g_loss = -torch.mean(logits_fake)
if global_step >= self.discriminator_iter_start:
if self.disc_factor > 0.0:
d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
else:
d_weight = torch.tensor(1.0)
else:
d_weight = torch.tensor(0.0)
g_loss = torch.tensor(0.0, requires_grad=True)
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
loss = (
weighted_nll_loss
+ self.kl_weight * kl_loss
+ d_weight * disc_factor * g_loss
+ self.wavelet_weight * wl_loss
)
log = {
"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/logvar".format(split): self.logvar.detach(),
"{}/kl_loss".format(split): kl_loss.detach().mean(),
"{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): weighted_nll_loss.detach().mean(),
"{}/wl_loss".format(split): wl_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
return loss, log
elif optimizer_idx == 1:
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {
"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean(),
}
return d_loss, log
================================================
FILE: opensora/models/causalvideovae/model/modeling_videobase.py
================================================
import torch
from diffusers import ModelMixin, ConfigMixin
from torch import nn
import os
import json
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.modeling_utils import ModelMixin
from typing import Optional, Union
import glob
class VideoBaseAE(ModelMixin, ConfigMixin):
config_name = "config.json"
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def encode(self, x: torch.Tensor, *args, **kwargs):
pass
def decode(self, encoding: torch.Tensor, *args, **kwargs):
pass
@property
def num_training_steps(self) -> int:
"""Total training steps inferred from datamodule and devices."""
if self.trainer.max_steps:
return self.trainer.max_steps
limit_batches = self.trainer.limit_train_batches
batches = len(self.train_dataloader())
batches = min(batches, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * batches)
num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
if self.trainer.tpu_cores:
num_devices = max(num_devices, self.trainer.tpu_cores)
effective_accum = self.trainer.accumulate_grad_batches * num_devices
return (batches // effective_accum) * self.trainer.max_epochs
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
ckpt_files = glob.glob(os.path.join(pretrained_model_name_or_path, '*.ckpt'))
if ckpt_files:
# Adapt to checkpoint
last_ckpt_file = ckpt_files[-1]
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
model = cls.from_config(config_file)
model.init_from_ckpt(last_ckpt_file)
return model
else:
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
================================================
FILE: opensora/models/causalvideovae/model/modules/__init__.py
================================================
from .block import Block
from .attention import *
from .conv import *
from .normalize import *
from .resnet_block import *
from .updownsample import *
from .wavelet import *
================================================
FILE: opensora/models/causalvideovae/model/modules/attention.py
================================================
import torch.nn as nn
import torch.nn.functional as F
from .normalize import Normalize
from .conv import CausalConv3d
import torch
from .block import Block
try:
import torch_npu
from opensora.npu_config import npu_config, set_run_dtype
except:
torch_npu = None
npu_config = None
# from xformers import ops as xops
class AttnBlock3D(Block):
"""Compatible with old versions, there are issues, use with caution."""
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, t, h, w = q.shape
q = q.reshape(b * t, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b * t, c, h * w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b * t, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, t, h, w)
h_ = self.proj_out(h_)
return x + h_
class AttnBlock3DFix(nn.Module):
"""
Thanks to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/172.
"""
def __init__(self, in_channels, norm_type="groupnorm"):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels, norm_type=norm_type)
self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, t, h, w = q.shape
q = q.permute(0, 2, 3, 4, 1).reshape(b * t, h * w, c).contiguous()
k = k.permute(0, 2, 3, 4, 1).reshape(b * t, h * w, c).contiguous()
v = v.permute(0, 2, 3, 4, 1).reshape(b * t, h * w, c).contiguous()
if torch_npu is None:
# attn_output = xops.memory_efficient_attention(
# q, k, v,
# scale=c ** -0.5
# )
q = q.view(b * t, -1, 1, c).transpose(1, 2)
k = k.view(b * t, -1, 1, c).transpose(1, 2)
v = v.view(b * t, -1, 1, c).transpose(1, 2)
attn_output = F.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False
)
attn_output = attn_output.transpose(1, 2).reshape(b * t, -1, 1 * c)
else:
# print('npu_config.enable_FA, q.dtype == torch.float32', npu_config.enable_FA, q.dtype == torch.float32)
if npu_config.enable_FA and q.dtype == torch.float32:
dtype = torch.bfloat16
else:
dtype = None
with set_run_dtype(q, dtype):
query, key, value = npu_config.set_current_run_dtype([q, k, v])
hidden_states = npu_config.run_attention(query, key, value, atten_mask=None, input_layout="BSH",
head_dim=c, head_num=1)
attn_output = npu_config.restore_dtype(hidden_states)
attn_output = attn_output.reshape(b, t, h, w, c).permute(0, 4, 1, 2, 3)
h_ = self.proj_out(attn_output)
return x + h_
================================================
FILE: opensora/models/causalvideovae/model/modules/block.py
================================================
import torch.nn as nn
class Block(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
================================================
FILE: opensora/models/causalvideovae/model/modules/conv.py
================================================
try:
import torch_npu
from opensora.npu_config import npu_config
except:
torch_npu = None
npu_config = None
import torch.nn as nn
from typing import Union, Tuple
import torch
from .block import Block
from .ops import cast_tuple
from .ops import video_to_image
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F
from collections import deque
class Conv2d(nn.Conv2d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int]] = 3,
stride: Union[int, Tuple[int]] = 1,
padding: Union[str, int, Tuple[int]] = 0,
dilation: Union[int, Tuple[int]] = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None,
dtype=None,
) -> None:
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
device,
dtype,
)
@video_to_image
def forward(self, x):
return super().forward(x)
class CausalConv3d(Block):
def __init__(
self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int, int, int]],
enable_cached=False,
bias=True,
**kwargs
):
super().__init__()
self.kernel_size = cast_tuple(kernel_size, 3)
self.time_kernel_size = self.kernel_size[0]
self.chan_in = chan_in
self.chan_out = chan_out
self.stride = kwargs.pop("stride", 1)
self.padding = kwargs.pop("padding", 0)
self.padding = list(cast_tuple(self.padding, 3))
self.padding[0] = 0
self.stride = cast_tuple(self.stride, 3)
self.conv = nn.Conv3d(
chan_in,
chan_out,
self.kernel_size,
stride=self.stride,
padding=self.padding,
bias=bias
)
self.enable_cached = enable_cached
self.is_first_chunk = True
self.causal_cached = deque()
self.cache_offset = 0
def forward(self, x):
if self.is_first_chunk:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, self.time_kernel_size - 1, 1, 1)
)
else:
first_frame_pad = self.causal_cached.popleft()
x = torch.concatenate((first_frame_pad, x), dim=2)
if self.enable_cached and self.time_kernel_size != 1:
if (self.time_kernel_size - 1) // self.stride[0] != 0:
if self.cache_offset == 0:
self.causal_cached.append(x[:, :, -(self.time_kernel_size - 1) // self.stride[0]:].clone())
else:
self.causal_cached.append(x[:, :, :-self.cache_offset][:, :, -(self.time_kernel_size - 1) // self.stride[0]:].clone())
else:
self.causal_cached.append(x[:, :, 0:0, :, :].clone())
elif self.enable_cached:
self.causal_cached.append(x[:, :, 0:0, :, :].clone())
x = self.conv(x)
return x
class CausalConv3d_GC(CausalConv3d):
def __init__(
self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int]],
init_method="random",
**kwargs
):
super().__init__(chan_in, chan_out, kernel_size, init_method, **kwargs)
def forward(self, x):
# 1 + 16 16 as video, 1 as image
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, self.time_kernel_size - 1, 1, 1)
) # b c t h w
x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16
return checkpoint(self.conv, x)
================================================
FILE: opensora/models/causalvideovae/model/modules/normalize.py
================================================
import torch
import torch.nn as nn
from .block import Block
from einops import rearrange
class GroupNorm(Block):
def __init__(self, num_channels, num_groups=32, eps=1e-6, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.norm = torch.nn.GroupNorm(
num_groups=num_groups, num_channels=num_channels, eps=eps, affine=True
)
def forward(self, x):
return self.norm(x)
class LayerNorm(Block):
def __init__(self, num_channels, eps=1e-6, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.norm = torch.nn.LayerNorm(num_channels, eps=eps, elementwise_affine=True)
def forward(self, x):
if x.dim() == 5:
x = rearrange(x, "b c t h w -> b t h w c")
x = self.norm(x)
x = rearrange(x, "b t h w c -> b c t h w")
else:
x = rearrange(x, "b c h w -> b h w c")
x = self.norm(x)
x = rearrange(x, "b h w c -> b c h w")
return x
def Normalize(in_channels, num_groups=32, norm_type="groupnorm"):
if norm_type == "groupnorm":
return torch.nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
)
elif norm_type == "layernorm":
return LayerNorm(num_channels=in_channels, eps=1e-6)
================================================
FILE: opensora/models/causalvideovae/model/modules/ops.py
================================================
import torch
from einops import rearrange
def video_to_image(func):
def wrapper(self, x, *args, **kwargs):
if x.dim() == 5:
t = x.shape[2]
if True:
x = rearrange(x, "b c t h w -> (b t) c h w")
x = func(self, x, *args, **kwargs)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
else:
# Conv 2d slice infer
result = []
for i in range(t):
frame = x[:, :, i, :, :]
frame = func(self, frame, *args, **kwargs)
result.append(frame.unsqueeze(2))
x = torch.concatenate(result, dim=2)
return x
return wrapper
def nonlinearity(x):
return x * torch.sigmoid(x)
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) or isinstance(t, list) else ((t,) * length)
def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
n_dims = len(x.shape)
if src_dim < 0:
src_dim = n_dims + src_dim
if dest_dim < 0:
dest_dim = n_dims + dest_dim
assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
dims = list(range(n_dims))
del dims[src_dim]
permutation = []
ctr = 0
for i in range(n_dims):
if i == dest_dim:
permutation.append(src_dim)
else:
permutation.append(dims[ctr])
ctr += 1
x = x.permute(permutation)
if make_contiguous:
x = x.contiguous()
return x
================================================
FILE: opensora/models/causalvideovae/model/modules/quant.py
================================================
import torch
import torch.nn as nn
import torch.distributed as dist
import numpy as np
import torch.nn.functional as F
from .ops import shift_dim
class Codebook(nn.Module):
def __init__(self, n_codes, embedding_dim):
super().__init__()
self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim))
self.register_buffer("N", torch.zeros(n_codes))
self.register_buffer("z_avg", self.embeddings.data.clone())
self.n_codes = n_codes
self.embedding_dim = embedding_dim
self._need_init = True
def _tile(self, x):
d, ew = x.shape
if d < self.n_codes:
n_repeats = (self.n_codes + d - 1) // d
std = 0.01 / np.sqrt(ew)
x = x.repeat(n_repeats, 1)
x = x + torch.randn_like(x) * std
return x
def _init_embeddings(self, z):
# z: [b, c, t, h, w]
self._need_init = False
flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
y = self._tile(flat_inputs)
d = y.shape[0]
_k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
if dist.is_initialized():
dist.broadcast(_k_rand, 0)
self.embeddings.data.copy_(_k_rand)
self.z_avg.data.copy_(_k_rand)
self.N.data.copy_(torch.ones(self.n_codes))
def forward(self, z):
# z: [b, c, t, h, w]
if self._need_init and self.training:
self._init_embeddings(z)
flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
distances = (
(flat_inputs**2).sum(dim=1, keepdim=True)
- 2 * flat_inputs @ self.embeddings.t()
+ (self.embeddings.t() ** 2).sum(dim=0, keepdim=True)
)
encoding_indices = torch.argmin(distances, dim=1)
encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs)
encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:])
embeddings = F.embedding(encoding_indices, self.embeddings)
embeddings = shift_dim(embeddings, -1, 1)
commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())
# EMA codebook update
if self.training:
n_total = encode_onehot.sum(dim=0)
encode_sum = flat_inputs.t() @ encode_onehot
if dist.is_initialized():
dist.all_reduce(n_total)
dist.all_reduce(encode_sum)
self.N.data.mul_(0.99).add_(n_total, alpha=0.01)
self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01)
n = self.N.sum()
weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n
encode_normalized = self.z_avg / weights.unsqueeze(1)
self.embeddings.data.copy_(encode_normalized)
y = self._tile(flat_inputs)
_k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
if dist.is_initialized():
dist.broadcast(_k_rand, 0)
usage = (self.N.view(self.n_codes, 1) >= 1).float()
self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage))
embeddings_st = (embeddings - z).detach() + z
avg_probs = torch.mean(encode_onehot, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
return dict(
embeddings=embeddings_st,
encodings=encoding_indices,
commitment_loss=commitment_loss,
perplexity=perplexity,
)
def dictionary_lookup(self, encodings):
embeddings = F.embedding(encodings, self.embeddings)
return embeddings
================================================
FILE: opensora/models/causalvideovae/model/modules/resnet_block.py
================================================
try:
import torch_npu
from opensora.npu_config import npu_config
except:
torch_npu = None
npu_config = None
import torch
from .normalize import Normalize
from .ops import nonlinearity, video_to_image
from .conv import CausalConv3d
from .block import Block
from torch.utils.checkpoint import checkpoint
class ResnetBlock2D(Block):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
norm_type,
dropout,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels, norm_type=norm_type)
self.conv1 = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.norm2 = Normalize(out_channels, norm_type=norm_type)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
else:
self.nin_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
@video_to_image
def forward(self, x):
h = x
if npu_config is None:
h = self.norm1(h)
else:
h = npu_config.run_group_norm(self.norm1, h)
h = nonlinearity(h)
h = self.conv1(h)
if npu_config is None:
h = self.norm2(h)
else:
h = npu_config.run_group_norm(self.norm2, h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
x = x + h
return x
class ResnetBlock3D(Block):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
norm_type,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels, norm_type=norm_type)
self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1)
self.norm2 = Normalize(out_channels, norm_type=norm_type)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = CausalConv3d(
in_channels, out_channels, 3, padding=1
)
else:
self.nin_shortcut = CausalConv3d(
in_channels, out_channels, 1, padding=0
)
def forward(self, x):
h = x
if npu_config is None:
h = self.norm1(h)
else:
h = npu_config.run_group_norm(self.norm1, h)
h = nonlinearity(h)
h = self.conv1(h)
if npu_config is None:
h = self.norm2(h)
else:
h = npu_config.run_group_norm(self.norm2, h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class ResnetBlock3D_GC(Block):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
norm_type,
dropout,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels, norm_type=norm_type)
self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1)
self.norm2 = Normalize(out_channels, norm_type=norm_type)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = CausalConv3d(
in_channels, out_channels, 3, padding=1
)
else:
self.nin_shortcut = CausalConv3d(
in_channels, out_channels, 1, padding=0
)
def forward(self, x):
return checkpoint(self._forward, x, use_reentrant=True)
def _forward(self, x):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
================================================
FILE: opensora/models/causalvideovae/model/modules/updownsample.py
================================================
from typing import Union, Tuple
from collections import deque
import torch
import torch.nn as nn
import torch.nn.functional as F
from .ops import cast_tuple, video_to_image
from .conv import CausalConv3d, CausalConv3d_GC
from einops import rearrange
from .block import Block
try:
import torch_npu
from opensora.npu_config import npu_config
except:
torch_npu = None
npu_config = None
class Upsample(Block):
def __init__(self, in_channels, out_channels):
super().__init__()
self.with_conv = True
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
@video_to_image
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class Downsample(Block):
def __init__(self, in_channels, out_channels, undown=False):
super().__init__()
self.with_conv = True
self.undown = undown
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
if self.undown:
self.conv = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.conv = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=0)
@video_to_image
def forward(self, x):
if self.with_conv:
if self.undown:
if npu_config is not None and npu_config.on_npu:
x_dtype = x.dtype
x = x.to(npu_config.replaced_type)
x = npu_config.run_conv3d(self.conv, x, x_dtype)
else:
x = self.conv(x)
else:
pad = (0, 1, 0, 1)
if npu_config is not None and npu_config.on_npu:
x_dtype = x.dtype
x = x.to(npu_config.replaced_type)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = npu_config.run_conv3d(self.conv, x, x_dtype)
else:
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class SpatialDownsample2x(Block):
def __init__(
self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int]] = (3, 3),
stride: Union[int, Tuple[int]] = (2, 2),
**kwargs
):
super().__init__()
kernel_size = cast_tuple(kernel_size, 2)
stride = cast_tuple(stride, 2)
self.chan_in = chan_in
self.chan_out = chan_out
self.kernel_size = kernel_size
self.conv = CausalConv3d(
self.chan_in,
self.chan_out,
(1,) + self.kernel_size,
stride=(1, ) + stride,
padding=0
)
def forward(self, x):
pad = (0,1,0,1,0,0)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class SpatialUpsample2x_GC(Block):
def __init__(
self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int]] = (3, 3),
stride: Union[int, Tuple[int]] = (1, 1),
unup=False,
):
super().__init__()
self.chan_in = chan_in
self.chan_out = chan_out
self.kernel_size = kernel_size
self.unup = unup
self.conv = CausalConv3d_GC(
self.chan_in,
self.chan_out,
(1,) + self.kernel_size,
stride=(1, ) + stride,
padding=1
)
def forward(self, x):
if not self.unup:
t = x.shape[2]
x = rearrange(x, "b c t h w -> b (c t) h w")
x = F.interpolate(x, scale_factor=(2,2), mode="nearest")
x = rearrange(x, "b (c t) h w -> b c t h w", t=t)
x = self.conv(x)
return x
class SpatialUpsample2x(Block):
def __init__(
self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int]] = (3, 3),
stride: Union[int, Tuple[int]] = (1, 1),
unup=False,
):
super().__init__()
self.chan_in = chan_in
self.chan_out = chan_out
self.kernel_size = kernel_size
self.unup = unup
self.conv = CausalConv3d(
self.chan_in,
self.chan_out,
(1,) + self.kernel_size,
stride=(1, ) + stride,
padding=1
)
def forward(self, x):
if not self.unup:
t = x.shape[2]
x = rearrange(x, "b c t h w -> b (c t) h w")
x = F.interpolate(x, scale_factor=(2,2), mode="nearest")
x = rearrange(x, "b (c t) h w -> b c t h w", t=t)
x = self.conv(x)
return x
class TimeDownsample2x(Block):
def __init__(
self,
chan_in,
chan_out,
kernel_size: int = 3
):
super().__init__()
self.kernel_size = kernel_size
if npu_config is not None and npu_config.on_npu:
self.avg_pool = nn.AvgPool2d((kernel_size, 1), stride=(2, 1))
self.pad = nn.ReplicationPad3d((0, 0, 0, 0, self.kernel_size - 1, 0))
else:
self.conv = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))
def forward(self, x):
if npu_config is not None and npu_config.on_npu:
n, c, d, h, w = x.shape
x = self.pad(x)
x = x.view(n * c, -1, h * w)
pooled = self.avg_pool(x)
output = pooled.view(n, c, -1, h, w)
return output
else:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, self.kernel_size - 1, 1, 1)
)
x = torch.concatenate((first_frame_pad, x), dim=2)
return self.conv(x)
class TimeUpsample2x(Block):
def __init__(
self,
chan_in,
chan_out
):
super().__init__()
def forward(self, x):
if x.size(2) > 1:
x,x_= x[:,:,:1],x[:,:,1:]
x_= F.interpolate(x_, scale_factor=(2,1,1), mode='trilinear')
x = torch.concat([x, x_], dim=2)
return x
class TimeDownsampleRes2x(Block):
def __init__(
self,
in_channels,
out_channels,
kernel_size: int = 3,
mix_factor: float = 2.0,
):
super().__init__()
self.kernel_size = cast_tuple(kernel_size, 3)
if npu_config is not None and npu_config.on_npu:
self.avg_pool = nn.AvgPool2d((kernel_size, 1), stride=(2, 1))
self.pad = nn.ReplicationPad3d((0, 0, 0, 0, kernel_size - 1, 0))
else:
self.avg_pool = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))
self.conv = nn.Conv3d(
in_channels, out_channels, self.kernel_size, stride=(2,1,1), padding=(0,1,1)
)
self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
def forward(self, x):
alpha = torch.sigmoid(self.mix_factor)
if npu_config is not None and npu_config.on_npu:
n, c, d, h, w = x.shape
x_dtype = x.dtype
x = x.to(npu_config.replaced_type)
x = self.pad(x)
pad_x = x.view(n, c, -1, h, w)
avg_x = self.avg_pool(x.view(n * c, -1, h * w)).view(n, c, -1, h, w).to(x_dtype)
conv_x = npu_config.run_conv3d(self.conv, pad_x, x_dtype)
return alpha * avg_x + (1 - alpha) * conv_x
else:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, self.kernel_size[0] - 1, 1, 1)
)
x = torch.concatenate((first_frame_pad, x), dim=2)
return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(x)
class TimeUpsampleRes2x(Block):
def __init__(
self,
in_channels,
out_channels,
kernel_size: int = 3,
mix_factor: float = 2.0,
):
super().__init__()
self.conv = CausalConv3d(
in_channels, out_channels, kernel_size, padding=1
)
self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
def forward(self, x):
alpha = torch.sigmoid(self.mix_factor)
if x.size(2) > 1:
x,x_= x[:,:,:1],x[:,:,1:]
if npu_config is not None and npu_config.on_npu:
x_dtype = x_.dtype
x_ = x_.to(npu_config.replaced_type)
x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode='trilinear')
x_ = x_.to(x_dtype)
else:
x_= F.interpolate(x_, scale_factor=(2,1,1), mode='trilinear')
x = torch.concat([x, x_], dim=2)
return alpha * x + (1-alpha) * self.conv(x)
class Spatial2xTime2x3DDownsample(Block):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=0, stride=2)
def forward(self, x):
pad = (0,1,0,1,0,0)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Spatial2x3DDownsample(Block):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=0, stride=(1,2,2))
def forward(self, x):
pad = (0,1,0,1,0,0)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Spatial2x3DUpsample(Block):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x):
x = F.interpolate(x, scale_factor=(1,2,2), mode='trilinear')
return self.conv(x)
class Spatial2xTime2x3DUpsample(Block):
def __init__(
self,
in_channels,
out_channels,
t_interpolation="trilinear",
enable_cached=False,
):
super().__init__()
self.t_interpolation = t_interpolation
self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=1)
self.enable_cached = enable_cached
self.causal_cached = deque()
def forward(self, x):
if x.size(2) > 1 or len(self.causal_cached) > 0:
if self.enable_cached and len(self.causal_cached) > 0:
x = torch.cat([self.causal_cached.popleft(), x], dim=2)
self.causal_cached.append(x[:, :, -2:-1].clone())
x = F.interpolate(x, scale_factor=(2, 1, 1), mode=self.t_interpolation)
x = x[:, :, 2:]
x = F.interpolate(x, scale_factor=(1, 2, 2), mode="trilinear")
else:
if self.enable_cached:
self.causal_cached.append(x[:, :, -1:].clone())
x, x_ = x[:, :, :1], x[:, :, 1:]
x_ = F.interpolate(
x_, scale_factor=(2, 1, 1), mode=self.t_interpolation
)
x_ = F.interpolate(x_, scale_factor=(1, 2, 2), mode="trilinear")
x = F.interpolate(x, scale_factor=(1, 2, 2), mode="trilinear")
x = torch.concat([x, x_], dim=2)
else:
if self.enable_cached:
self.causal_cached.append(x[:, :, -1:].clone())
x = F.interpolate(x, scale_factor=(1, 2, 2), mode="trilinear")
return self.conv(x)
================================================
FILE: opensora/models/causalvideovae/model/modules/wavelet.py
================================================
import torch
import torch.nn.functional as F
import torch.nn as nn
from ..modules import CausalConv3d
from ..modules.ops import video_to_image
from einops import rearrange
try:
import torch_npu
from opensora.npu_config import npu_config, set_run_dtype
except Exception as e:
torch_npu = None
npu_config = None
class HaarWaveletTransform3D(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
h = torch.tensor([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]) * 0.3536
g = torch.tensor([[[1, -1], [1, -1]], [[1, -1], [1, -1]]]) * 0.3536
hh = torch.tensor([[[1, 1], [-1, -1]], [[1, 1], [-1, -1]]]) * 0.3536
gh = torch.tensor([[[1, -1], [-1, 1]], [[1, -1], [-1, 1]]]) * 0.3536
h_v = torch.tensor([[[1, 1], [1, 1]], [[-1, -1], [-1, -1]]]) * 0.3536
g_v = torch.tensor([[[1, -1], [1, -1]], [[-1, 1], [-1, 1]]]) * 0.3536
hh_v = torch.tensor([[[1, 1], [-1, -1]], [[-1, -1], [1, 1]]]) * 0.3536
gh_v = torch.tensor([[[1, -1], [-1, 1]], [[-1, 1], [1, -1]]]) * 0.3536
h = h.view(1, 1, 2, 2, 2)
g = g.view(1, 1, 2, 2, 2)
hh = hh.view(1, 1, 2, 2, 2)
gh = gh.view(1, 1, 2, 2, 2)
h_v = h_v.view(1, 1, 2, 2, 2)
g_v = g_v.view(1, 1, 2, 2, 2)
hh_v = hh_v.view(1, 1, 2, 2, 2)
gh_v = gh_v.view(1, 1, 2, 2, 2)
self.h_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)
self.g_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)
self.hh_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)
self.gh_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)
self.h_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)
self.g_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)
self.hh_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)
self.gh_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)
self.h_conv.conv.weight.data = h
self.g_conv.conv.weight.data = g
self.hh_conv.conv.weight.data = hh
self.gh_conv.conv.weight.data = gh
self.h_v_conv.conv.weight.data = h_v
self.g_v_conv.conv.weight.data = g_v
self.hh_v_conv.conv.weight.data = hh_v
self.gh_v_conv.conv.weight.data = gh_v
self.h_conv.requires_grad_(False)
self.g_conv.requires_grad_(False)
self.hh_conv.requires_grad_(False)
self.gh_conv.requires_grad_(False)
self.h_v_conv.requires_grad_(False)
self.g_v_conv.requires_grad_(False)
self.hh_v_conv.requires_grad_(False)
self.gh_v_conv.requires_grad_(False)
def forward(self, x):
assert x.dim() == 5
if torch_npu is not None:
dtype = x.dtype
x = x.to(npu_config.conv_dtype)
self.to(npu_config.conv_dtype)
b = x.shape[0]
x = rearrange(x, "b c t h w -> (b c) 1 t h w")
low_low_low = self.h_conv(x)
low_low_low = rearrange(low_low_low, "(b c) 1 t h w -> b c t h w", b=b)
low_low_high = self.g_conv(x)
low_low_high = rearrange(low_low_high, "(b c) 1 t h w -> b c t h w", b=b)
low_high_low = self.hh_conv(x)
low_high_low = rearrange(low_high_low, "(b c) 1 t h w -> b c t h w", b=b)
low_high_high = self.gh_conv(x)
low_high_high = rearrange(low_high_high, "(b c) 1 t h w -> b c t h w", b=b)
high_low_low = self.h_v_conv(x)
high_low_low = rearrange(high_low_low, "(b c) 1 t h w -> b c t h w", b=b)
high_low_high = self.g_v_conv(x)
high_low_high = rearrange(high_low_high, "(b c) 1 t h w -> b c t h w", b=b)
high_high_low = self.hh_v_conv(x)
high_high_low = rearrange(high_high_low, "(b c) 1 t h w -> b c t h w", b=b)
high_high_high = self.gh_v_conv(x)
high_high_high = rearrange(high_high_high, "(b c) 1 t h w -> b c t h w", b=b)
output = torch.cat(
[
low_low_low,
low_low_high,
low_high_low,
low_high_high,
high_low_low,
high_low_high,
high_high_low,
high_high_high,
],
dim=1,
)
if torch_npu is not None:
x = x.to(dtype)
output = output.to(dtype)
self.to(dtype)
return output
class InverseHaarWaveletTransform3D(nn.Module):
def __init__(self, enable_cached=False, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.register_buffer('h',
torch.tensor([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536
)
self.register_buffer('g',
torch.tensor([[[1, -1], [1, -1]], [[1, -1], [1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536
)
self.register_buffer('hh',
torch.tensor([[[1, 1], [-1, -1]], [[1, 1], [-1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536
)
self.register_buffer('gh',
torch.tensor([[[1, -1], [-1, 1]], [[1, -1], [-1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536
)
self.register_buffer('h_v',
torch.tensor([[[1, 1], [1, 1]], [[-1, -1], [-1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536
)
self.register_buffer('g_v',
torch.tensor([[[1, -1], [1, -1]], [[-1, 1], [-1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536
)
self.register_buffer('hh_v',
torch.tensor([[[1, 1], [-1, -1]], [[-1, -1], [1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536
)
self.register_buffer('gh_v',
torch.tensor([[[1, -1], [-1, 1]], [[-1, 1], [1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536
)
self.enable_cached = enable_cached
self.is_first_chunk = True
def forward(self, coeffs):
assert coeffs.dim() == 5
if torch_npu is not None:
dtype = coeffs.dtype
coeffs = coeffs.to(npu_config.conv_dtype)
self.h = self.h.to(npu_config.conv_dtype)
self.g = self.g.to(npu_config.conv_dtype)
self.hh = self.hh.to(npu_config.conv_dtype)
self.gh = self.gh.to(npu_config.conv_dtype)
self.h_v = self.h_v.to(npu_config.conv_dtype)
self.g_v = self.g_v.to(npu_config.conv_dtype)
self.hh_v = self.hh_v.to(npu_config.conv_dtype)
self.gh_v = self.gh_v.to(npu_config.conv_dtype)
b = coeffs.shape[0]
(
low_low_low,
low_low_high,
low_high_low,
low_high_high,
high_low_low,
high_low_high,
high_high_low,
high_high_high,
) = coeffs.chunk(8, dim=1)
low_low_low = rearrange(low_low_low, "b c t h w -> (b c) 1 t h w")
low_low_high = rearrange(low_low_high, "b c t h w -> (b c) 1 t h w")
low_high_low = rearrange(low_high_low, "b c t h w -> (b c) 1 t h w")
low_high_high = rearrange(low_high_high, "b c t h w -> (b c) 1 t h w")
high_low_low = rearrange(high_low_low, "b c t h w -> (b c) 1 t h w")
high_low_high = rearrange(high_low_high, "b c t h w -> (b c) 1 t h w")
high_high_low = rearrange(high_high_low, "b c t h w -> (b c) 1 t h w")
high_high_high = rearrange(high_high_high, "b c t h w -> (b c) 1 t h w")
low_low_low = F.conv_transpose3d(low_low_low, self.h, stride=2)
low_low_high = F.conv_transpose3d(low_low_high, self.g, stride=2)
low_high_low = F.conv_transpose3d(low_high_low, self.hh, stride=2)
low_high_high = F.conv_transpose3d(low_high_high, self.gh, stride=2)
high_low_low = F.conv_transpose3d(high_low_low, self.h_v, stride=2)
high_low_high = F.conv_transpose3d(high_low_high, self.g_v, stride=2)
high_high_low = F.conv_transpose3d(high_high_low, self.hh_v, stride=2)
high_high_high = F.conv_transpose3d(high_high_high, self.gh_v, stride=2)
if self.enable_cached and not self.is_first_chunk:
reconstructed = (
low_low_low
+ low_low_high
+ low_high_low
+ low_high_high
+ high_low_low
+ high_low_high
+ high_high_low
+ high_high_high
)
else:
reconstructed = (
low_low_low[:, :, 1:]
+ low_low_high[:, :, 1:]
+ low_high_low[:, :, 1:]
+ low_high_high[:, :, 1:]
+ high_low_low[:, :, 1:]
+ high_low_high[:, :, 1:]
+ high_high_low[:, :, 1:]
+ high_high_high[:, :, 1:]
)
reconstructed = rearrange(reconstructed, "(b c) 1 t h w -> b c t h w", b=b)
if torch_npu is not None:
coeffs = coeffs.to(dtype)
reconstructed = reconstructed.to(dtype)
self.h = self.h.to(dtype)
self.g = self.g.to(dtype)
self.hh = self.hh.to(dtype)
self.gh = self.gh.to(dtype)
self.h_v = self.h_v.to(dtype)
self.g_v = self.g_v.to(dtype)
self.hh_v = self.hh_v.to(dtype)
self.gh_v = self.gh_v.to(dtype)
return reconstructed
class HaarWaveletTransform2D(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer('aa', torch.tensor([[1, 1], [1, 1]]).view(1, 1, 2, 2) / 2)
self.register_buffer('ad', torch.tensor([[1, 1], [-1, -1]]).view(1, 1, 2, 2) / 2)
self.register_buffer('da', torch.tensor([[1, -1], [1, -1]]).view(1, 1, 2, 2) / 2)
self.register_buffer('dd', torch.tensor([[1, -1], [-1, 1]]).view(1, 1, 2, 2) / 2)
@video_to_image
def forward(self, x):
b, c, h, w = x.shape
x = x.reshape(b * c, 1, h, w)
low_low = F.conv2d(x, self.aa, stride=2).reshape(b, c, h // 2, w // 2)
low_high = F.conv2d(x, self.ad, stride=2).reshape(b, c, h // 2, w // 2)
high_low = F.conv2d(x, self.da, stride=2).reshape(b, c, h // 2, w // 2)
high_high = F.conv2d(x, self.dd, stride=2).reshape(b, c, h // 2, w // 2)
coeffs = torch.cat([low_low, low_high, high_low, high_high], dim=1)
return coeffs
class InverseHaarWaveletTransform2D(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer('aa', torch.tensor([[1, 1], [1, 1]]).view(1, 1, 2, 2) / 2)
self.register_buffer('ad', torch.tensor([[1, 1], [-1, -1]]).view(1, 1, 2, 2) / 2)
self.register_buffer('da', torch.tensor([[1, -1], [1, -1]]).view(1, 1, 2, 2) / 2)
self.register_buffer('dd', torch.tensor([[1, -1], [-1, 1]]).view(1, 1, 2, 2) / 2)
@video_to_image
def forward(self, coeffs):
low_low, low_high, high_low, high_high = coeffs.chunk(4, dim=1)
b, c, height_half, width_half = low_low.shape
height = height_half * 2
width = width_half * 2
low_low = F.conv_transpose2d(
low_low.reshape(b * c, 1, height_half, width_half), self.aa, stride=2
)
low_high = F.conv_transpose2d(
low_high.reshape(b * c, 1, height_half, width_half), self.ad, stride=2
)
high_low = F.conv_transpose2d(
high_low.reshape(b * c, 1, height_half, width_half), self.da, stride=2
)
high_high = F.conv_transpose2d(
high_high.reshape(b * c, 1, height_half, width_half), self.dd, stride=2
)
return (low_low + low_high + high_low + high_high).reshape(b, c, height, width)
================================================
FILE: opensora/models/causalvideovae/model/registry.py
================================================
class ModelRegistry:
_models = {}
@classmethod
def register(cls, model_name):
def decorator(model_class):
cls._models[model_name] = model_class
return model_class
return decorator
@classmethod
def get_model(cls, model_name):
return cls._models.get(model_name)
================================================
FILE: opensora/models/causalvideovae/model/trainer_videobase.py
================================================
from transformers import Trainer
import torch.nn.functional as F
from typing import Optional
import os
import torch
from transformers.utils import WEIGHTS_NAME
import json
class VideoBaseTrainer(Trainer):
def _save(self, output_dir: Optional[str] = None, state_dict=None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
if state_dict is None:
state_dict = self.model.state_dict()
# get model config
model_config = self.model.config.to_dict()
# add more information
model_config['model'] = self.model.__class__.__name__
with open(os.path.join(output_dir, "config.json"), "w") as file:
json.dump(self.model.config.to_dict(), file)
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
================================================
FILE: opensora/models/causalvideovae/model/utils/__init__.py
================================================
================================================
FILE: opensora/models/causalvideovae/model/utils/distrib_utils.py
================================================
import torch
import numpy as np
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device, dtype=self.parameters.dtype)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1,2,3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
================================================
FILE: opensora/models/causalvideovae/model/utils/module_utils.py
================================================
import importlib
Module = str
MODULES_BASE = "opensora.models.causalvideovae.model.modules."
def resolve_str_to_obj(str_val, append=True):
if append:
str_val = MODULES_BASE + str_val
module_name, class_name = str_val.rsplit('.', 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
def create_instance(module_class_str: str, **kwargs):
module_name, class_name = module_class_str.rsplit('.', 1)
module = importlib.import_module(module_name)
class_ = getattr(module, class_name)
return class_(**kwargs)
================================================
FILE: opensora/models/causalvideovae/model/utils/scheduler_utils.py
================================================
import torch
def cosine_scheduler(step, max_steps, value_base=1, value_end=0):
step = torch.tensor(step)
cosine_value = 0.5 * (1 + torch.cos(torch.pi * step / max_steps))
value = value_end + (value_base - value_end) * cosine_value
return value
================================================
FILE: opensora/models/causalvideovae/model/utils/video_utils.py
================================================
import torch
import numpy as np
def tensor_to_video(x):
x = (x * 2 - 1).detach().cpu()
x = torch.clamp(x, -1, 1)
x = (x + 1) / 2
x = x.permute(1, 0, 2, 3).float().numpy() # c t h w -> t c h w
x = (255 * x).astype(np.uint8)
return x
================================================
FILE: opensora/models/causalvideovae/model/utils/wavelet_utils.py
================================================
import torch
import torch.nn.functional as F
import torch.nn as nn
from ..modules import CausalConv3d
from einops import rearrange
class HaarWaveletTransform3D(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
h = torch.tensor([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]) * 0.3536
g = torch.tensor([[[1, -1], [1, -1]], [[1, -1], [1, -1]]]) * 0.3536
hh = torch.tensor([[[1, 1], [-1, -1]], [[1, 1], [-1, -1]]]) * 0.3536
gh = torch.tensor([[[1, -1], [-1, 1]], [[1, -1], [-1, 1]]]) * 0.3536
h_v = torch.tensor([[[1, 1], [1, 1]], [[-1, -1], [-1, -1]]]) * 0.3536
g_v = torch.tensor([[[1, -1], [1, -1]], [[-1, 1], [-1, 1]]]) * 0.3536
hh_v = torch.tensor([[[1, 1], [-1, -1]], [[-1, -1], [1, 1]]]) * 0.3536
gh_v = torch.tensor([[[1, -1], [-1, 1]], [[-1, 1], [1, -1]]]) * 0.3536
h = h.view(1, 1, 2, 2, 2)
g = g.view(1, 1, 2, 2, 2)
hh = hh.view(1, 1, 2, 2, 2)
gh = gh.view(1, 1, 2, 2, 2)
h_v = h_v.view(1, 1, 2, 2, 2)
g_v = g_v.view(1, 1, 2, 2, 2)
hh_v = hh_v.view(1, 1, 2, 2, 2)
gh_v = gh_v.view(1, 1, 2, 2, 2)
self.h_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)
self.g_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)
self.hh_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)
self.gh_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)
self.h_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)
self.g_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)
self.hh_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)
self.gh_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False)
self.h_conv.conv.weight.data = h
self.g_conv.conv.weight.data = g
self.hh_conv.conv.weight.data = hh
self.gh_conv.conv.weight.data = gh
self.h_v_conv.conv.weight.data = h_v
self.g_v_conv.conv.weight.data = g_v
self.hh_v_conv.conv.weight.data = hh_v
self.gh_v_conv.conv.weight.data = gh_v
self.h_conv.requires_grad_(False)
self.g_conv.requires_grad_(False)
self.hh_conv.requires_grad_(False)
self.gh_conv.requires_grad_(False)
self.h_v_conv.requires_grad_(False)
self.g_v_conv.requires_grad_(False)
self.hh_v_conv.requires_grad_(False)
self.gh_v_conv.requires_grad_(False)
def forward(self, x):
assert x.dim() == 5
b = x.shape[0]
x = rearrange(x, "b c t h w -> (b c) 1 t h w")
low_low_low = self.h_conv(x)
low_low_low = rearrange(low_low_low, "(b c) 1 t h w -> b c t h w", b=b)
low_low_high = self.g_conv(x)
low_low_high = rearrange(low_low_high, "(b c) 1 t h w -> b c t h w", b=b)
low_high_low = self.hh_conv(x)
low_high_low = rearrange(low_high_low, "(b c) 1 t h w -> b c t h w", b=b)
low_high_high = self.gh_conv(x)
low_high_high = rearrange(low_high_high, "(b c) 1 t h w -> b c t h w", b=b)
high_low_low = self.h_v_conv(x)
high_low_low = rearrange(high_low_low, "(b c) 1 t h w -> b c t h w", b=b)
high_low_high = self.g_v_conv(x)
high_low_high = rearrange(high_low_high, "(b c) 1 t h w -> b c t h w", b=b)
high_high_low = self.hh_v_conv(x)
high_high_low = rearrange(high_high_low, "(b c) 1 t h w -> b c t h w", b=b)
high_high_high = self.gh_v_conv(x)
high_high_high = rearrange(high_high_high, "(b c) 1 t h w -> b c t h w", b=b)
output = torch.cat(
[
low_low_low,
low_low_high,
low_high_low,
low_high_high,
high_low_low,
high_low_high,
high_high_low,
high_high_high,
],
dim=1,
)
return output
class InverseHaarWaveletTransform3D(nn.Module):
def __init__(self, enable_cached=False, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.register_buffer('h',
torch.tensor([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536
)
self.register_buffer('g',
torch.tensor([[[1, -1], [1, -1]], [[1, -1], [1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536
)
self.register_buffer('hh',
torch.tensor([[[1, 1], [-1, -1]], [[1, 1], [-1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536
)
self.register_buffer('gh',
torch.tensor([[[1, -1], [-1, 1]], [[1, -1], [-1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536
)
self.register_buffer('h_v',
torch.tensor([[[1, 1], [1, 1]], [[-1, -1], [-1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536
)
self.register_buffer('g_v',
torch.tensor([[[1, -1], [1, -1]], [[-1, 1], [-1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536
)
self.register_buffer('hh_v',
torch.tensor([[[1, 1], [-1, -1]], [[-1, -1], [1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536
)
self.register_buffer('gh_v',
torch.tensor([[[1, -1], [-1, 1]], [[-1, 1], [1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536
)
self.enable_cached = enable_cached
self.causal_cached = None
def forward(self, coeffs):
assert coeffs.dim() == 5
b = coeffs.shape[0]
(
low_low_low,
low_low_high,
low_high_low,
low_high_high,
high_low_low,
high_low_high,
high_high_low,
high_high_high,
) = coeffs.chunk(8, dim=1)
low_low_low = rearrange(low_low_low, "b c t h w -> (b c) 1 t h w")
low_low_high = rearrange(low_low_high, "b c t h w -> (b c) 1 t h w")
low_high_low = rearrange(low_high_low, "b c t h w -> (b c) 1 t h w")
low_high_high = rearrange(low_high_high, "b c t h w -> (b c) 1 t h w")
high_low_low = rearrange(high_low_low, "b c t h w -> (b c) 1 t h w")
high_low_high = rearrange(high_low_high, "b c t h w -> (b c) 1 t h w")
high_high_low = rearrange(high_high_low, "b c t h w -> (b c) 1 t h w")
high_high_high = rearrange(high_high_high, "b c t h w -> (b c) 1 t h w")
low_low_low = F.conv_transpose3d(low_low_low, self.h, stride=2)
low_low_high = F.conv_transpose3d(low_low_high, self.g, stride=2)
low_high_low = F.conv_transpose3d(low_high_low, self.hh, stride=2)
low_high_high = F.conv_transpose3d(low_high_high, self.gh, stride=2)
high_low_low = F.conv_transpose3d(high_low_low, self.h_v, stride=2)
high_low_high = F.conv_transpose3d(high_low_high, self.g_v, stride=2)
high_high_low = F.conv_transpose3d(high_high_low, self.hh_v, stride=2)
high_high_high = F.conv_transpose3d(high_high_high, self.gh_v, stride=2)
if self.enable_cached and self.causal_cached:
reconstructed = (
low_low_low
+ low_low_high
+ low_high_low
+ low_high_high
+ high_low_low
+ high_low_high
+ high_high_low
+ high_high_high
)
else:
reconstructed = (
low_low_low[:, :, 1:]
+ low_low_high[:, :, 1:]
+ low_high_low[:, :, 1:]
+ low_high_high[:, :, 1:]
+ high_low_low[:, :, 1:]
+ high_low_high[:, :, 1:]
+ high_high_low[:, :, 1:]
+ high_high_high[:, :, 1:]
)
self.causal_cached = True
reconstructed = rearrange(reconstructed, "(b c) 1 t h w -> b c t h w", b=b)
return reconstructed
class HaarWaveletTransform2D(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer('aa', torch.tensor([[1, 1], [1, 1]]).view(1, 1, 2, 2) / 2)
self.register_buffer('ad', torch.tensor([[1, 1], [-1, -1]]).view(1, 1, 2, 2) / 2)
self.register_buffer('da', torch.tensor([[1, -1], [1, -1]]).view(1, 1, 2, 2) / 2)
self.register_buffer('dd', torch.tensor([[1, -1], [-1, 1]]).view(1, 1, 2, 2) / 2)
def forward(self, x):
b, c, h, w = x.shape
x = x.reshape(b * c, 1, h, w)
low_low = F.conv2d(x, self.aa, stride=2).reshape(b, c, h // 2, w // 2)
low_high = F.conv2d(x, self.ad, stride=2).reshape(b, c, h // 2, w // 2)
high_low = F.conv2d(x, self.da, stride=2).reshape(b, c, h // 2, w // 2)
high_high = F.conv2d(x, self.dd, stride=2).reshape(b, c, h // 2, w // 2)
coeffs = torch.cat([low_low, low_high, high_low, high_high], dim=1)
return coeffs
class InverseHaarWaveletTransform2D(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer('aa', torch.tensor([[1, 1], [1, 1]]).view(1, 1, 2, 2) / 2)
self.register_buffer('ad', torch.tensor([[1, 1], [-1, -1]]).view(1, 1, 2, 2) / 2)
self.register_buffer('da', torch.tensor([[1, -1], [1, -1]]).view(1, 1, 2, 2) / 2)
self.register_buffer('dd', torch.tensor([[1, -1], [-1, 1]]).view(1, 1, 2, 2) / 2)
def forward(self, coeffs):
low_low, low_high, high_low, high_high = coeffs.chunk(4, dim=1)
b, c, height_half, width_half = low_low.shape
height = height_half * 2
width = width_half * 2
low_low = F.conv_transpose2d(
low_low.reshape(b * c, 1, height_half, width_half), self.aa, stride=2
)
low_high = F.conv_transpose2d(
low_high.reshape(b * c, 1, height_half, width_half), self.ad, stride=2
)
high_low = F.conv_transpose2d(
high_low.reshape(b * c, 1, height_half, width_half), self.da, stride=2
)
high_high = F.conv_transpose2d(
high_high.reshape(b * c, 1, height_half, width_half), self.dd, stride=2
)
return (low_low + low_high + high_low + high_high).reshape(b, c, height, width)
================================================
FILE: opensora/models/causalvideovae/model/vae/__init__.py
================================================
from .modeling_causalvae import CausalVAEModel
from .modeling_wfvae import WFVAEModel
from einops import rearrange
from torch import nn
================================================
FILE: opensora/models/causalvideovae/model/vae/modeling_causalvae.py
================================================
try:
import torch_npu
from opensora.npu_config import npu_config
except:
torch_npu = None
npu_config = None
from ..modeling_videobase import VideoBaseAE
from ..modules import Normalize
from ..modules.ops import nonlinearity
from typing import Tuple
import torch.nn as nn
from ..utils.module_utils import resolve_str_to_obj, Module
from ..utils.distrib_utils import DiagonalGaussianDistribution
from ..registry import ModelRegistry
import torch
from diffusers.configuration_utils import register_to_config
from copy import deepcopy
import os
class Encoder(nn.Module):
def __init__(
self,
z_channels: int,
hidden_size: int,
hidden_size_mult: Tuple[int] = (1, 2, 4, 4),
attn_resolutions: Tuple[int] = (16,),
conv_in: Module = "Conv2d",
conv_out: Module = "CasualConv3d",
attention: Module = "AttnBlock",
resnet_blocks: Tuple[Module] = (
"ResnetBlock2D",
"ResnetBlock2D",
"ResnetBlock2D",
"ResnetBlock3D",
),
spatial_downsample: Tuple[Module] = (
"Downsample",
"Downsample",
"Downsample",
"",
),
temporal_downsample: Tuple[Module] = ("", "", "TimeDownsampleRes2x", ""),
mid_resnet: Module = "ResnetBlock3D",
dropout: float = 0.0,
resolution: int = 256,
num_res_blocks: int = 2,
double_z: bool = True,
norm_type: str = "groupnorm",
) -> None:
super().__init__()
assert len(resnet_blocks) == len(hidden_size_mult), print(
hidden_size_mult, resnet_blocks
)
# ---- Config ----
self.num_resolutions = len(hidden_size_mult)
self.resolution = resolution
self.num_res_blocks = num_res_blocks
# ---- In ----
self.conv_in = resolve_str_to_obj(conv_in)(
3, hidden_size, kernel_size=3, stride=1, padding=1
)
# ---- Downsample ----
curr_res = resolution
in_ch_mult = (1,) + tuple(hidden_size_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = hidden_size * in_ch_mult[i_level]
block_out = hidden_size * hidden_size_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
resolve_str_to_obj(resnet_blocks[i_level])(
in_channels=block_in,
out_channels=block_out,
dropout=dropout,
norm_type=norm_type
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(resolve_str_to_obj(attention)(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if spatial_downsample[i_level]:
down.downsample = resolve_str_to_obj(spatial_downsample[i_level])(
block_in, block_in
)
curr_res = curr_res // 2
if temporal_downsample[i_level]:
down.time_downsample = resolve_str_to_obj(temporal_downsample[i_level])(
block_in, block_in
)
self.down.append(down)
# ---- Mid ----
self.mid = nn.Module()
self.mid.block_1 = resolve_str_to_obj(mid_resnet)(
in_channels=block_in,
out_channels=block_in,
dropout=dropout,
norm_type=norm_type
)
self.mid.attn_1 = resolve_str_to_obj(attention)(block_in)
self.mid.block_2 = resolve_str_to_obj(mid_resnet)(
in_channels=block_in,
out_channels=block_in,
dropout=dropout,
norm_type=norm_type
)
# ---- Out ----
self.norm_out = Normalize(block_in, norm_type=norm_type)
self.conv_out = resolve_str_to_obj(conv_out)(
block_in,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1,
)
def forward(self, x):
h = self.conv_in(x)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if hasattr(self.down[i_level], "downsample"):
h = self.down[i_level].downsample(h)
if hasattr(self.down[i_level], "time_downsample"):
h = self.down[i_level].time_downsample(h)
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
if npu_config is None:
h = self.norm_out(h)
else:
h = npu_config.run_group_norm(self.norm_out, h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(
self,
z_channels: int,
hidden_size: int,
hidden_size_mult: Tuple[int] = (1, 2, 4, 4),
attn_resolutions: Tuple[int] = (16,),
conv_in: Module = "Conv2d",
conv_out: Module = "CasualConv3d",
attention: Module = "AttnBlock",
resnet_blocks: Tuple[Module] = (
"ResnetBlock3D",
"ResnetBlock3D",
"ResnetBlock3D",
"ResnetBlock3D",
),
spatial_upsample: Tuple[Module] = (
"",
"SpatialUpsample2x",
"SpatialUpsample2x",
"SpatialUpsample2x",
),
temporal_upsample: Tuple[Module] = ("", "", "", "TimeUpsampleRes2x"),
mid_resnet: Module = "ResnetBlock3D",
dropout: float = 0.0,
resolution: int = 256,
num_res_blocks: int = 2,
norm_type: str = "groupnorm",
):
super().__init__()
# ---- Config ----
self.num_resolutions = len(hidden_size_mult)
self.resolution = resolution
self.num_res_blocks = num_res_blocks
# ---- In ----
block_in = hidden_size * hidden_size_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.conv_in = resolve_str_to_obj(conv_in)(
z_channels, block_in, kernel_size=3, padding=1
)
# ---- Mid ----
self.mid = nn.Module()
self.mid.block_1 = resolve_str_to_obj(mid_resnet)(
in_channels=block_in,
out_channels=block_in,
dropout=dropout,
norm_type=norm_type
)
self.mid.attn_1 = resolve_str_to_obj(attention)(block_in, norm_type=norm_type)
self.mid.block_2 = resolve_str_to_obj(mid_resnet)(
in_channels=block_in,
out_channels=block_in,
dropout=dropout,
norm_type=norm_type
)
# ---- Upsample ----
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = hidden_size * hidden_size_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
resolve_str_to_obj(resnet_blocks[i_level])(
in_channels=block_in,
out_channels=block_out,
dropout=dropout,
norm_type=norm_type
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(resolve_str_to_obj(attention)(block_in, norm_type=norm_type))
up = nn.Module()
up.block = block
up.attn = attn
if spatial_upsample[i_level]:
up.upsample = resolve_str_to_obj(spatial_upsample[i_level])(
block_in, block_in
)
curr_res = curr_res * 2
if temporal_upsample[i_level]:
up.time_upsample = resolve_str_to_obj(temporal_upsample[i_level])(
block_in, block_in
)
self.up.insert(0, up)
# ---- Out ----
self.norm_out = Normalize(block_in, norm_type=norm_type)
self.conv_out = resolve_str_to_obj(conv_out)(
block_in, 3, kernel_size=3, padding=1
)
def forward(self, z):
h = self.conv_in(z)
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if hasattr(self.up[i_level], "upsample"):
h = self.up[i_level].upsample(h)
if hasattr(self.up[i_level], "time_upsample"):
h = self.up[i_level].time_upsample(h)
if npu_config is None:
h = self.norm_out(h)
else:
h = npu_config.run_group_norm(self.norm_out, h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
@ModelRegistry.register("CausalVAE")
class CausalVAEModel(VideoBaseAE):
@register_to_config
def __init__(
self,
hidden_size: int = 128,
z_channels: int = 4,
hidden_size_mult: Tuple[int] = (1, 2, 4, 4),
attn_resolutions: Tuple[int] = [],
dropout: float = 0.0,
resolution: int = 256,
double_z: bool = True,
embed_dim: int = 4,
num_res_blocks: int = 2,
q_conv: str = "CausalConv3d",
encoder_conv_in: Module = "CausalConv3d",
encoder_conv_out: Module = "CausalConv3d",
encoder_attention: Module = "AttnBlock3D",
encoder_resnet_blocks: Tuple[Module] = (
"ResnetBlock3D",
"ResnetBlock3D",
"ResnetBlock3D",
"ResnetBlock3D",
),
encoder_spatial_downsample: Tuple[Module] = (
"SpatialDownsample2x",
"SpatialDownsample2x",
"SpatialDownsample2x",
"",
),
encoder_temporal_downsample: Tuple[Module] = (
"",
"TimeDownsample2x",
"TimeDownsample2x",
"",
),
encoder_mid_resnet: Module = "ResnetBlock3D",
decoder_conv_in: Module = "CausalConv3d",
decoder_conv_out: Module = "CausalConv3d",
decoder_attention: Module = "AttnBlock3D",
decoder_resnet_blocks: Tuple[Module] = (
"ResnetBlock3D",
"ResnetBlock3D",
"ResnetBlock3D",
"ResnetBlock3D",
),
decoder_spatial_upsample: Tuple[Module] = (
"",
"SpatialUpsample2x",
"SpatialUpsample2x",
"SpatialUpsample2x",
),
decoder_temporal_upsample: Tuple[Module] = (
"",
"",
"TimeUpsample2x",
"TimeUpsample2x",
),
decoder_mid_resnet: Module = "ResnetBlock3D",
use_quant_layer: bool = True,
norm_type: str = "groupnorm",
) -> None:
super().__init__()
self.tile_sample_min_size = 512000
self.tile_sample_min_size_t = 33
self.tile_sample_min_size_dec = 512
self.tile_sample_min_size_t_dec = 17
self.tile_latent_min_size = int(self.tile_sample_min_size_dec / (2 ** (len(hidden_size_mult) - 1)))
self.tile_latent_min_size_t = int((self.tile_sample_min_size_t_dec-1) / 4) + 1
self.tile_overlap_t = 2
self.tile_overlap_factor = 0.125
self.use_tiling = False
self.use_quant_layer = use_quant_layer
self.encoder = Encoder(
z_channels=z_channels,
hidden_size=hidden_size,
hidden_size_mult=hidden_size_mult,
attn_resolutions=attn_resolutions,
conv_in=encoder_conv_in,
conv_out=encoder_conv_out,
attention=encoder_attention,
resnet_blocks=encoder_resnet_blocks,
spatial_downsample=encoder_spatial_downsample,
temporal_downsample=encoder_temporal_downsample,
mid_resnet=encoder_mid_resnet,
dropout=dropout,
resolution=resolution,
num_res_blocks=num_res_blocks,
double_z=double_z,
norm_type=norm_type
)
self.decoder = Decoder(
z_channels=z_channels,
hidden_size=hidden_size,
hidden_size_mult=hidden_size_mult,
attn_resolutions=attn_resolutions,
conv_in=decoder_conv_in,
conv_out=decoder_conv_out,
attention=decoder_attention,
resnet_blocks=decoder_resnet_blocks,
spatial_upsample=decoder_spatial_upsample,
temporal_upsample=decoder_temporal_upsample,
mid_resnet=decoder_mid_resnet,
dropout=dropout,
resolution=resolution,
num_res_blocks=num_res_blocks,
norm_type=norm_type
)
if self.use_quant_layer:
quant_conv_cls = resolve_str_to_obj(q_conv)
self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1)
self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1)
def get_encoder(self):
if self.use_quant_layer:
return [self.quant_conv, self.encoder]
return [self.encoder]
def get_decoder(self):
if self.use_quant_layer:
return [self.post_quant_conv, self.decoder]
return [self.decoder]
def encode(self, x):
if self.use_tiling and (
x.shape[-1] > self.tile_sample_min_size
or x.shape[-2] > self.tile_sample_min_size
or x.shape[-3] > self.tile_sample_min_size_t
):
# import ipdb;ipdb.set_trace()
return self.tiled_encode(x)
h = self.encoder(x)
if self.use_quant_layer:
h = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(h)
return posterior
def decode(self, z):
if self.use_tiling and (
z.shape[-1] > self.tile_latent_min_size
or z.shape[-2] > self.tile_latent_min_size
or z.shape[-3] > self.tile_latent_min_size_t
):
return self.tiled_decode(z)
if self.use_quant_layer:
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def on_train_start(self):
self.ema = deepcopy(self) if self.save_ema == True else None
def get_last_layer(self):
if hasattr(self.decoder.conv_out, "conv"):
return self.decoder.conv_out.conv.weight
else:
return self.decoder.conv_out.weight
def blend_v(
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
1 - y / blend_extent
) + b[:, :, :, y, :] * (y / blend_extent)
return b
def blend_h(
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
) -> torch.Tensor:
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
1 - x / blend_extent
) + b[:, :, :, :, x] * (x / blend_extent)
return b
def tiled_encode(self, x):
t = x.shape[2]
t_chunk_idx = [i for i in range(0, t, self.tile_sample_min_size_t-1)]
# print('tiled_encode', t_chunk_idx)
if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0:
t_chunk_start_end = [[0, t]]
else:
t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i+1]+1+(self.tile_overlap_t-1)*4] for i in range(len(t_chunk_idx)-1)]
if t_chunk_start_end[-1][-1] > t:
t_chunk_start_end[-1][-1] = t
elif t_chunk_start_end[-1][-1] < t:
last_start_end = [t_chunk_idx[-1], t]
t_chunk_start_end.append(last_start_end)
moments = []
# print('tiled_encode t_chunk_start_end', t_chunk_start_end)
for idx, (start, end) in enumerate(t_chunk_start_end):
chunk_x = x[:, :, start: end]
if idx != 0:
moment = self.tiled_encode2d(chunk_x, return_moments=True)[:, :, 1+(self.tile_overlap_t-1):]
else:
moment = self.tiled_encode2d(chunk_x, return_moments=True)
moments.append(moment)
moments = torch.cat(moments, dim=2)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def tiled_decode(self, x):
t = x.shape[2]
t_chunk_idx = [i for i in range(0, t, self.tile_latent_min_size_t-1)]
# print('tiled_decode', t_chunk_idx)
if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0:
t_chunk_start_end = [[0, t]]
else:
t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i+1]+1+(self.tile_overlap_t-1)] for i in range(len(t_chunk_idx)-1)]
if t_chunk_start_end[-1][-1] > t:
t_chunk_start_end[-1][-1] = t
elif t_chunk_start_end[-1][-1] < t:
last_start_end = [t_chunk_idx[-1], t]
t_chunk_start_end.append(last_start_end)
dec_ = []
# print('tiled_decode t_chunk_start_end', t_chunk_start_end)
for idx, (start, end) in enumerate(t_chunk_start_end):
# import ipdb;ipdb.set_trace()
chunk_x = x[:, :, start: end]
if idx != 0:
dec = self.tiled_decode2d(chunk_x)[:, :, 1+(self.tile_overlap_t-1)*4:]
else:
dec = self.tiled_decode2d(chunk_x)
# print(chunk_x.shape, dec.shape)
dec_.append(dec)
dec_ = torch.cat(dec_, dim=2)
return dec_
def tiled_encode2d(self, x, return_moments=False):
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# print('overlap_size, blend_extent, row_limit', overlap_size, blend_extent, row_limit)
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[3], overlap_size):
row = []
for j in range(0, x.shape[4], overlap_size):
# print(i, j)
tile = x[
:,
:,
:,
i : i + self.tile_sample_min_size,
j : j + self.tile_sample_min_size,
]
tile = self.encoder(tile)
if self.use_quant_layer:
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=4))
moments = torch.cat(result_rows, dim=3)
posterior = DiagonalGaussianDistribution(moments)
if return_moments:
return moments
return posterior
def tiled_decode2d(self, z):
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
# Split z into overlapping 64x64 tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
# print('tiled_decode2d', list(range(0, z.shape[3], overlap_size)), list(range(0, z.shape[4], overlap_size)))
# import ipdb;ipdb.set_trace()
rows = []
for i in range(0, z.shape[3], overlap_size):
row = []
for j in range(0, z.shape[4], overlap_size):
tile = z[
:,
:,
:,
i : i + self.tile_latent_min_size,
j : j + self.tile_latent_min_size,
]
if self.use_quant_layer:
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=4))
dec = torch.cat(result_rows, dim=3)
return dec
def enable_tiling(self, use_tiling: bool = True):
self.use_tiling = use_tiling
def disable_tiling(self):
self.enable_tiling(False)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")
print("init from " + path)
if (
"ema_state_dict" in sd
and len(sd["ema_state_dict"]) > 0
and os.environ.get("NOT_USE_EMA_MODEL", 0) == 0
):
print("Load from ema model!")
sd = sd["ema_state_dict"]
sd = {key.replace("module.", ""): value for key, value in sd.items()}
elif "state_dict" in sd:
print("Load from normal model!")
if "gen_model" in sd["state_dict"]:
sd = sd["state_dict"]["gen_model"]
else:
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
================================================
FILE: opensora/models/causalvideovae/model/vae/modeling_wfvae.py
================================================
try:
import torch_npu
from opensora.npu_config import npu_config
except:
torch_npu = None
npu_config = None
from ..modeling_videobase import VideoBaseAE
from diffusers.configuration_utils import register_to_config
import torch
import torch.nn as nn
from ..modules import (
ResnetBlock2D,
ResnetBlock3D,
Conv2d,
HaarWaveletTransform3D,
InverseHaarWaveletTransform3D,
CausalConv3d,
Normalize,
AttnBlock3DFix,
nonlinearity,
)
import torch.nn as nn
from ..utils.distrib_utils import DiagonalGaussianDistribution
import torch
from copy import deepcopy
import os
from ..registry import ModelRegistry
from einops import rearrange
from collections import deque
from ..utils.module_utils import resolve_str_to_obj, Module
from typing import List
class Encoder(VideoBaseAE):
@register_to_config
def __init__(
self,
latent_dim: int = 8,
base_channels: int = 128,
num_resblocks: int = 2,
energy_flow_hidden_size: int = 64,
dropout: float = 0.0,
attention_type: str = "AttnBlock3DFix",
use_attention: bool = True,
norm_type: str = "groupnorm",
l1_dowmsample_block: str = "Downsample",
l1_downsample_wavelet: str = "HaarWaveletTransform2D",
l2_dowmsample_block: str = "Spatial2xTime2x3DDownsample",
l2_downsample_wavelet: str = "HaarWaveletTransform3D",
) -> None:
super().__init__()
self.down1 = nn.Sequential(
Conv2d(24, base_channels, kernel_size=3, stride=1, padding=1),
*[
ResnetBlock2D(
in_channels=base_channels,
out_channels=base_channels,
dropout=dropout,
norm_type=norm_type,
)
for _ in range(num_resblocks)
],
resolve_str_to_obj(l1_dowmsample_block)(in_channels=base_channels, out_channels=base_channels),
)
self.down2 = nn.Sequential(
Conv2d(
base_channels + energy_flow_hidden_size,
base_channels * 2,
kernel_size=3,
stride=1,
padding=1,
),
*[
ResnetBlock3D(
in_channels=base_channels * 2,
out_channels=base_channels * 2,
dropout=dropout,
norm_type=norm_type,
)
for _ in range(num_resblocks)
],
resolve_str_to_obj(l2_dowmsample_block)(base_channels * 2, base_channels * 2),
)
# Connection
if l1_dowmsample_block == "Downsample": # Bad code. For temporal usage.
l1_channels = 12
else:
l1_channels = 24
self.connect_l1 = Conv2d(
l1_channels, energy_flow_hidden_size, kernel_size=3, stride=1, padding=1
)
self.connect_l2 = Conv2d(
24, energy_flow_hidden_size, kernel_size=3, stride=1, padding=1
)
# Mid
mid_layers = [
ResnetBlock3D(
in_channels=base_channels * 2 + energy_flow_hidden_size,
out_channels=base_channels * 4,
dropout=dropout,
norm_type=norm_type,
),
ResnetBlock3D(
in_channels=base_channels * 4,
out_channels=base_channels * 4,
dropout=dropout,
norm_type=norm_type,
),
]
if use_attention:
mid_layers.insert(
1, resolve_str_to_obj(attention_type)(in_channels=base_channels * 4, norm_type=norm_type)
)
self.mid = nn.Sequential(*mid_layers)
self.norm_out = Normalize(base_channels * 4, norm_type=norm_type)
self.conv_out = CausalConv3d(
base_channels * 4, latent_dim * 2, kernel_size=3, stride=1, padding=1
)
self.wavelet_transform_in = HaarWaveletTransform3D()
self.wavelet_transform_l1 = resolve_str_to_obj(l1_downsample_wavelet)()
self.wavelet_transform_l2 = resolve_str_to_obj(l2_downsample_wavelet)()
def forward(self, x):
coeffs = self.wavelet_transform_in(x)
l1_coeffs = coeffs[:, :3]
l1_coeffs = self.wavelet_transform_l1(l1_coeffs)
l1 = self.connect_l1(l1_coeffs)
l2_coeffs = self.wavelet_transform_l2(l1_coeffs[:, :3])
l2 = self.connect_l2(l2_coeffs)
h = self.down1(coeffs)
h = torch.concat([h, l1], dim=1)
h = self.down2(h)
h = torch.concat([h, l2], dim=1)
h = self.mid(h)
if npu_config is None:
h = self.norm_out(h)
else:
h = npu_config.run_group_norm(self.norm_out, h)
h = nonlinearity(h)
h = self.conv_out(h)
return h, (l1_coeffs, l2_coeffs)
class Decoder(VideoBaseAE):
@register_to_config
def __init__(
self,
latent_dim: int = 8,
base_channels: int = 128,
num_resblocks: int = 2,
dropout: float = 0.0,
energy_flow_hidden_size: int = 128,
attention_type: str = "AttnBlock3DFix",
use_attention: bool = True,
norm_type: str = "groupnorm",
t_interpolation: str = "nearest",
connect_res_layer_num: int = 1,
l1_upsample_block: str = "Upsample",
l1_upsample_wavelet: str = "InverseHaarWaveletTransform2D",
l2_upsample_block: str = "Spatial2xTime2x3DUpsample",
l2_upsample_wavelet: str = "InverseHaarWaveletTransform3D",
) -> None:
super().__init__()
self.energy_flow_hidden_size = energy_flow_hidden_size
self.conv_in = CausalConv3d(
latent_dim, base_channels * 4, kernel_size=3, stride=1, padding=1
)
mid_layers = [
ResnetBlock3D(
in_channels=base_channels * 4,
out_channels=base_channels * 4,
dropout=dropout,
norm_type=norm_type,
),
ResnetBlock3D(
in_channels=base_channels * 4,
out_channels=base_channels * 4 + energy_flow_hidden_size,
dropout=dropout,
norm_type=norm_type,
),
]
if use_attention:
mid_layers.insert(
1, resolve_str_to_obj(attention_type)(in_channels=base_channels * 4, norm_type=norm_type)
)
self.mid = nn.Sequential(*mid_layers)
self.up2 = nn.Sequential(
*[
ResnetBlock3D(
in_channels=base_channels * 4,
out_channels=base_channels * 4,
dropout=dropout,
norm_type=norm_type,
)
for _ in range(num_resblocks)
],
resolve_str_to_obj(l2_upsample_block)(
base_channels * 4, base_channels * 4, t_interpolation=t_interpolation
),
ResnetBlock3D(
in_channels=base_channels * 4,
out_channels=base_channels * 4 + energy_flow_hidden_size,
dropout=dropout,
norm_type=norm_type,
),
)
self.up1 = nn.Sequential(
*[
ResnetBlock3D(
in_channels=base_channels * (4 if i == 0 else 2),
out_channels=base_channels * 2,
dropout=dropout,
norm_type=norm_type,
)
for i in range(num_resblocks)
],
resolve_str_to_obj(l1_upsample_block)(in_channels=base_channels * 2, out_channels=base_channels * 2),
ResnetBlock3D(
in_channels=base_channels * 2,
out_channels=base_channels * 2,
dropout=dropout,
norm_type=norm_type,
),
)
self.layer = nn.Sequential(
*[
ResnetBlock3D(
in_channels=base_channels * (2 if i == 0 else 1),
out_channels=base_channels,
dropout=dropout,
norm_type=norm_type,
)
for i in range(2)
],
)
# Connection
if l1_upsample_block == "Upsample": # Bad code. For temporal usage.
l1_channels = 12
else:
l1_channels = 24
self.connect_l1 = nn.Sequential(
*[
ResnetBlock3D(
in_channels=energy_flow_hidden_size,
out_channels=energy_flow_hidden_size,
dropout=dropout,
norm_type=norm_type,
)
for _ in range(connect_res_layer_num)
],
Conv2d(energy_flow_hidden_size, l1_channels, kernel_size=3, stride=1, padding=1),
)
self.connect_l2 = nn.Sequential(
*[
ResnetBlock3D(
in_channels=energy_flow_hidden_size,
out_channels=energy_flow_hidden_size,
dropout=dropout,
norm_type=norm_type,
)
for _ in range(connect_res_layer_num)
],
Conv2d(energy_flow_hidden_size, 24, kernel_size=3, stride=1, padding=1),
)
# Out
self.norm_out = Normalize(base_channels, norm_type=norm_type)
self.conv_out = Conv2d(base_channels, 24, kernel_size=3, stride=1, padding=1)
self.inverse_wavelet_transform_out = InverseHaarWaveletTransform3D()
self.inverse_wavelet_transform_l1 = resolve_str_to_obj(l1_upsample_wavelet)()
self.inverse_wavelet_transform_l2 = resolve_str_to_obj(l2_upsample_wavelet)()
def forward(self, z):
h = self.conv_in(z)
h = self.mid(h)
l2_coeffs = self.connect_l2(h[:, -self.energy_flow_hidden_size :])
l2 = self.inverse_wavelet_transform_l2(l2_coeffs)
h = self.up2(h[:, : -self.energy_flow_hidden_size])
l1_coeffs = h[:, -self.energy_flow_hidden_size :]
l1_coeffs = self.connect_l1(l1_coeffs)
l1_coeffs[:, :3] = l1_coeffs[:, :3] + l2
l1 = self.inverse_wavelet_transform_l1(l1_coeffs)
h = self.up1(h[:, : -self.energy_flow_hidden_size])
h = self.layer(h)
if npu_config is None:
h = self.norm_out(h)
else:
h = npu_config.run_group_norm(self.norm_out, h)
h = nonlinearity(h)
h = self.conv_out(h)
h[:, :3] = h[:, :3] + l1
dec = self.inverse_wavelet_transform_out(h)
return dec, (l1_coeffs, l2_coeffs)
@ModelRegistry.register("WFVAE")
class WFVAEModel(VideoBaseAE):
@register_to_config
def __init__(
self,
latent_dim: int = 8,
base_channels: int = 128,
encoder_num_resblocks: int = 2,
encoder_energy_flow_hidden_size: int = 64,
decoder_num_resblocks: int = 2,
decoder_energy_flow_hidden_size: int = 128,
attention_type: str = "AttnBlock3DFix",
use_attention: bool = True,
dropout: float = 0.0,
norm_type: str = "groupnorm",
t_interpolation: str = "nearest",
connect_res_layer_num: int = 1,
scale: List[float] = [0.18215, 0.18215, 0.18215, 0.18215, 0.18215, 0.18215, 0.18215, 0.18215],
shift: List[float] = [0, 0, 0, 0, 0, 0, 0, 0],
# Module config
l1_dowmsample_block: str = "Downsample",
l1_downsample_wavelet: str = "HaarWaveletTransform2D",
l2_dowmsample_block: str = "Spatial2xTime2x3DDownsample",
l2_downsample_wavelet: str = "HaarWaveletTransform3D",
l1_upsample_block: str = "Upsample",
l1_upsample_wavelet: str = "InverseHaarWaveletTransform2D",
l2_upsample_block: str = "Spatial2xTime2x3DUpsample",
l2_upsample_wavelet: str = "InverseHaarWaveletTransform3D",
) -> None:
super().__init__()
self.use_tiling = False
# Hardcode for now
self.t_chunk_enc = 8
self.t_chunk_dec = 2
self.t_upsample_times = 4 // 2
self.use_quant_layer = False
self.encoder = Encoder(
latent_dim=latent_dim,
base_channels=base_channels,
num_resblocks=encoder_num_resblocks,
energy_flow_hidden_size=encoder_energy_flow_hidden_size,
dropout=dropout,
use_attention=use_attention,
norm_type=norm_type,
l1_dowmsample_block=l1_dowmsample_block,
l1_downsample_wavelet=l1_downsample_wavelet,
l2_dowmsample_block=l2_dowmsample_block,
l2_downsample_wavelet=l2_downsample_wavelet,
attention_type=attention_type
)
self.decoder = Decoder(
latent_dim=latent_dim,
base_channels=base_channels,
num_resblocks=decoder_num_resblocks,
energy_flow_hidden_size=decoder_energy_flow_hidden_size,
dropout=dropout,
use_attention=use_attention,
norm_type=norm_type,
t_interpolation=t_interpolation,
connect_res_layer_num=connect_res_layer_num,
l1_upsample_block=l1_upsample_block,
l1_upsample_wavelet=l1_upsample_wavelet,
l2_upsample_block=l2_upsample_block,
l2_upsample_wavelet=l2_upsample_wavelet,
attention_type=attention_type
)
# Set cache offset for trilinear lossless upsample.
self._set_cache_offset([self.decoder.up2, self.decoder.connect_l2, self.decoder.conv_in, self.decoder.mid], 1)
self._set_cache_offset([self.decoder.up2[-2:], self.decoder.up1, self.decoder.connect_l1, self.decoder.layer], self.t_upsample_times)
def get_encoder(self):
if self.use_quant_layer:
return [self.quant_conv, self.encoder]
return [self.encoder]
def get_decoder(self):
if self.use_quant_layer:
return [self.post_quant_conv, self.decoder]
return [self.decoder]
def _empty_causal_cached(self, parent):
for name, module in parent.named_modules():
if hasattr(module, 'causal_cached'):
module.causal_cached = deque()
def _set_causal_cached(self, enable_cached=True):
for name, module in self.named_modules():
if hasattr(module, 'enable_cached'):
module.enable_cached = enable_cached
def _set_cache_offset(self, modules, cache_offset=0):
for module in modules:
for submodule in module.modules():
if hasattr(submodule, 'cache_offset'):
submodule.cache_offset = cache_offset
def _set_first_chunk(self, is_first_chunk=True):
for module in self.modules():
if hasattr(module, 'is_first_chunk'):
module.is_first_chunk = is_first_chunk
def build_chunk_start_end(self, t, decoder_mode=False):
start_end = [[0, 1]]
start = 1
end = start
while True:
if start >= t:
break
end = min(t, end + (self.t_chunk_dec if decoder_mode else self.t_chunk_enc) )
start_end.append([start, end])
start = end
return start_end
def encode(self, x):
self._empty_causal_cached(self.encoder)
self._set_first_chunk(True)
if self.use_tiling:
h = self.tile_encode(x)
l1, l2 = None, None
else:
h, (l1, l2) = self.encoder(x)
if self.use_quant_layer:
h = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(h)
return posterior
def tile_encode(self, x):
b, c, t, h, w = x.shape
start_end = self.build_chunk_start_end(t)
result = []
for idx, (start, end) in enumerate(start_end):
self._set_first_chunk(idx == 0)
chunk = x[:, :, start:end, :, :]
chunk = self.encoder(chunk)[0]
if self.use_quant_layer:
chunk = self.quant_conv(chunk)
result.append(chunk)
return torch.cat(result, dim=2)
def decode(self, z):
self._empty_causal_cached(self.decoder)
self._set_first_chunk(True)
if self.use_tiling:
dec = self.tile_decode(z)
l1, l2 = None, None
else:
if self.use_quant_layer:
z = self.post_quant_conv(z)
dec, (l1, l2) = self.decoder(z)
return dec
def tile_decode(self, x):
b, c, t, h, w = x.shape
start_end = self.build_chunk_start_end(t, decoder_mode=True)
result = []
for idx, (start, end) in enumerate(start_end):
self._set_first_chunk(idx==0)
if end + 1 < t:
chunk = x[:, :, start:end+1, :, :]
else:
chunk = x[:, :, start:end, :, :]
if self.use_quant_layer:
chunk = self.post_quant_conv(chunk)
chunk = self.decoder(chunk)[0]
if end + 1 < t:
chunk = chunk[:, :, :-4]
result.append(chunk.clone())
else:
result.append(chunk.clone())
return torch.cat(result, dim=2)
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_last_layer(self):
if hasattr(self.decoder.conv_out, "conv"):
return self.decoder.conv_out.conv.weight
else:
return self.decoder.conv_out.weight
def enable_tiling(self, use_tiling: bool = True):
self.use_tiling = use_tiling
self._set_causal_cached(use_tiling)
def disable_tiling(self):
self.enable_tiling(False)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")
print("init from " + path)
if (
"ema_state_dict" in sd
and len(sd["ema_state_dict"]) > 0
and os.environ.get("NOT_USE_EMA_MODEL", 0) == 0
):
print("Load from ema model!")
sd = sd["ema_state_dict"]
sd = {key.replace("module.", ""): value for key, value in sd.items()}
elif "state_dict" in sd:
print("Load from normal model!")
if "gen_model" in sd["state_dict"]:
sd = sd["state_dict"]["gen_model"]
else:
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
print(missing_keys, unexpected_keys)
================================================
FILE: opensora/models/causalvideovae/sample/rec_video_vae.py
================================================
import argparse
from tqdm import tqdm
import torch
import sys
from torch.utils.data import DataLoader, Subset
import os
from accelerate import Accelerator
sys.path.append(".")
from opensora.models.causalvideovae.model import *
from opensora.models.causalvideovae.dataset.video_dataset import ValidVideoDataset
from opensora.models.causalvideovae.utils.video_utils import custom_to_video
@torch.no_grad()
def main(args: argparse.Namespace):
accelerator = Accelerator()
device = accelerator.device
real_video_dir = args.real_video_dir
generated_video_dir = args.generated_video_dir
sample_rate = args.sample_rate
resolution = args.resolution
crop_size = args.crop_size
num_frames = args.num_frames
sample_rate = args.sample_rate
device = args.device
sample_fps = args.sample_fps
batch_size = args.batch_size
num_workers = args.num_workers
subset_size = args.subset_size
if not os.path.exists(args.generated_video_dir):
os.makedirs(args.generated_video_dir, exist_ok=True)
data_type = torch.bfloat16
# ---- Load Model ----
device = args.device
model_cls = ModelRegistry.get_model(args.model_name)
vae = model_cls.from_pretrained(args.from_pretrained)
vae = vae.to(device).to(data_type)
if args.enable_tiling:
vae.enable_tiling()
vae.tile_overlap_factor = args.tile_overlap_factor
# ---- Prepare Dataset ----
dataset = ValidVideoDataset(
real_video_dir=real_video_dir,
num_frames=num_frames,
sample_rate=sample_rate,
crop_size=crop_size,
resolution=resolution,
)
if subset_size:
indices = range(subset_size)
dataset = Subset(dataset, indices=indices)
dataloader = DataLoader(
dataset, batch_size=batch_size, pin_memory=False, num_workers=num_workers
)
dataloader = accelerator.prepare(dataloader)
# ---- Inference ----
for batch in tqdm(dataloader, disable=not accelerator.is_local_main_process):
x, file_names = batch['video'], batch['file_name']
x = x.to(device=device, dtype=data_type) # b c t h w
x = x * 2 - 1
encode_result = vae.encode(x)
if isinstance(encode_result, tuple):
encode_result = encode_result[0]
latents = encode_result.sample().to(data_type)
video_recon = vae.decode(latents)
if isinstance(video_recon, tuple):
video_recon = video_recon[0]
for idx, video in enumerate(video_recon):
output_path = os.path.join(generated_video_dir, file_names[idx])
if args.output_origin:
os.makedirs(os.path.join(generated_video_dir, "origin/"), exist_ok=True)
origin_output_path = os.path.join(generated_video_dir, "origin/", file_names[idx])
custom_to_video(
x[idx], fps=sample_fps / sample_rate, output_file=origin_output_path
)
custom_to_video(
video, fps=sample_fps / sample_rate, output_file=output_path
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--real_video_dir", type=str, default="")
parser.add_argument("--generated_video_dir", type=str, default="")
parser.add_argument("--from_pretrained", type=str, default="")
parser.add_argument("--sample_fps", type=int, default=30)
parser.add_argument("--resolution", type=int, default=336)
parser.add_argument("--crop_size", type=int, default=None)
parser.add_argument("--num_frames", type=int, default=17)
parser.add_argument("--sample_rate", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--subset_size", type=int, default=None)
parser.add_argument("--tile_overlap_factor", type=float, default=0.25)
parser.add_argument('--enable_tiling', action='store_true')
parser.add_argument('--output_origin', action='store_true')
parser.add_argument("--model_name", type=str, default=None, help="")
parser.add_argument("--device", type=str, default="cuda")
args = parser.parse_args()
main(args)
================================================
FILE: opensora/models/causalvideovae/utils/__init__.py
================================================
================================================
FILE: opensora/models/causalvideovae/utils/dataset_utils.py
================================================
import math
from einops import rearrange
import decord
from torch.nn import functional as F
import torch
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
class DecordInit(object):
"""Using Decord(https://github.com/dmlc/decord) to initialize the video_reader."""
def __init__(self, num_threads=1):
self.num_threads = num_threads
self.ctx = decord.cpu(0)
def __call__(self, filename):
"""Perform the Decord initialization.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
reader = decord.VideoReader(filename,
ctx=self.ctx,
num_threads=self.num_threads)
return reader
def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'sr={self.sr},'
f'num_threads={self.num_threads})')
return repr_str
def pad_to_multiple(number, ds_stride):
remainder = number % ds_stride
if remainder == 0:
return number
else:
padding = ds_stride - remainder
return number + padding
================================================
FILE: opensora/models/causalvideovae/utils/downloader.py
================================================
import gdown
import os
opensora_cache_home = os.path.expanduser(
os.getenv("OPENSORA_HOME", os.path.join("~/.cache", "opensora"))
)
def gdown_download(id, fname, cache_dir=None):
cache_dir = opensora_cache_home if not cache_dir else cache_dir
os.makedirs(cache_dir, exist_ok=True)
destination = os.path.join(cache_dir, fname)
if os.path.exists(destination):
return destination
gdown.download(id=id, output=destination, quiet=False)
return destination
================================================
FILE: opensora/models/causalvideovae/utils/video_utils.py
================================================
import torch
import numpy as np
import numpy.typing as npt
import cv2
from decord import VideoReader, cpu
def array_to_video(
image_array: npt.NDArray, fps: float = 30.0, output_file: str = "output_video.mp4"
) -> None:
"""b h w c"""
height, width, channels = image_array[0].shape
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height))
for image in image_array:
image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
video_writer.write(image_rgb)
video_writer.release()
def custom_to_video(
x: torch.Tensor, fps: float = 2.0, output_file: str = "output_video.mp4"
) -> None:
x = x.detach().cpu()
x = torch.clamp(x, -1, 1)
x = (x + 1) / 2
x = x.permute(1, 2, 3, 0).float().numpy()
x = (255 * x).astype(np.uint8)
array_to_video(x, fps=fps, output_file=output_file)
return
def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor:
decord_vr = VideoReader(video_path, ctx=cpu(0), num_threads=8)
total_frames = len(decord_vr)
sample_frames_len = sample_rate * num_frames
if total_frames > sample_frames_len:
s = 0
e = s + sample_frames_len
num_frames = num_frames
else:
s = 0
e = total_frames
num_frames = int(total_frames / sample_frames_len * num_frames)
print(
f"sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}",
video_path,
total_frames,
)
frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)
video_data = decord_vr.get_batch(frame_id_list).asnumpy()
video_data = torch.from_numpy(video_data)
video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W)
return video_data
def tensor_to_video(x):
"""[0-1] tensor to video"""
x = (x * 2 - 1).detach().cpu()
x = torch.clamp(x, -1, 1)
x = (x + 1) / 2
x = x.permute(1, 0, 2, 3).float().numpy() # c t h w -> t c h w
x = (255 * x).astype(np.uint8)
return x
================================================
FILE: opensora/models/diffusion/__init__.py
================================================
from .opensora_v1_3.modeling_opensora import OpenSora_v1_3_models
from .opensora_v1_3.modeling_inpaint import OpenSoraInpaint_v1_3_models
Diffusion_models = {}
Diffusion_models.update(OpenSora_v1_3_models)
Diffusion_models.update(OpenSoraInpaint_v1_3_models)
from .opensora_v1_3.modeling_opensora import OpenSora_v1_3_models_class
from .opensora_v1_3.modeling_inpaint import OpenSoraInpaint_v1_3_models_class
Diffusion_models_class = {}
Diffusion_models_class.update(OpenSora_v1_3_models_class)
Diffusion_models_class.update(OpenSoraInpaint_v1_3_models_class)
================================================
FILE: opensora/models/diffusion/common.py
================================================
import torch
from einops import rearrange, repeat
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.models.attention_processor import Attention as Attention_
try:
import torch_npu
from opensora.npu_config import npu_config, set_run_dtype
from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info as xccl_info
from opensora.acceleration.communications import all_to_all_SBH
except:
torch_npu = None
npu_config = None
set_run_dtype = None
from opensora.utils.parallel_states import get_sequence_parallel_state, nccl_info as xccl_info
from opensora.utils.communications import all_to_all_SBH
class PatchEmbed2D(nn.Module):
"""2D Image to Patch Embedding but with video"""
def __init__(
self,
patch_size=16,
in_channels=3,
embed_dim=768,
bias=True,
):
super().__init__()
self.proj = nn.Conv2d(
in_channels, embed_dim,
kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), bias=bias
)
def forward(self, latent):
b, _, _, _, _ = latent.shape
latent = rearrange(latent, 'b c t h w -> (b t) c h w')
latent = self.proj(latent)
latent = rearrange(latent, '(b t) c h w -> b (t h w) c', b=b)
return latent
class PositionGetter3D(object):
""" return positions of patches """
def __init__(self, ):
self.cache_positions = {}
def __call__(self, b, t, h, w, device):
if not (b,t,h,w) in self.cache_positions:
x = torch.arange(w, device=device)
y = torch.arange(h, device=device)
z = torch.arange(t, device=device)
pos = torch.cartesian_prod(z, y, x)
# print('PositionGetter3D', PositionGetter3D)
pos = pos.reshape(t * h * w, 3).transpose(0, 1).reshape(3, -1, 1).contiguous().expand(3, -1, b).clone()
poses = (pos[0].contiguous(), pos[1].contiguous(), pos[2].contiguous())
max_poses = (int(poses[0].max()), int(poses[1].max()), int(poses[2].max()))
self.cache_positions[b, t, h, w] = (poses, max_poses)
pos = self.cache_positions[b, t, h, w]
return pos
class RoPE3D(torch.nn.Module):
def __init__(self, freq=10000.0, F0=1.0, interpolation_scale_thw=(1, 1, 1)):
super().__init__()
self.base = freq
self.F0 = F0
self.interpolation_scale_t = interpolation_scale_thw[0]
self.interpolation_scale_h = interpolation_scale_thw[1]
self.interpolation_scale_w = interpolation_scale_thw[2]
self.cache = {}
def get_cos_sin(self, D, seq_len, device, dtype, interpolation_scale=1):
if (D, seq_len, device, dtype) not in self.cache:
inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) / interpolation_scale
freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
freqs = torch.cat((freqs, freqs), dim=-1)
cos = freqs.cos() # (Seq, Dim)
sin = freqs.sin()
self.cache[D, seq_len, device, dtype] = (cos, sin)
return self.cache[D, seq_len, device, dtype]
@staticmethod
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rope1d(self, tokens, pos1d, cos, sin):
assert pos1d.ndim == 2
# for (ntokens x batch_size x nheads x dim)
cos = torch.nn.functional.embedding(pos1d, cos)[:, :, None, :]
sin = torch.nn.functional.embedding(pos1d, sin)[:, :, None, :]
return (tokens * cos) + (self.rotate_half(tokens) * sin)
def forward(self, tokens, positions):
"""
input:
* tokens: ntokens x batch_size x nheads x dim
* positions: batch_size x ntokens x 3 (t, y and x position of each token)
output:
* tokens after appplying RoPE3D (ntokens x batch_size x nheads x dim)
"""
assert tokens.size(3) % 3 == 0, "number of dimensions should be a multiple of three"
D = tokens.size(3) // 3
poses, max_poses = positions
assert len(poses) == 3 and poses[0].ndim == 2# Batch, Seq, 3
cos_t, sin_t = self.get_cos_sin(D, max_poses[0] + 1, tokens.device, tokens.dtype, self.interpolation_scale_t)
cos_y, sin_y = self.get_cos_sin(D, max_poses[1] + 1, tokens.device, tokens.dtype, self.interpolation_scale_h)
cos_x, sin_x = self.get_cos_sin(D, max_poses[2] + 1, tokens.device, tokens.dtype, self.interpolation_scale_w)
# split features into three along the feature dimension, and apply rope1d on each half
t, y, x = tokens.chunk(3, dim=-1)
t = self.apply_rope1d(t, poses[0], cos_t, sin_t)
y = self.apply_rope1d(y, poses[1], cos_y, sin_y)
x = self.apply_rope1d(x, poses[2], cos_x, sin_x)
tokens = torch.cat((t, y, x), dim=-1)
return tokens
================================================
FILE: opensora/models/diffusion/opensora_v1_3/__init__.py
================================================
================================================
FILE: opensora/models/diffusion/opensora_v1_3/modeling_inpaint.py
================================================
import os
import numpy as np
from torch import nn
import torch
from einops import rearrange, repeat
from typing import Any, Dict, Optional, Tuple
from diffusers.configuration_utils import register_to_config
from opensora.models.diffusion.common import PatchEmbed2D
from opensora.utils.utils import to_2tuple
from opensora.models.diffusion.opensora_v1_3.modeling_opensora import OpenSoraT2V_v1_3 as OpenSoraT2V
import glob
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
class OpenSoraInpaint_v1_3(OpenSoraT2V):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = True,
sample_size_h: Optional[int] = None,
sample_size_w: Optional[int] = None,
sample_size_t: Optional[int] = None,
patch_size: Optional[int] = None,
patch_size_t: Optional[int] = None,
activation_fn: str = "geglu",
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
caption_channels: int = None,
interpolation_scale_h: float = 1.0,
interpolation_scale_w: float = 1.0,
interpolation_scale_t: float = 1.0,
sparse1d: bool = False,
sparse_n: int = 2,
# inpaint
vae_scale_factor_t: int = 4,
):
super().__init__(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
attention_bias=attention_bias,
sample_size_h=sample_size_h,
sample_size_w=sample_size_w,
sample_size_t=sample_size_t,
patch_size=patch_size,
patch_size_t=patch_size_t,
activation_fn=activation_fn,
only_cross_attention=only_cross_attention,
double_self_attention=double_self_attention,
upcast_attention=upcast_attention,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
caption_channels=caption_channels,
interpolation_scale_h=interpolation_scale_h,
interpolation_scale_w=interpolation_scale_w,
interpolation_scale_t=interpolation_scale_t,
sparse1d=sparse1d,
sparse_n=sparse_n,
)
self.vae_scale_factor_t = vae_scale_factor_t
# init masked_pixel_values and mask conv_in
self._init_patched_inputs_for_inpainting()
def _init_patched_inputs_for_inpainting(self):
self.config.sample_size = to_2tuple(self.config.sample_size)
self.pos_embed_masked_hidden_states = nn.ModuleList(
[
PatchEmbed2D(
patch_size=self.config.patch_size,
in_channels=self.config.in_channels,
embed_dim=self.config.hidden_size,
),
zero_module(nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)),
]
)
self.pos_embed_mask = nn.ModuleList(
[
PatchEmbed2D(
patch_size=self.config.patch_size,
in_channels=self.vae_scale_factor_t,
embed_dim=self.config.hidden_size,
),
zero_module(nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)),
]
)
def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, batch_size, frame):
# inpaint
assert hidden_states.shape[1] == 2 * self.config.in_channels + self.vae_scale_factor_t
in_channels = self.config.in_channels
input_hidden_states, input_masked_hidden_states, input_mask = hidden_states[:, :in_channels], hidden_states[:, in_channels: 2 * in_channels], hidden_states[:, 2 * in_channels:]
input_hidden_states = self.pos_embed(input_hidden_states.to(self.dtype))
input_masked_hidden_states = self.pos_embed_masked_hidden_states[0](input_masked_hidden_states.to(self.dtype))
input_masked_hidden_states = self.pos_embed_masked_hidden_states[1](input_masked_hidden_states)
input_mask = self.pos_embed_mask[0](input_mask.to(self.dtype))
input_mask = self.pos_embed_mask[1](input_mask)
hidden_states = input_hidden_states + input_masked_hidden_states + input_mask
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=self.dtype
) # b 6d, b d
encoder_hidden_states = self.caption_projection(encoder_hidden_states) # b, 1, l, d or b, 1, l, d
assert encoder_hidden_states.shape[1] == 1
encoder_hidden_states = rearrange(encoder_hidden_states, 'b 1 l d -> (b 1) l d')
return hidden_states, encoder_hidden_states, timestep, embedded_timestep
def OpenSoraInpaint_v1_3_2B_122(**kwargs):
return OpenSoraInpaint_v1_3(
num_layers=32, attention_head_dim=96, num_attention_heads=24, patch_size_t=1, patch_size=2,
caption_channels=4096, cross_attention_dim=2304, activation_fn="gelu-approximate", **kwargs
)
OpenSoraInpaint_v1_3_models = {
"OpenSoraInpaint_v1_3-2B/122": OpenSoraInpaint_v1_3_2B_122, # 2.7B
}
OpenSoraInpaint_v1_3_models_class = {
"OpenSoraInpaint_v1_3-2B/122": OpenSoraInpaint_v1_3,
"OpenSoraInpaint_v1_3": OpenSoraInpaint_v1_3,
}
================================================
FILE: opensora/models/diffusion/opensora_v1_3/modeling_opensora.py
================================================
import os
import numpy as np
from torch import nn
import torch
from einops import rearrange, repeat
from typing import Any, Dict, Optional, Tuple
from torch.nn import functional as F
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import is_torch_version, deprecate
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNormSingle
from diffusers.models.embeddings import PixArtAlphaTextProjection
from opensora.models.diffusion.opensora_v1_3.modules import BasicTransformerBlock, Attention
from opensora.models.diffusion.common import PatchEmbed2D
from opensora.utils.utils import to_2tuple
try:
import torch_npu
from opensora.npu_config import npu_config
from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info
except:
torch_npu = None
npu_config = None
from opensora.utils.parallel_states import get_sequence_parallel_state, nccl_info
class OpenSoraT2V_v1_3(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = True,
sample_size_h: Optional[int] = None,
sample_size_w: Optional[int] = None,
sample_size_t: Optional[int] = None,
patch_size: Optional[int] = None,
patch_size_t: Optional[int] = None,
activation_fn: str = "geglu",
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
caption_channels: int = None,
interpolation_scale_h: float = 1.0,
interpolation_scale_w: float = 1.0,
interpolation_scale_t: float = 1.0,
sparse1d: bool = False,
sparse_n: int = 2,
):
super().__init__()
# Set some common variables used across the board.
self.out_channels = in_channels if out_channels is None else out_channels
self.config.hidden_size = self.config.num_attention_heads * self.config.attention_head_dim
self.gradient_checkpointing = False
self._init_patched_inputs()
def _init_patched_inputs(self):
self.config.sample_size = (self.config.sample_size_h, self.config.sample_size_w)
interpolation_scale_thw = (
self.config.interpolation_scale_t,
self.config.interpolation_scale_h,
self.config.interpolation_scale_w
)
self.caption_projection = PixArtAlphaTextProjection(
in_features=self.config.caption_channels, hidden_size=self.config.hidden_size
)
self.pos_embed = PatchEmbed2D(
patch_size=self.config.patch_size,
in_channels=self.config.in_channels,
embed_dim=self.config.hidden_size,
)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
self.config.hidden_size,
self.config.num_attention_heads,
self.config.attention_head_dim,
dropout=self.config.dropout,
cross_attention_dim=self.config.cross_attention_dim,
activation_fn=self.config.activation_fn,
attention_bias=self.config.attention_bias,
only_cross_attention=self.config.only_cross_attention,
double_self_attention=self.config.double_self_attention,
upcast_attention=self.config.upcast_attention,
norm_elementwise_affine=self.config.norm_elementwise_affine,
norm_eps=self.config.norm_eps,
interpolation_scale_thw=interpolation_scale_thw,
sparse1d=self.config.sparse1d if i > 1 and i < 30 else False,
sparse_n=self.config.sparse_n,
sparse_group=i % 2 == 1,
)
for i in range(self.config.num_layers)
]
)
self.norm_out = nn.LayerNorm(self.config.hidden_size, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, self.config.hidden_size) / self.config.hidden_size**0.5)
self.proj_out = nn.Linear(
self.config.hidden_size, self.config.patch_size_t * self.config.patch_size * self.config.patch_size * self.out_channels
)
self.adaln_single = AdaLayerNormSingle(self.config.hidden_size)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
timestep: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
**kwargs,
):
batch_size, c, frame, h, w = hidden_states.shape
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 4:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
# b, frame, h, w -> a video
# b, 1, h, w -> only images
attention_mask = attention_mask.to(self.dtype)
attention_mask = attention_mask.unsqueeze(1) # b 1 t h w
attention_mask = F.max_pool3d(
attention_mask,
kernel_size=(self.config.patch_size_t, self.config.patch_size, self.config.patch_size),
stride=(self.config.patch_size_t, self.config.patch_size, self.config.patch_size)
)
attention_mask = rearrange(attention_mask, 'b 1 t h w -> (b 1) 1 (t h w)')
attention_mask = (1 - attention_mask.bool().to(self.dtype)) * -10000.0
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3:
# b, 1, l
encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0
# 1. Input
frame = ((frame - 1) // self.config.patch_size_t + 1) if frame % 2 == 1 else frame // self.config.patch_size_t # patchfy
height, width = hidden_states.shape[-2] // self.config.patch_size, hidden_states.shape[-1] // self.config.patch_size
hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs(
hidden_states, encoder_hidden_states, timestep, batch_size, frame
)
# To
# x (t*h*w b d) or (t//sp*h*w b d)
# cond_1 (l b d) or (l//sp b d)
hidden_states = rearrange(hidden_states, 'b s h -> s b h', b=batch_size).contiguous()
encoder_hidden_states = rearrange(encoder_hidden_states, 'b s h -> s b h', b=batch_size).contiguous()
timestep = timestep.view(batch_size, 6, -1).transpose(0, 1).contiguous()
sparse_mask = {}
if npu_config is None:
if get_sequence_parallel_state():
head_num = self.config.num_attention_heads // nccl_info.world_size
else:
head_num = self.config.num_attention_heads
else:
head_num = None
for sparse_n in [1, 4]:
sparse_mask[sparse_n] = Attention.prepare_sparse_mask(attention_mask, encoder_attention_mask, sparse_n, head_num)
# 2. Blocks
for i, block in enumerate(self.transformer_blocks):
if i > 1 and i < 30:
attention_mask, encoder_attention_mask = sparse_mask[block.attn1.processor.sparse_n][block.attn1.processor.sparse_group]
else:
attention_mask, encoder_attention_mask = sparse_mask[1][block.attn1.processor.sparse_group]
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
frame,
height,
width,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
frame=frame,
height=height,
width=width,
)
# To (b, t*h*w, h) or (b, t//sp*h*w, h)
hidden_states = rearrange(hidden_states, 's b h -> b s h', b=batch_size).contiguous()
# 3. Output
output = self._get_output_for_patched_inputs(
hidden_states=hidden_states,
timestep=timestep,
embedded_timestep=embedded_timestep,
num_frames=frame,
height=height,
width=width,
) # b c t h w
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, batch_size, frame):
hidden_states = self.pos_embed(hidden_states.to(self.dtype))
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=self.dtype
) # b 6d, b d
encoder_hidden_states = self.caption_projection(encoder_hidden_states) # b, 1, l, d or b, 1, l, d
assert encoder_hidden_states.shape[1] == 1
encoder_hidden_states = rearrange(encoder_hidden_states, 'b 1 l d -> (b 1) l d')
return hidden_states, encoder_hidden_states, timestep, embedded_timestep
def _get_output_for_patched_inputs(
self, hidden_states, timestep, embedded_timestep, num_frames, height, width
):
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)
# unpatchify
hidden_states = hidden_states.reshape(
shape=(-1, num_frames, height, width, self.config.patch_size_t, self.config.patch_size, self.config.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nthwopqc->nctohpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels,
num_frames * self.config.patch_size_t, height * self.config.patch_size, width * self.config.patch_size)
)
return output
def OpenSoraT2V_v1_3_2B_122(**kwargs):
kwargs.pop('skip_connection', None)
return OpenSoraT2V_v1_3(
num_layers=32, attention_head_dim=96, num_attention_heads=24, patch_size_t=1, patch_size=2,
caption_channels=4096, cross_attention_dim=2304, activation_fn="gelu-approximate", **kwargs
)
OpenSora_v1_3_models = {
"OpenSoraT2V_v1_3-2B/122": OpenSoraT2V_v1_3_2B_122, # 2.7B
}
OpenSora_v1_3_models_class = {
"OpenSoraT2V_v1_3-2B/122": OpenSoraT2V_v1_3,
"OpenSoraT2V_v1_3": OpenSoraT2V_v1_3,
}
if __name__ == '__main__':
from opensora.models.causalvideovae import ae_stride_config, ae_channel_config
from opensora.models.causalvideovae import ae_norm, ae_denorm
from opensora.models import CausalVAEModelWrapper
args = type('args', (),
{
'ae': 'WFVAEModel_D8_4x8x8',
'model_max_length': 300,
'max_height': 176,
'max_width': 176,
'num_frames': 33,
'compress_kv_factor': 1,
'interpolation_scale_t': 1,
'interpolation_scale_h': 1,
'interpolation_scale_w': 1,
"sparse1d": True,
"sparse_n": 4,
"rank": 64,
}
)
b = 2
c = 8
cond_c = 4096
num_timesteps = 1000
ae_stride_t, ae_stride_h, ae_stride_w = ae_stride_config[args.ae]
latent_size = (args.max_height // ae_stride_h, args.max_width // ae_stride_w)
num_frames = (args.num_frames - 1) // ae_stride_t + 1
device = torch.device('cuda:0')
model = OpenSoraT2V_v1_3_2B_122(
in_channels=c,
out_channels=c,
sample_size_h=latent_size,
sample_size_w=latent_size,
sample_size_t=num_frames,
activation_fn="gelu-approximate",
attention_bias=True,
double_self_attention=False,
norm_elementwise_affine=False,
norm_eps=1e-06,
only_cross_attention=False,
upcast_attention=False,
interpolation_scale_t=args.interpolation_scale_t,
interpolation_scale_h=args.interpolation_scale_h,
interpolation_scale_w=args.interpolation_scale_w,
sparse1d=args.sparse1d,
sparse_n=args.sparse_n
)
try:
path = "/storage/ongoing/new/7.19anyres/Open-Sora-Plan/bs32x8x1_anyx93x640x640_fps16_lr1e-5_snr5_ema9999_sparse1d4_dit_l_mt5xxl_vpred_zerosnr/checkpoint-43000/model_ema/diffusion_pytorch_model.safetensors"
# ckpt = torch.load(path, map_location="cpu")
from safetensors.torch import load_file as safe_load
ckpt = safe_load(path, device="cpu")
msg = model.load_state_dict(ckpt, strict=True)
print(msg)
except Exception as e:
print(e)
print(model)
print(f'{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9} B')
# import sys;sys.exit()
model = model.to(device)
x = torch.randn(b, c, 1+(args.num_frames-1)//ae_stride_t, args.max_height//ae_stride_h, args.max_width//ae_stride_w).to(device)
cond = torch.randn(b, 1, args.model_max_length, cond_c).to(device)
attn_mask = torch.randint(0, 2, (b, 1+(args.num_frames-1)//ae_stride_t, args.max_height//ae_stride_h, args.max_width//ae_stride_w)).to(device) # B L or B 1+num_images L
cond_mask = torch.randint(0, 2, (b, 1, args.model_max_length)).to(device) # B L or B 1+num_images L
timestep = torch.randint(0, 1000, (b,), device=device)
model_kwargs = dict(
hidden_states=x, encoder_hidden_states=cond, attention_mask=attn_mask,
encoder_attention_mask=cond_mask, timestep=timestep
)
with torch.no_grad():
output = model(**model_kwargs)
print(output[0].shape)
================================================
FILE: opensora/models/diffusion/opensora_v1_3/modules.py
================================================
import torch
from einops import rearrange, repeat
from typing import Any, Dict, Optional, Tuple
import torch.nn.functional as F
from torch import nn
from typing import Optional, Tuple
from diffusers.utils import logging
from diffusers.utils.torch_utils import maybe_allow_in_graph
from diffusers.models.attention import FeedForward
from diffusers.models.attention_processor import Attention as Attention_
from diffusers.models.embeddings import Timesteps, TimestepEmbedding
try:
import torch_npu
from opensora.npu_config import npu_config, set_run_dtype
from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info as xccl_info
from opensora.acceleration.communications import all_to_all_SBH
except:
torch_npu = None
npu_config = None
set_run_dtype = None
from opensora.utils.parallel_states import get_sequence_parallel_state, nccl_info as xccl_info
from opensora.utils.communications import all_to_all_SBH
from ..common import RoPE3D, PositionGetter3D
logger = logging.get_logger(__name__)
class Attention(Attention_):
def __init__(
self, interpolation_scale_thw, sparse1d, sparse_n,
sparse_group, is_cross_attn, **kwags
):
processor = OpenSoraAttnProcessor2_0(
interpolation_scale_thw=interpolation_scale_thw, sparse1d=sparse1d, sparse_n=sparse_n,
sparse_group=sparse_group, is_cross_attn=is_cross_attn
)
super().__init__(processor=processor, **kwags)
@staticmethod
def prepare_sparse_mask(attention_mask, encoder_attention_mask, sparse_n, head_num):
attention_mask = attention_mask.unsqueeze(1)
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
l = attention_mask.shape[-1]
if l % (sparse_n * sparse_n) == 0:
pad_len = 0
else:
pad_len = sparse_n * sparse_n - l % (sparse_n * sparse_n)
attention_mask_sparse = F.pad(attention_mask, (0, pad_len, 0, 0), value=-9980.0)
attention_mask_sparse_1d = rearrange(
attention_mask_sparse,
'b 1 1 (g k) -> (k b) 1 1 g',
k=sparse_n
)
attention_mask_sparse_1d_group = rearrange(
attention_mask_sparse,
'b 1 1 (n m k) -> (m b) 1 1 (n k)',
m=sparse_n,
k=sparse_n
)
encoder_attention_mask_sparse = encoder_attention_mask.repeat(sparse_n, 1, 1, 1)
if npu_config is not None:
attention_mask_sparse_1d = npu_config.get_attention_mask(
attention_mask_sparse_1d, attention_mask_sparse_1d.shape[-1]
)
attention_mask_sparse_1d_group = npu_config.get_attention_mask(
attention_mask_sparse_1d_group, attention_mask_sparse_1d_group.shape[-1]
)
encoder_attention_mask_sparse_1d = npu_config.get_attention_mask(
encoder_attention_mask_sparse, attention_mask_sparse_1d.shape[-1]
)
encoder_attention_mask_sparse_1d_group = encoder_attention_mask_sparse_1d
else:
attention_mask_sparse_1d = attention_mask_sparse_1d.repeat_interleave(head_num, dim=1)
attention_mask_sparse_1d_group = attention_mask_sparse_1d_group.repeat_interleave(head_num, dim=1)
encoder_attention_mask_sparse_1d = encoder_attention_mask_sparse.repeat_interleave(head_num, dim=1)
encoder_attention_mask_sparse_1d_group = encoder_attention_mask_sparse_1d
return {
False: (attention_mask_sparse_1d, encoder_attention_mask_sparse_1d),
True: (attention_mask_sparse_1d_group, encoder_attention_mask_sparse_1d_group)
}
def prepare_attention_mask(
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
) -> torch.Tensor:
r"""
Prepare the attention mask for the attention computation.
Args:
attention_mask (`torch.Tensor`):
The attention mask to prepare.
target_length (`int`):
The target length of the attention mask. This is the length of the attention mask after padding.
batch_size (`int`):
The batch size, which is used to repeat the attention mask.
out_dim (`int`, *optional*, defaults to `3`):
The output dimension of the attention mask. Can be either `3` or `4`.
Returns:
`torch.Tensor`: The prepared attention mask.
"""
head_size = self.heads
if get_sequence_parallel_state():
head_size = head_size // xccl_info.world_size # e.g, 24 // 8
if attention_mask is None: # b 1 t*h*w in sa, b 1 l in ca
return attention_mask
current_length: int = attention_mask.shape[-1]
if current_length != target_length:
print(f'attention_mask.shape, {attention_mask.shape}, current_length, {current_length}, target_length, {target_length}')
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
if out_dim == 3:
if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
elif out_dim == 4:
attention_mask = attention_mask.unsqueeze(1)
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
return attention_mask
class OpenSoraAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self, interpolation_scale_thw=(1, 1, 1),
sparse1d=False, sparse_n=2, sparse_group=False, is_cross_attn=True):
self.sparse1d = sparse1d
self.sparse_n = sparse_n
self.sparse_group = sparse_group
self.is_cross_attn = is_cross_attn
self.interpolation_scale_thw = interpolation_scale_thw
self._init_rope(interpolation_scale_thw)
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def _init_rope(self, interpolation_scale_thw):
self.rope = RoPE3D(interpolation_scale_thw=interpolation_scale_thw)
self.position_getter = PositionGetter3D()
def _sparse_1d(self, x, frame, height, width):
"""
require the shape of (ntokens x batch_size x dim)
"""
l = x.shape[0]
assert l == frame*height*width
pad_len = 0
if l % (self.sparse_n * self.sparse_n) != 0:
pad_len = self.sparse_n * self.sparse_n - l % (self.sparse_n * self.sparse_n)
if pad_len != 0:
x = F.pad(x, (0, 0, 0, 0, 0, pad_len))
if not self.sparse_group:
x = rearrange(x, '(g k) b d -> g (k b) d', k=self.sparse_n)
else:
x = rearrange(x, '(n m k) b d -> (n k) (m b) d', m=self.sparse_n, k=self.sparse_n)
return x, pad_len
def _reverse_sparse_1d(self, x, frame, height, width, pad_len):
"""
require the shape of (ntokens x batch_size x dim)
"""
assert x.shape[0] == (frame*height*width+pad_len) // self.sparse_n
if not self.sparse_group:
x = rearrange(x, 'g (k b) d -> (g k) b d', k=self.sparse_n)
else:
x = rearrange(x, '(n k) (m b) d -> (n m k) b d', m=self.sparse_n, k=self.sparse_n)
x = x[:frame*height*width, :, :]
return x
def _sparse_1d_kv(self, x):
"""
require the shape of (ntokens x batch_size x dim)
"""
x = repeat(x, 's b d -> s (k b) d', k=self.sparse_n)
return x
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
frame: int = 8,
height: int = 16,
width: int = 16,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
sequence_length, batch_size, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
# if attention_mask is not None:
# if npu_config is None:
# # scaled_dot_product_attention expects attention_mask shape to be
# # (batch, heads, source_length, target_length)
# if get_sequence_parallel_state():
# # sequence_length has been split, so we need sequence_length * nccl_info.world_size
# # (sp*b 1 s), where s has not been split
# # (sp*b 1 s) -prepare-> (sp*b*head 1 s) -> (sp*b head 1 s), where head has been split (e.g, 24 // 8)
# attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length * xccl_info.world_size, batch_size)
# attention_mask = attention_mask.view(batch_size, attn.heads // xccl_info.world_size, -1, attention_mask.shape[-1])
# else:
# attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
FA_head_num = attn.heads
total_frame = frame
if get_sequence_parallel_state():
sp_size = xccl_info.world_size
FA_head_num = attn.heads // sp_size
total_frame = frame * sp_size
# apply all_to_all to gather sequence and split attention heads [s // sp * b, h, d] -> [s * b, h // sp, d]
query = all_to_all_SBH(query.view(-1, attn.heads, head_dim), scatter_dim=1, gather_dim=0)
key = all_to_all_SBH(key.view(-1, attn.heads, head_dim), scatter_dim=1, gather_dim=0)
value = all_to_all_SBH(value.view(-1, attn.heads, head_dim), scatter_dim=1, gather_dim=0)
query = query.view(-1, batch_size, FA_head_num, head_dim)
key = key.view(-1, batch_size, FA_head_num, head_dim)
if not self.is_cross_attn:
# require the shape of (ntokens x batch_size x nheads x dim)
pos_thw = self.position_getter(batch_size, t=total_frame, h=height, w=width, device=query.device)
query = self.rope(query, pos_thw)
key = self.rope(key, pos_thw)
# query = rearrange(query, 's b h d -> b h s d')
# key = rearrange(key, 's b h d -> b h s d')
# dtype = query.dtype
# query = self.rope(query.to(torch.float16), pos_thw)
# key = self.rope(key.to(torch.float16), pos_thw)
# query = rearrange(query, 'b h s d -> s b h d').to(dtype)
# key = rearrange(key, 'b h s d -> s b h d').to(dtype)
query = query.view(-1, batch_size, FA_head_num * head_dim)
key = key.view(-1, batch_size, FA_head_num * head_dim)
value = value.view(-1, batch_size, FA_head_num * head_dim)
# print(f'q {query.shape}, k {key.shape}, v {value.shape}')
if self.sparse1d:
query, pad_len = self._sparse_1d(query, total_frame, height, width)
if self.is_cross_attn:
key = self._sparse_1d_kv(key)
value = self._sparse_1d_kv(value)
else:
key, pad_len = self._sparse_1d(key, total_frame, height, width)
value, pad_len = self._sparse_1d(value, total_frame, height, width)
# print(f'after sparse q {query.shape}, k {key.shape}, v {value.shape}')
if npu_config is not None:
hidden_states = npu_config.run_attention(query, key, value, attention_mask, "SBH", head_dim, FA_head_num)
else:
query = rearrange(query, 's b (h d) -> b h s d', h=FA_head_num)
key = rearrange(key, 's b (h d) -> b h s d', h=FA_head_num)
value = rearrange(value, 's b (h d) -> b h s d', h=FA_head_num)
# 0, -10000 ->(bool) False, True ->(any) True ->(not) False
# 0, 0 ->(bool) False, False ->(any) False ->(not) True
# if attention_mask is None or not torch.any(attention_mask.bool()): # 0 mean visible
# attention_mask = None
# the output of sdp = (batch, num_heads, seq_len, head_dim)
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=False, enable_mem_efficient=True):
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = rearrange(hidden_states, 'b h s d -> s b (h d)', h=FA_head_num)
if self.sparse1d:
hidden_states = self._reverse_sparse_1d(hidden_states, total_frame, height, width, pad_len)
# [s, b, h // sp * d] -> [s // sp * b, h, d] -> [s // sp, b, h * d]
if get_sequence_parallel_state():
hidden_states = all_to_all_SBH(hidden_states.reshape(-1, FA_head_num, head_dim), scatter_dim=0, gather_dim=1)
hidden_states = hidden_states.view(-1, batch_size, inner_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
# if attn.residual_connection:
# print('attn.residual_connection')
# hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
@maybe_allow_in_graph
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
final_dropout: bool = False,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
interpolation_scale_thw: Tuple[int] = (1, 1, 1),
sparse1d: bool = False,
sparse_n: int = 2,
sparse_group: bool = False,
):
super().__init__()
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
interpolation_scale_thw=interpolation_scale_thw,
sparse1d=sparse1d,
sparse_n=sparse_n,
sparse_group=sparse_group,
is_cross_attn=False,
)
# 2. Cross-Attn
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
interpolation_scale_thw=interpolation_scale_thw,
sparse1d=sparse1d,
sparse_n=sparse_n,
sparse_group=sparse_group,
is_cross_attn=True,
) # is self-attn if encoder_hidden_states is none
# 3. Feed-forward
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
# 4. Scale-shift.
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
frame: int = None,
height: int = None,
width: int = None,
) -> torch.FloatTensor:
# 0. Self-Attention
batch_size = hidden_states.shape[1]
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[:, None] + timestep.reshape(6, batch_size, -1)
).chunk(6, dim=0)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=None,
attention_mask=attention_mask, frame=frame, height=height, width=width,
)
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
# 3. Cross-Attention
norm_hidden_states = hidden_states
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, frame=frame, height=height, width=width,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
return hidden_states
================================================
FILE: opensora/models/frame_interpolation/cfgs/AMT-G.yaml
================================================
seed: 2023
network:
name: networks.AMT-G.Model
params:
corr_radius: 3
corr_lvls: 4
num_flows: 5
================================================
FILE: opensora/models/frame_interpolation/interpolation.py
================================================
# this script is modified from https://github.com/MCG-NKU/AMT/blob/main/demos/demo_2x.py
from json import load
import os
import cv2
import sys
import glob
import torch
import argparse
import numpy as np
import os.path as osp
from warnings import warn
from omegaconf import OmegaConf
from torchvision.utils import make_grid
sys.path.append('.')
from utils.utils import (
read, write,
img2tensor, tensor2img,
check_dim_and_resize
)
from utils.build_utils import build_from_cfg
from utils.utils import InputPadder
AMT_G = {
'name': 'networks.AMT-G.Model',
'params':{
'corr_radius': 3,
'corr_lvls': 4,
'num_flows': 5,
}
}
def init(device="cuda"):
'''
initialize the device and the anchor resolution.
'''
if device == 'cuda':
anchor_resolution = 1024 * 512
anchor_memory = 1500 * 1024**2
anchor_memory_bias = 2500 * 1024**2
vram_avail = torch.cuda.get_device_properties(device).total_memory
print("VRAM available: {:.1f} MB".format(vram_avail / 1024 ** 2))
else:
# Do not resize in cpu mode
anchor_resolution = 8192*8192
anchor_memory = 1
anchor_memory_bias = 0
vram_avail = 1
return anchor_resolution, anchor_memory, anchor_memory_bias, vram_avail
def get_input_video_from_path(input_path, device="cuda"):
'''
Get the input video from the input_path.
params:
input_path: str, the path of the input video.
devices: str, the device to run the model.
returns:
inputs: list, the list of the input frames.
scale: float, the scale of the input frames.
padder: InputPadder, the padder to pad the input frames.
'''
anchor_resolution, anchor_memory, anchor_memory_bias, vram_avail = init(device)
if osp.splitext(input_path)[-1] in ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv',
'.webm', '.MP4', '.AVI', '.MOV', '.MKV', '.FLV',
'.WMV', '.WEBM']:
vcap = cv2.VideoCapture(input_path)
inputs = []
w = int(vcap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(vcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory)
scale = 1 if scale > 1 else scale
scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16
if scale < 1:
print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}")
padding = int(16 / scale)
padder = InputPadder((h, w), padding)
while True:
ret, frame = vcap.read()
if ret is False:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_t = img2tensor(frame).to(device)
frame_t = padder.pad(frame_t)
inputs.append(frame_t)
print(f'Loading the [video] from {input_path}, the number of frames [{len(inputs)}]')
else:
raise TypeError("Input should be a video.")
return inputs, scale, padder
def load_model(ckpt_path, device="cuda"):
'''
load the frame interpolation model.
'''
network_cfg = AMT_G
network_name = network_cfg['name']
print(f'Loading [{network_name}] from [{ckpt_path}]...')
model = build_from_cfg(network_cfg)
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt['state_dict'])
model = model.to(device)
model.eval()
return model
def interpolater(model, inputs, scale, padder, iters=1):
'''
interpolating with the interpolation model.
params:
model: nn.Module, the frame interpolation model.
inputs: list, the list of the input frames.
scale: float, the scale of the input frames.
iters: int, the number of iterations of interpolation. The final frames model generating is 2 ** iters * (m - 1) + 1 and m is input frames.
returns:
outputs: list, the list of the output frames.
'''
print(f'Start frame interpolation:')
embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)
for i in range(iters):
print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}')
outputs = [inputs[0]]
for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
in_0 = in_0.to(device)
in_1 = in_1.to(device)
with torch.no_grad():
imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred']
outputs += [imgt_pred.cpu(), in_1.cpu()]
inputs = outputs
outputs = padder.unpad(*outputs)
return outputs
def write(outputs, input_path, output_path, frame_rate=30):
'''
write results to the output_path.
'''
if osp.exists(output_path) is False:
os.makedirs(output_path)
size = outputs[0].shape[2:][::-1]
_, file_name_with_extension = os.path.split(input_path)
file_name, _ = os.path.splitext(file_name_with_extension)
save_video_path = f'{output_path}/output_{file_name}.mp4'
writer = cv2.VideoWriter(save_video_path, cv2.VideoWriter_fourcc(*"mp4v"),
frame_rate, size)
for i, imgt_pred in enumerate(outputs):
imgt_pred = tensor2img(imgt_pred)
imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR)
writer.write(imgt_pred)
print(f"Demo video is saved to [{save_video_path}]")
writer.release()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--ckpt', type=str, default='amt-g.pth', help="The pretrained model.")
parser.add_argument('--niters', type=int, default=1, help="Iter of Interpolation. The number of frames will be double after per iter.")
parser.add_argument('--input', default="test.mp4", help="Input video.")
parser.add_argument('--output_path', type=str, default='results', help="Output path.")
parser.add_argument('--frame_rate', type=int, default=30, help="Frames rate of the output video.")
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ckpt_path = args.ckpt
input_path = args.input
output_path = args.output_path
iters = int(args.niters)
frame_rate = int(args.frame_rate)
inputs, scale, padder = get_input_video_from_path(input_path, device)
model = load_model(ckpt_path, device)
outputs = interpolater(model, inputs, scale, padder, iters)
write(outputs, input_path, output_path, frame_rate)
================================================
FILE: opensora/models/frame_interpolation/networks/AMT-G.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from networks.blocks.raft import (
coords_grid,
BasicUpdateBlock, BidirCorrBlock
)
from networks.blocks.feat_enc import (
LargeEncoder
)
from networks.blocks.ifrnet import (
resize,
Encoder,
InitDecoder,
IntermediateDecoder
)
from networks.blocks.multi_flow import (
multi_flow_combine,
MultiFlowDecoder
)
class Model(nn.Module):
def __init__(self,
corr_radius=3,
corr_lvls=4,
num_flows=5,
channels=[84, 96, 112, 128],
skip_channels=84):
super(Model, self).__init__()
self.radius = corr_radius
self.corr_levels = corr_lvls
self.num_flows = num_flows
self.feat_encoder = LargeEncoder(output_dim=128, norm_fn='instance', dropout=0.)
self.encoder = Encoder(channels, large=True)
self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels)
self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels)
self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels)
self.decoder1 = MultiFlowDecoder(channels[0], skip_channels, num_flows)
self.update4 = self._get_updateblock(112, None)
self.update3_low = self._get_updateblock(96, 2.0)
self.update2_low = self._get_updateblock(84, 4.0)
self.update3_high = self._get_updateblock(96, None)
self.update2_high = self._get_updateblock(84, None)
self.comb_block = nn.Sequential(
nn.Conv2d(3*self.num_flows, 6*self.num_flows, 7, 1, 3),
nn.PReLU(6*self.num_flows),
nn.Conv2d(6*self.num_flows, 3, 7, 1, 3),
)
def _get_updateblock(self, cdim, scale_factor=None):
return BasicUpdateBlock(cdim=cdim, hidden_dim=192, flow_dim=64,
corr_dim=256, corr_dim2=192, fc_dim=188,
scale_factor=scale_factor, corr_levels=self.corr_levels,
radius=self.radius)
def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):
# convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
# based on linear assumption
t1_scale = 1. / embt
t0_scale = 1. / (1. - embt)
if downsample != 1:
inv = 1 / downsample
flow0 = inv * resize(flow0, scale_factor=inv)
flow1 = inv * resize(flow1, scale_factor=inv)
corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale)
corr = torch.cat([corr0, corr1], dim=1)
flow = torch.cat([flow0, flow1], dim=1)
return corr, flow
def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
img0 = img0 - mean_
img1 = img1 - mean_
img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0
img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1
b, _, h, w = img0_.shape
coord = coords_grid(b, h // 8, w // 8, img0.device)
fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8]
corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels)
# f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4]
# f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16]
f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_)
f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_)
######################################### the 4th decoder #########################################
up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt)
corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord,
up_flow0_4, up_flow1_4,
embt, downsample=1)
# residue update with lookup corr
delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4)
delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1)
up_flow0_4 = up_flow0_4 + delta_flow0_4
up_flow1_4 = up_flow1_4 + delta_flow1_4
ft_3_ = ft_3_ + delta_ft_3_
######################################### the 3rd decoder #########################################
up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4)
corr_3, flow_3 = self._corr_scale_lookup(corr_fn,
coord, up_flow0_3, up_flow1_3,
embt, downsample=2)
# residue update with lookup corr
delta_ft_2_, delta_flow_3 = self.update3_low(ft_2_, flow_3, corr_3)
delta_flow0_3, delta_flow1_3 = torch.chunk(delta_flow_3, 2, 1)
up_flow0_3 = up_flow0_3 + delta_flow0_3
up_flow1_3 = up_flow1_3 + delta_flow1_3
ft_2_ = ft_2_ + delta_ft_2_
# residue update with lookup corr (hr)
corr_3 = resize(corr_3, scale_factor=2.0)
up_flow_3 = torch.cat([up_flow0_3, up_flow1_3], dim=1)
delta_ft_2_, delta_up_flow_3 = self.update3_high(ft_2_, up_flow_3, corr_3)
ft_2_ += delta_ft_2_
up_flow0_3 += delta_up_flow_3[:, 0:2]
up_flow1_3 += delta_up_flow_3[:, 2:4]
######################################### the 2nd decoder #########################################
up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3)
corr_2, flow_2 = self._corr_scale_lookup(corr_fn,
coord, up_flow0_2, up_flow1_2,
embt, downsample=4)
# residue update with lookup corr
delta_ft_1_, delta_flow_2 = self.update2_low(ft_1_, flow_2, corr_2)
delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1)
up_flow0_2 = up_flow0_2 + delta_flow0_2
up_flow1_2 = up_flow1_2 + delta_flow1_2
ft_1_ = ft_1_ + delta_ft_1_
# residue update with lookup corr (hr)
corr_2 = resize(corr_2, scale_factor=4.0)
up_flow_2 = torch.cat([up_flow0_2, up_flow1_2], dim=1)
delta_ft_1_, delta_up_flow_2 = self.update2_high(ft_1_, up_flow_2, corr_2)
ft_1_ += delta_ft_1_
up_flow0_2 += delta_up_flow_2[:, 0:2]
up_flow1_2 += delta_up_flow_2[:, 2:4]
######################################### the 1st decoder #########################################
up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2)
if scale_factor != 1.0:
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
mask = resize(mask, scale_factor=(1.0/scale_factor))
img_res = resize(img_res, scale_factor=(1.0/scale_factor))
# Merge multiple predictions
imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1,
mask, img_res, mean_)
imgt_pred = torch.clamp(imgt_pred, 0, 1)
if eval:
return { 'imgt_pred': imgt_pred, }
else:
up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w)
up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w)
return {
'imgt_pred': imgt_pred,
'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
'ft_pred': [ft_1_, ft_2_, ft_3_],
}
================================================
FILE: opensora/models/frame_interpolation/networks/__init__.py
================================================
================================================
FILE: opensora/models/frame_interpolation/networks/blocks/__init__.py
================================================
================================================
FILE: opensora/models/frame_interpolation/networks/blocks/feat_enc.py
================================================
import torch
import torch.nn as nn
class BottleneckBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(BottleneckBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes//4)
self.norm2 = nn.BatchNorm2d(planes//4)
self.norm3 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm4 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes//4)
self.norm2 = nn.InstanceNorm2d(planes//4)
self.norm3 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm4 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
self.norm3 = nn.Sequential()
if not stride == 1:
self.norm4 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
y = self.relu(self.norm3(self.conv3(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
class SmallEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(SmallEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(32)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(32)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 32
self.layer1 = self._make_layer(32, stride=1)
self.layer2 = self._make_layer(64, stride=2)
self.layer3 = self._make_layer(96, stride=2)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x
class BasicEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(64)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(72, stride=2)
self.layer3 = self._make_layer(128, stride=2)
# output convolution
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x
class LargeEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(LargeEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(64)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(112, stride=2)
self.layer3 = self._make_layer(160, stride=2)
self.layer3_2 = self._make_layer(160, stride=1)
# output convolution
self.conv2 = nn.Conv2d(self.in_planes, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer3_2(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x
================================================
FILE: opensora/models/frame_interpolation/networks/blocks/ifrnet.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.flow_utils import warp
def resize(x, scale_factor):
return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias),
nn.PReLU(out_channels)
)
class ResBlock(nn.Module):
def __init__(self, in_channels, side_channels, bias=True):
super(ResBlock, self).__init__()
self.side_channels = side_channels
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias),
nn.PReLU(in_channels)
)
self.conv2 = nn.Sequential(
nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias),
nn.PReLU(side_channels)
)
self.conv3 = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias),
nn.PReLU(in_channels)
)
self.conv4 = nn.Sequential(
nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias),
nn.PReLU(side_channels)
)
self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias)
self.prelu = nn.PReLU(in_channels)
def forward(self, x):
out = self.conv1(x)
res_feat = out[:, :-self.side_channels, ...]
side_feat = out[:, -self.side_channels:, :, :]
side_feat = self.conv2(side_feat)
out = self.conv3(torch.cat([res_feat, side_feat], 1))
res_feat = out[:, :-self.side_channels, ...]
side_feat = out[:, -self.side_channels:, :, :]
side_feat = self.conv4(side_feat)
out = self.conv5(torch.cat([res_feat, side_feat], 1))
out = self.prelu(x + out)
return out
class Encoder(nn.Module):
def __init__(self, channels, large=False):
super(Encoder, self).__init__()
self.channels = channels
prev_ch = 3
for idx, ch in enumerate(channels, 1):
k = 7 if large and idx == 1 else 3
p = 3 if k ==7 else 1
self.register_module(f'pyramid{idx}',
nn.Sequential(
convrelu(prev_ch, ch, k, 2, p),
convrelu(ch, ch, 3, 1, 1)
))
prev_ch = ch
def forward(self, in_x):
fs = []
for idx in range(len(self.channels)):
out_x = getattr(self, f'pyramid{idx+1}')(in_x)
fs.append(out_x)
in_x = out_x
return fs
class InitDecoder(nn.Module):
def __init__(self, in_ch, out_ch, skip_ch) -> None:
super().__init__()
self.convblock = nn.Sequential(
convrelu(in_ch*2+1, in_ch*2),
ResBlock(in_ch*2, skip_ch),
nn.ConvTranspose2d(in_ch*2, out_ch+4, 4, 2, 1, bias=True)
)
def forward(self, f0, f1, embt):
h, w = f0.shape[2:]
embt = embt.repeat(1, 1, h, w)
out = self.convblock(torch.cat([f0, f1, embt], 1))
flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1)
ft_ = out[:, 4:, ...]
return flow0, flow1, ft_
class IntermediateDecoder(nn.Module):
def __init__(self, in_ch, out_ch, skip_ch) -> None:
super().__init__()
self.convblock = nn.Sequential(
convrelu(in_ch*3+4, in_ch*3),
ResBlock(in_ch*3, skip_ch),
nn.ConvTranspose2d(in_ch*3, out_ch+4, 4, 2, 1, bias=True)
)
def forward(self, ft_, f0, f1, flow0_in, flow1_in):
f0_warp = warp(f0, flow0_in)
f1_warp = warp(f1, flow1_in)
f_in = torch.cat([ft_, f0_warp, f1_warp, flow0_in, flow1_in], 1)
out = self.convblock(f_in)
flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1)
ft_ = out[:, 4:, ...]
flow0 = flow0 + 2.0 * resize(flow0_in, scale_factor=2.0)
flow1 = flow1 + 2.0 * resize(flow1_in, scale_factor=2.0)
return flow0, flow1, ft_
================================================
FILE: opensora/models/frame_interpolation/networks/blocks/multi_flow.py
================================================
import torch
import torch.nn as nn
from utils.flow_utils import warp
from networks.blocks.ifrnet import (
convrelu, resize,
ResBlock,
)
def multi_flow_combine(comb_block, img0, img1, flow0, flow1,
mask=None, img_res=None, mean=None):
'''
A parallel implementation of multiple flow field warping
comb_block: An nn.Seqential object.
img shape: [b, c, h, w]
flow shape: [b, 2*num_flows, h, w]
mask (opt):
If 'mask' is None, the function conduct a simple average.
img_res (opt):
If 'img_res' is None, the function adds zero instead.
mean (opt):
If 'mean' is None, the function adds zero instead.
'''
b, c, h, w = flow0.shape
num_flows = c // 2
flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
mask = mask.reshape(b, num_flows, 1, h, w
).reshape(-1, 1, h, w) if mask is not None else None
img_res = img_res.reshape(b, num_flows, 3, h, w
).reshape(-1, 3, h, w) if img_res is not None else 0
img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w)
img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w)
mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1
) if mean is not None else 0
img0_warp = warp(img0, flow0)
img1_warp = warp(img1, flow1)
img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res
img_warps = img_warps.reshape(b, num_flows, 3, h, w)
imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w))
return imgt_pred
class MultiFlowDecoder(nn.Module):
def __init__(self, in_ch, skip_ch, num_flows=3):
super(MultiFlowDecoder, self).__init__()
self.num_flows = num_flows
self.convblock = nn.Sequential(
convrelu(in_ch*3+4, in_ch*3),
ResBlock(in_ch*3, skip_ch),
nn.ConvTranspose2d(in_ch*3, 8*num_flows, 4, 2, 1, bias=True)
)
def forward(self, ft_, f0, f1, flow0, flow1):
n = self.num_flows
f0_warp = warp(f0, flow0)
f1_warp = warp(f1, flow1)
out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1))
delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2*n, 2*n, n, 3*n], 1)
mask = torch.sigmoid(mask)
flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0
).repeat(1, self.num_flows, 1, 1)
flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0
).repeat(1, self.num_flows, 1, 1)
return flow0, flow1, mask, img_res
================================================
FILE: opensora/models/frame_interpolation/networks/blocks/raft.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
def resize(x, scale_factor):
return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
def bilinear_sampler(img, coords, mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def coords_grid(batch, ht, wd, device):
coords = torch.meshgrid(torch.arange(ht, device=device),
torch.arange(wd, device=device),
indexing='ij')
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
class SmallUpdateBlock(nn.Module):
def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, fc_dim,
corr_levels=4, radius=3, scale_factor=None):
super(SmallUpdateBlock, self).__init__()
cor_planes = corr_levels * (2 * radius + 1) **2
self.scale_factor = scale_factor
self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)
self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3)
self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1)
self.conv = nn.Conv2d(corr_dim+flow_dim, fc_dim, 3, padding=1)
self.gru = nn.Sequential(
nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
)
self.feat_head = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(hidden_dim, cdim, 3, padding=1),
)
self.flow_head = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(hidden_dim, 4, 3, padding=1),
)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, net, flow, corr):
net = resize(net, 1 / self.scale_factor
) if self.scale_factor is not None else net
cor = self.lrelu(self.convc1(corr))
flo = self.lrelu(self.convf1(flow))
flo = self.lrelu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
inp = self.lrelu(self.conv(cor_flo))
inp = torch.cat([inp, flow, net], dim=1)
out = self.gru(inp)
delta_net = self.feat_head(out)
delta_flow = self.flow_head(out)
if self.scale_factor is not None:
delta_net = resize(delta_net, scale_factor=self.scale_factor)
delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor)
return delta_net, delta_flow
class BasicUpdateBlock(nn.Module):
def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, corr_dim2,
fc_dim, corr_levels=4, radius=3, scale_factor=None, out_num=1):
super(BasicUpdateBlock, self).__init__()
cor_planes = corr_levels * (2 * radius + 1) **2
self.scale_factor = scale_factor
self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)
self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1)
self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3)
self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1)
self.conv = nn.Conv2d(flow_dim+corr_dim2, fc_dim, 3, padding=1)
self.gru = nn.Sequential(
nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
)
self.feat_head = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(hidden_dim, cdim, 3, padding=1),
)
self.flow_head = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(hidden_dim, 4*out_num, 3, padding=1),
)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, net, flow, corr):
net = resize(net, 1 / self.scale_factor
) if self.scale_factor is not None else net
cor = self.lrelu(self.convc1(corr))
cor = self.lrelu(self.convc2(cor))
flo = self.lrelu(self.convf1(flow))
flo = self.lrelu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
inp = self.lrelu(self.conv(cor_flo))
inp = torch.cat([inp, flow, net], dim=1)
out = self.gru(inp)
delta_net = self.feat_head(out)
delta_flow = self.flow_head(out)
if self.scale_factor is not None:
delta_net = resize(delta_net, scale_factor=self.scale_factor)
delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor)
return delta_net, delta_flow
class BidirCorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
self.corr_pyramid_T = []
corr = BidirCorrBlock.corr(fmap1, fmap2)
batch, h1, w1, dim, h2, w2 = corr.shape
corr_T = corr.clone().permute(0, 4, 5, 3, 1, 2)
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
corr_T = corr_T.reshape(batch*h2*w2, dim, h1, w1)
self.corr_pyramid.append(corr)
self.corr_pyramid_T.append(corr_T)
for _ in range(self.num_levels-1):
corr = F.avg_pool2d(corr, 2, stride=2)
corr_T = F.avg_pool2d(corr_T, 2, stride=2)
self.corr_pyramid.append(corr)
self.corr_pyramid_T.append(corr_T)
def __call__(self, coords0, coords1):
r = self.radius
coords0 = coords0.permute(0, 2, 3, 1)
coords1 = coords1.permute(0, 2, 3, 1)
assert coords0.shape == coords1.shape, f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]"
batch, h1, w1, _ = coords0.shape
out_pyramid = []
out_pyramid_T = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
corr_T = self.corr_pyramid_T[i]
dx = torch.linspace(-r, r, 2*r+1, device=coords0.device)
dy = torch.linspace(-r, r, 2*r+1, device=coords0.device)
delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), axis=-1)
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
centroid_lvl_0 = coords0.reshape(batch*h1*w1, 1, 1, 2) / 2**i
centroid_lvl_1 = coords1.reshape(batch*h1*w1, 1, 1, 2) / 2**i
coords_lvl_0 = centroid_lvl_0 + delta_lvl
coords_lvl_1 = centroid_lvl_1 + delta_lvl
corr = bilinear_sampler(corr, coords_lvl_0)
corr_T = bilinear_sampler(corr_T, coords_lvl_1)
corr = corr.view(batch, h1, w1, -1)
corr_T = corr_T.view(batch, h1, w1, -1)
out_pyramid.append(corr)
out_pyramid_T.append(corr_T)
out = torch.cat(out_pyramid, dim=-1)
out_T = torch.cat(out_pyramid_T, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float(), out_T.permute(0, 3, 1, 2).contiguous().float()
@staticmethod
def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht*wd)
fmap2 = fmap2.view(batch, dim, ht*wd)
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float())
================================================
FILE: opensora/models/frame_interpolation/readme.md
================================================
#### Frame Interpolation
We use AMT as our frame interpolation model. (Thanks [AMT](https://github.com/MCG-NKU/AMT)) After sampling, you can use frame interpolation model to interpolate your video smoothly.
1. Download the pretrained weights from [AMT](https://github.com/MCG-NKU/AMT), we recommend using the largest model AMT-G to achieve the best performance.
2. Run the script of frame interpolation.
```
python opensora/models/frame_interpolation/interpolation.py --ckpt /path/to/ckpt --niters 1 --input /path/to/input/video.mp4 --output_path /path/to/output/floder --frame_rate 30
```
3. The output video will be stored at output_path and its duration time is equal `the total number of frames after frame interpolation / the frame rate`
##### Frame Interpolation Specific Settings
* `--ckpt`: Pretrained model of [AMT](https://github.com/MCG-NKU/AMT). We use AMT-G as our frame interpolation model.
* `--niter`: Iterations of interpolation. With $m$ input frames, `[N_ITER]` $=n$ corresponds to $2^n\times (m-1)+1$ output frames.
* `--input`: Path of the input video.
* `--output_path`: Folder Path of the output video.
* `--frame_rate"`: Frame rate of the output video.
================================================
FILE: opensora/models/frame_interpolation/utils/__init__.py
================================================
================================================
FILE: opensora/models/frame_interpolation/utils/build_utils.py
================================================
import importlib
def base_build_fn(module, cls, params):
return getattr(importlib.import_module(
module, package=None), cls)(**params)
def build_from_cfg(config):
module, cls = config['name'].rsplit(".", 1)
params = config.get('params', {})
return base_build_fn(module, cls, params)
================================================
FILE: opensora/models/frame_interpolation/utils/dist_utils.py
================================================
import os
import torch
def get_world_size():
"""Find OMPI world size without calling mpi functions
:rtype: int
"""
if os.environ.get('PMI_SIZE') is not None:
return int(os.environ.get('PMI_SIZE') or 1)
elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None:
return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1)
else:
return torch.cuda.device_count()
def get_global_rank():
"""Find OMPI world rank without calling mpi functions
:rtype: int
"""
if os.environ.get('PMI_RANK') is not None:
return int(os.environ.get('PMI_RANK') or 0)
elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None:
return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0)
else:
return 0
def get_local_rank():
"""Find OMPI local rank without calling mpi functions
:rtype: int
"""
if os.environ.get('MPI_LOCALRANKID') is not None:
return int(os.environ.get('MPI_LOCALRANKID') or 0)
elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None:
return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0)
else:
return 0
def get_master_ip():
if os.environ.get('AZ_BATCH_MASTER_NODE') is not None:
return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0]
elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None:
return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE')
else:
return "127.0.0.1"
================================================
FILE: opensora/models/frame_interpolation/utils/flow_utils.py
================================================
import numpy as np
import torch
from PIL import ImageFile
import torch.nn.functional as F
ImageFile.LOAD_TRUNCATED_IMAGES = True
def warp(img, flow):
B, _, H, W = flow.shape
xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1)
yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W)
grid = torch.cat([xx, yy], 1).to(img)
flow_ = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), flow[:, 1:2, :, :] / ((H - 1.0) / 2.0)], 1)
grid_ = (grid + flow_).permute(0, 2, 3, 1)
output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True)
return output
def make_colorwheel():
"""
Generates a color wheel for optical flow visualization as presented in:
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
Code follows the original C++ source code of Daniel Scharstein.
Code follows the the Matlab source code of Deqing Sun.
Returns:
np.ndarray: Color wheel
"""
RY = 15
YG = 6
GC = 4
CB = 11
BM = 13
MR = 6
ncols = RY + YG + GC + CB + BM + MR
colorwheel = np.zeros((ncols, 3))
col = 0
# RY
colorwheel[0:RY, 0] = 255
colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
col = col+RY
# YG
colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
colorwheel[col:col+YG, 1] = 255
col = col+YG
# GC
colorwheel[col:col+GC, 1] = 255
colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
col = col+GC
# CB
colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
colorwheel[col:col+CB, 2] = 255
col = col+CB
# BM
colorwheel[col:col+BM, 2] = 255
colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
col = col+BM
# MR
colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
colorwheel[col:col+MR, 0] = 255
return colorwheel
def flow_uv_to_colors(u, v, convert_to_bgr=False):
"""
Applies the flow color wheel to (possibly clipped) flow components u and v.
According to the C++ source code of Daniel Scharstein
According to the Matlab source code of Deqing Sun
Args:
u (np.ndarray): Input horizontal flow of shape [H,W]
v (np.ndarray): Input vertical flow of shape [H,W]
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
Returns:
np.ndarray: Flow visualization image of shape [H,W,3]
"""
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
colorwheel = make_colorwheel() # shape [55x3]
ncols = colorwheel.shape[0]
rad = np.sqrt(np.square(u) + np.square(v))
a = np.arctan2(-v, -u)/np.pi
fk = (a+1) / 2*(ncols-1)
k0 = np.floor(fk).astype(np.int32)
k1 = k0 + 1
k1[k1 == ncols] = 0
f = fk - k0
for i in range(colorwheel.shape[1]):
tmp = colorwheel[:,i]
col0 = tmp[k0] / 255.0
col1 = tmp[k1] / 255.0
col = (1-f)*col0 + f*col1
idx = (rad <= 1)
col[idx] = 1 - rad[idx] * (1-col[idx])
col[~idx] = col[~idx] * 0.75 # out of range
# Note the 2-i => BGR instead of RGB
ch_idx = 2-i if convert_to_bgr else i
flow_image[:,:,ch_idx] = np.floor(255 * col)
return flow_image
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
"""
Expects a two dimensional flow image of shape.
Args:
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
Returns:
np.ndarray: Flow visualization image of shape [H,W,3]
"""
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
if clip_flow is not None:
flow_uv = np.clip(flow_uv, 0, clip_flow)
u = flow_uv[:,:,0]
v = flow_uv[:,:,1]
rad = np.sqrt(np.square(u) + np.square(v))
rad_max = np.max(rad)
epsilon = 1e-5
u = u / (rad_max + epsilon)
v = v / (rad_max + epsilon)
return flow_uv_to_colors(u, v, convert_to_bgr)
================================================
FILE: opensora/models/frame_interpolation/utils/utils.py
================================================
import re
import sys
import torch
import random
import numpy as np
from PIL import ImageFile
import torch.nn.functional as F
from imageio import imread, imwrite
ImageFile.LOAD_TRUNCATED_IMAGES = True
class AverageMeter():
def __init__(self):
self.reset()
def reset(self):
self.val = 0.
self.avg = 0.
self.sum = 0.
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class AverageMeterGroups:
def __init__(self) -> None:
self.meter_dict = dict()
def update(self, dict, n=1):
for name, val in dict.items():
if self.meter_dict.get(name) is None:
self.meter_dict[name] = AverageMeter()
self.meter_dict[name].update(val, n)
def reset(self, name=None):
if name is None:
for v in self.meter_dict.values():
v.reset()
else:
meter = self.meter_dict.get(name)
if meter is not None:
meter.reset()
def avg(self, name):
meter = self.meter_dict.get(name)
if meter is not None:
return meter.avg
class InputPadder:
""" Pads images such that dimensions are divisible by divisor """
def __init__(self, dims, divisor=16):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor
pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
def pad(self, *inputs):
if len(inputs) == 1:
return F.pad(inputs[0], self._pad, mode='replicate')
else:
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
def unpad(self, *inputs):
if len(inputs) == 1:
return self._unpad(inputs[0])
else:
return [self._unpad(x) for x in inputs]
def _unpad(self, x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
return x[..., c[0]:c[1], c[2]:c[3]]
def img2tensor(img):
if img.shape[-1] > 3:
img = img[:,:,:3]
return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0
def tensor2img(img_t):
return (img_t * 255.).detach(
).squeeze(0).permute(1, 2, 0).cpu().numpy(
).clip(0, 255).astype(np.uint8)
def seed_all(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def read(file):
if file.endswith('.float3'): return readFloat(file)
elif file.endswith('.flo'): return readFlow(file)
elif file.endswith('.ppm'): return readImage(file)
elif file.endswith('.pgm'): return readImage(file)
elif file.endswith('.png'): return readImage(file)
elif file.endswith('.jpg'): return readImage(file)
elif file.endswith('.pfm'): return readPFM(file)[0]
else: raise Exception('don\'t know how to read %s' % file)
def write(file, data):
if file.endswith('.float3'): return writeFloat(file, data)
elif file.endswith('.flo'): return writeFlow(file, data)
elif file.endswith('.ppm'): return writeImage(file, data)
elif file.endswith('.pgm'): return writeImage(file, data)
elif file.endswith('.png'): return writeImage(file, data)
elif file.endswith('.jpg'): return writeImage(file, data)
elif file.endswith('.pfm'): return writePFM(file, data)
else: raise Exception('don\'t know how to write %s' % file)
def readPFM(file):
file = open(file, 'rb')
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().rstrip()
if header.decode("ascii") == 'PF':
color = True
elif header.decode("ascii") == 'Pf':
color = False
else:
raise Exception('Not a PFM file.')
dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii"))
if dim_match:
width, height = list(map(int, dim_match.groups()))
else:
raise Exception('Malformed PFM header.')
scale = float(file.readline().decode("ascii").rstrip())
if scale < 0:
endian = '<'
scale = -scale
else:
endian = '>'
data = np.fromfile(file, endian + 'f')
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
return data, scale
def writePFM(file, image, scale=1):
file = open(file, 'wb')
color = None
if image.dtype.name != 'float32':
raise Exception('Image dtype must be float32.')
image = np.flipud(image)
if len(image.shape) == 3 and image.shape[2] == 3:
color = True
elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1:
color = False
else:
raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')
file.write('PF\n' if color else 'Pf\n'.encode())
file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0]))
endian = image.dtype.byteorder
if endian == '<' or endian == '=' and sys.byteorder == 'little':
scale = -scale
file.write('%f\n'.encode() % scale)
image.tofile(file)
def readFlow(name):
if name.endswith('.pfm') or name.endswith('.PFM'):
return readPFM(name)[0][:,:,0:2]
f = open(name, 'rb')
header = f.read(4)
if header.decode("utf-8") != 'PIEH':
raise Exception('Flow file header does not contain PIEH')
width = np.fromfile(f, np.int32, 1).squeeze()
height = np.fromfile(f, np.int32, 1).squeeze()
flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2))
return flow.astype(np.float32)
def readImage(name):
if name.endswith('.pfm') or name.endswith('.PFM'):
data = readPFM(name)[0]
if len(data.shape)==3:
return data[:,:,0:3]
else:
return data
return imread(name)
def writeImage(name, data):
if name.endswith('.pfm') or name.endswith('.PFM'):
return writePFM(name, data, 1)
return imwrite(name, data)
def writeFlow(name, flow):
f = open(name, 'wb')
f.write('PIEH'.encode('utf-8'))
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
flow = flow.astype(np.float32)
flow.tofile(f)
def readFloat(name):
f = open(name, 'rb')
if(f.readline().decode("utf-8")) != 'float\n':
raise Exception('float file %s did not contain keyword' % name)
dim = int(f.readline())
dims = []
count = 1
for i in range(0, dim):
d = int(f.readline())
dims.append(d)
count *= d
dims = list(reversed(dims))
data = np.fromfile(f, np.float32, count).reshape(dims)
if dim > 2:
data = np.transpose(data, (2, 1, 0))
data = np.transpose(data, (1, 0, 2))
return data
def writeFloat(name, data):
f = open(name, 'wb')
dim=len(data.shape)
if dim>3:
raise Exception('bad float file dimension: %d' % dim)
f.write(('float\n').encode('ascii'))
f.write(('%d\n' % dim).encode('ascii'))
if dim == 1:
f.write(('%d\n' % data.shape[0]).encode('ascii'))
else:
f.write(('%d\n' % data.shape[1]).encode('ascii'))
f.write(('%d\n' % data.shape[0]).encode('ascii'))
for i in range(2, dim):
f.write(('%d\n' % data.shape[i]).encode('ascii'))
data = data.astype(np.float32)
if dim==2:
data.tofile(f)
else:
np.transpose(data, (2, 0, 1)).tofile(f)
def check_dim_and_resize(tensor_list):
shape_list = []
for t in tensor_list:
shape_list.append(t.shape[2:])
if len(set(shape_list)) > 1:
desired_shape = shape_list[0]
print(f'Inconsistent size of input video frames. All frames will be resized to {desired_shape}')
resize_tensor_list = []
for t in tensor_list:
resize_tensor_list.append(torch.nn.functional.interpolate(t, size=tuple(desired_shape), mode='bilinear'))
tensor_list = resize_tensor_list
return tensor_list
================================================
FILE: opensora/models/prompt_refiner/inference.py
================================================
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from tqdm import tqdm
import argparse
def get_output(prompt):
template = "Refine the sentence: \"{}\" to contain subject description, action, scene description. " \
"(Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. " \
"Make sure it is a fluent sentence, not nonsense."
prompt = template.format(prompt)
messages = [
{"role": "system", "content": "You are a caption refiner."},
{"role": "user", "content": prompt}
]
input_ids = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([input_ids], return_tensors="pt").to(device)
generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print('\nInput\n:', prompt)
print('\nOutput\n:', response)
return response
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--mode_path", type=str, default="llama3_8B_lora_merged_cn")
parser.add_argument("--prompt", type=str, default='a dog is running.')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
device = torch.device('cuda')
tokenizer = AutoTokenizer.from_pretrained(args.mode_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(args.mode_path,torch_dtype=torch.bfloat16, trust_remote_code=True).to(device).eval()
response = get_output(args.prompt)
================================================
FILE: opensora/models/prompt_refiner/merge.py
================================================
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
import argparse
def get_lora_model(base_model_path, lora_model_input_path, lora_model_output_path):
model = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, device_map="auto",trust_remote_code=True)
model = PeftModel.from_pretrained(model, lora_model_input_path)
merged_model = model.merge_and_unload()
merged_model.save_pretrained(lora_model_output_path, safe_serialization=True)
print("Merge lora to base model")
tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)
tokenizer.save_pretrained(lora_model_output_path)
print("Save tokenizer")
def get_model_result(base_model_path, fintune_model_path):
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
device = "cuda"
fintune_model = AutoModelForCausalLM.from_pretrained(
fintune_model_path,
device_map="auto",
torch_dtype=torch.bfloat16,
).eval()
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
device_map="auto",
torch_dtype=torch.bfloat16,
).eval()
template = "Refine the sentence: \"{}\" to contain subject description, action, scene description. " \
"(Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. " \
"Make sure it is a fluent sentence, not nonsense."
prompt = "a dog和一只猫"
prompt = template.format(prompt)
messages = [
{"role": "system", "content": "You are a caption refiner."},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(device)
def get_result(model_inputs, model):
generated_ids = model.generate(
model_inputs.input_ids,
max_new_tokens=512,
eos_token_id=tokenizer.get_vocab()["<|eot_id|>"]
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
base_model_response = get_result(model_inputs, base_model)
fintune_model_response = get_result(model_inputs, fintune_model)
print("\nInput\n", prompt)
print("\nResult before fine-tune:\n", base_model_response)
print("\nResult after fine-tune:\n", fintune_model_response)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--base_path", type=str, default="Meta-Llama-3___1-8B-Instruct")
parser.add_argument("--lora_in_path", type=str, default="llama3_1_instruct_lora/checkpoint-1008")
parser.add_argument("--lora_out_path", type=str, default="llama3_1_instruct_lora/llama3_8B_lora_merged_cn")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
get_lora_model(args.base_path, args.lora_in_path, args.lora_out_path)
get_model_result(args.base_path, args.lora_out_path)
================================================
FILE: opensora/models/prompt_refiner/train.py
================================================
from datasets import Dataset
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer, GenerationConfig
from peft import LoraConfig, TaskType, get_peft_model
import torch
import argparse
ins = "Refine the sentence to contain subject description, action, scene description. " \
"(Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. " \
"Make sure it is a fluent sentence, not nonsense."
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str, default='refine_32255.json')
parser.add_argument("--model_path", type=str, default='Meta-Llama-3___1-8B-Instruct')
parser.add_argument("--lora_out_path", type=str, default="llama3_1_instruct_lora")
args = parser.parse_args()
return args
args = parse_args()
df = pd.read_json(args.data_path)
ds = Dataset.from_pandas(df)
tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
def process_func(example):
MAX_LENGTH = 2048
input_ids, attention_mask, labels = [], [], []
instruction = tokenizer(f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a caption refiner.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{example['instruction'] + example['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", add_special_tokens=False) # add_special_tokens 不在开头加 special_tokens
response = tokenizer(f"{example['output']}<|eot_id|>", add_special_tokens=False)
input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]
labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]
if len(input_ids) > MAX_LENGTH:
input_ids = input_ids[:MAX_LENGTH]
attention_mask = attention_mask[:MAX_LENGTH]
labels = labels[:MAX_LENGTH]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
tokenized_id = ds.map(process_func, remove_columns=ds.column_names)
model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="auto",torch_dtype=torch.bfloat16)
print(model)
model.enable_input_require_grads()
config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
inference_mode=False,
r=64,
lora_alpha=64,
lora_dropout=0.1
)
print(config)
model = get_peft_model(model, config)
model.print_trainable_parameters()
args = TrainingArguments(
output_dir=args.lora_out_path,
per_device_train_batch_size=32,
gradient_accumulation_steps=1,
logging_steps=1,
num_train_epochs=1,
save_steps=20,
dataloader_num_workers=4,
learning_rate=1.5e-4,
warmup_ratio=0.1,
save_on_each_node=True,
gradient_checkpointing=True,
report_to='wandb',
)
trainer = Trainer(
model=model,
args=args,
train_dataset=tokenized_id,
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
trainer.train()
================================================
FILE: opensora/models/text_encoder/__init__.py
================================================
from opensora.models.text_encoder.clip import CLIPWrapper
from opensora.models.text_encoder.t5 import T5Wrapper
text_encoder = {
'google/mt5-xl': T5Wrapper,
'google/mt5-xxl': T5Wrapper,
'google/umt5-xl': T5Wrapper,
'google/umt5-xxl': T5Wrapper,
'google/t5-v1_1-xl': T5Wrapper,
'DeepFloyd/t5-v1_1-xxl': T5Wrapper,
'openai/clip-vit-large-patch14': CLIPWrapper,
'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k': CLIPWrapper
}
def get_text_warpper(text_encoder_name):
"""deprecation"""
encoder_key = None
for key in text_encoder.keys():
if key in text_encoder_name:
encoder_key = key
break
text_enc = text_encoder.get(encoder_key, None)
assert text_enc is not None
return text_enc
================================================
FILE: opensora/models/text_encoder/clip.py
================================================
import torch
from torch import nn
from transformers import CLIPTextModelWithProjection
try:
import torch_npu
except:
torch_npu = None
class CLIPWrapper(nn.Module):
def __init__(self, args, **kwargs):
super(CLIPWrapper, self).__init__()
self.model_name = args.text_encoder_name_2
if torch_npu is not None:
self.model_name = '/home/save_dir/pretrained/clip/models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k/snapshots/bc7788f151930d91b58474715fdce5524ad9a189'
else:
self.model_name = '/storage/cache_dir/CLIP-ViT-bigG-14-laion2B-39B-b160k'
print(f'Loading CLIP model from {self.model_name}...')
self.text_enc = CLIPTextModelWithProjection.from_pretrained(self.model_name, cache_dir=args.cache_dir, **kwargs).eval()
def forward(self, input_ids, attention_mask):
text_encoder_embs = self.text_enc(input_ids=input_ids, output_hidden_states=True)[0]
return text_encoder_embs.detach()
================================================
FILE: opensora/models/text_encoder/t5.py
================================================
import torch
from torch import nn
from transformers import T5EncoderModel
try:
import torch_npu
except:
torch_npu = None
class T5Wrapper(nn.Module):
def __init__(self, args, **kwargs):
super(T5Wrapper, self).__init__()
self.model_name = args.text_encoder_name_1
print(f'Loading T5 model from {self.model_name}...')
self.text_enc = T5EncoderModel.from_pretrained(self.model_name, cache_dir=args.cache_dir, **kwargs).eval()
def forward(self, input_ids, attention_mask):
text_encoder_embs = self.text_enc(input_ids=input_ids, attention_mask=attention_mask)['last_hidden_state']
return text_encoder_embs.detach()
================================================
FILE: opensora/npu_config.py
================================================
import math
import mmap
import os
import pickle
import random
import numpy as np
import torch
import subprocess
import sys
import threading
import gc
import torch.distributed as dist
from opensora.adaptor.zp_manager import zp_manager
try:
import torch_npu
npu_is_available = True
from torch_npu.contrib import transfer_to_npu
except:
npu_is_available = False
from contextlib import contextmanager
import types
def compress_video(input_file, output_file, out_size):
"""使用 ffmpeg 压缩视频文件。"""
command = [
'ffmpeg',
'-i', input_file,
'-vf', f"scale='min({out_size},iw)':'min({out_size},ih)':force_original_aspect_ratio=decrease",
'-c:v', 'libx264',
'-crf', '18',
'-preset', 'slow',
'-c:a', 'copy',
output_file
]
subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
@contextmanager
def set_run_dtype(x, dtype=None):
# 保存原始环境变量的值(如果存在)
npu_config.original_run_dtype = x.dtype
# 设置环境变量为指定的值
npu_config.current_run_dtype = dtype
try:
# Yield control back to the body of the `with` statement
yield
finally:
# 恢复原始的环境变量值
npu_config.current_run_dtype = None
npu_config.original_run_dtype = None
class NPUConfig:
N_NPU_PER_NODE = 8
def __init__(self):
self.on_npu = npu_is_available
self.node_world_size = self.N_NPU_PER_NODE
self.profiling = False
self.profiling_step = 5
self.enable_FA = True
self.enable_FP32 = False
self.load_pickle = True
self.use_small_dataset = False
self.current_run_dtype = None
self.original_run_dtype = None
self.zp_manager = zp_manager
self.replaced_type = torch.float32
self.conv_dtype = torch.float16
if self.enable_FA and self.enable_FP32:
self.inf_float = -10000.0
else:
self.inf_float = -10000.0
if self.use_small_dataset:
self.load_pickle = False
self._loss = []
self.work_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
self.pickle_save_path = f"{self.work_path}/pickles"
self.mm = dict()
if self.on_npu:
import deepspeed
import sys
torch_npu.npu.set_compile_mode(jit_compile=False)
import deepspeed.runtime.utils as utils
from opensora.adaptor.utils import all_gather_dp_groups, all_gather_into_tensor_dp_groups
utils.all_gather_dp_groups = all_gather_dp_groups
import deepspeed.runtime.bf16_optimizer as bf16_optimizer
from opensora.adaptor.bf16_optimizer import BF16_Optimizer
self.replace_methods(bf16_optimizer.BF16_Optimizer, BF16_Optimizer)
from opensora.adaptor.stage_1_and_2 import DeepSpeedZeroOptimizer
import deepspeed.runtime.zero.stage_1_and_2 as stage_1_and_2
self.replace_methods(stage_1_and_2.DeepSpeedZeroOptimizer, DeepSpeedZeroOptimizer, ['_has_inf_or_nan'])
import deepspeed.runtime.engine as engine
from opensora.adaptor.engine import DeepSpeedEngine
self.replace_methods(engine.DeepSpeedEngine, DeepSpeedEngine, skip_fcns=['__init__', '_copy_recovery_script', '_change_recovery_script_permissions'])
if "RANK" in os.environ:
self.rank = int(os.environ["RANK"])
self.world_size = int(os.environ["WORLD_SIZE"])
torch_npu.npu.set_device(self.get_local_rank())
else:
self.rank = torch.cuda.current_device()
self.world_size = self.N_NPU_PER_NODE
self.print_with_rank(f"The npu_config.on_npu is {self.on_npu}")
self.bind_thread_to_cpu()
gc.set_threshold(700, 10, 10000)
def get_total_cores(self):
try:
total_cores = os.sysconf('SC_NPROCESSORS_ONLN')
except (AttributeError, ValueError):
total_cores = os.cpu_count()
return total_cores
def bind_thread_to_cpu(self):
total_cores = self.get_total_cores()
# 每个卡的核心数量
cores_per_rank = total_cores // 8
# 计算本地rank
local_rank = self.rank % 8
# 计算当前 rank 的 CPU 核范围
start_core = local_rank * cores_per_rank
end_core = start_core + cores_per_rank - 1
# 构建 CPU 核范围字符串
cpu_cores_range = f"{start_core}-{end_core}"
pid = os.getpid()
command = f"taskset -cp {cpu_cores_range} {pid}"
subprocess.run(command, shell=True, check=True)
return f"Binding Cores:{self.rank}:{pid}:{cpu_cores_range}"
def replace_methods(self, target_class, source_class, skip_fcns=[], only_include_fcns=None):
for attr_name in dir(source_class):
attr_value = getattr(source_class, attr_name)
if attr_name in source_class.__dict__:
attr_class_value = source_class.__dict__[attr_name]
else:
attr_class_value = attr_value
if (isinstance(attr_class_value, staticmethod) or isinstance(attr_class_value, classmethod)
or attr_name in skip_fcns):
print(f"skip replace {attr_name}")
continue
if only_include_fcns is not None and attr_name not in only_include_fcns:
continue
elif isinstance(attr_value, types.FunctionType):
setattr(target_class, attr_name, attr_value)
def get_attention_mask(self, attention_mask, repeat_num):
if self.on_npu and attention_mask is not None:
if npu_config.enable_FA:
attention_mask = attention_mask.to(torch.bool)
attention_mask = attention_mask.repeat_interleave(repeat_num, dim=-2)
return attention_mask
def set_current_run_dtype(self, variables):
if variables[0].dtype != self.current_run_dtype and self.current_run_dtype is not None:
for index, var in enumerate(variables):
variables[index] = var.to(self.current_run_dtype)
return tuple(variables)
def restore_dtype(self, x):
if x.dtype != self.original_run_dtype and self.original_run_dtype is not None:
x = x.to(self.original_run_dtype)
return x
def get_output_video_path(self, name):
os.makedirs(f"{self.work_path}/output_videos", exist_ok=True)
return f"{self.work_path}/output_videos/{name}"
def get_node_id(self):
return self.rank // self.node_world_size
def get_node_size(self):
return self.world_size // self.node_world_size
def get_local_rank(self):
return self.rank % self.N_NPU_PER_NODE
def get_pickle_path(self, file_name):
return f"{self.pickle_save_path}/{file_name}_local_n63"
def free_mm(self):
for key, value in self.mm.items():
value.close()
self.mm.clear()
def __del__(self):
self.free_mm()
def try_load_pickle(self, file_name, function):
file_name = self.get_pickle_path(file_name)
if os.path.exists(file_name) and self.load_pickle:
with open(file_name, 'rb') as file:
# self.mm[file_name] = mmap.mmap(file.fileno(), 0, access=mmap.ACCESS_READ)
# # 使用 mmap 进行数据读取
# loaded_data = pickle.loads(self.mm[file_name][:])
loaded_data = pickle.load(file)
return loaded_data
else:
data = function()
if not self.use_small_dataset:
if self.rank % self.N_NPU_PER_NODE == 0:
# 只需要rank0保存文件
os.makedirs(self.pickle_save_path, exist_ok=True)
with open(file_name, 'wb') as file:
pickle.dump(data, file, pickle.HIGHEST_PROTOCOL)
return data
def try_get_vid_path(self, file, out_size=1024):
output_file = file.rsplit(".", 1)[0] + f"_resize{out_size}.mp4"
if not os.path.exists(output_file):
return file
# compress_video(file, output_file, out_size)
return output_file
def npu_format_cast(self, x):
return torch_npu.npu_format_cast(x, 2)
def calc_grad_norm(self, model):
# 计算并打印梯度范数
# model_engine = accelerator.deepspeed_engine_wrapped.engine
# gradients = model_engine.get_gradients()
# grad_norm = get_grad_norm(gradients)
# 计算并打印梯度范数
grad_norm = 0
n_grad = 0
# for name, param in model.named_parameters():
# grad_data = deepspeed.utils.safe_get_full_grad(param)
# # self.print_tensor_stats(grad_data, name=name)
#
# if grad_data is not None:
# param_norm = grad_data.norm(2)
# grad_norm += param_norm.item() ** 2
# n_grad += 1
# grad_norm = (grad_norm / n_grad) ** (1. / 2)
return grad_norm
def _run(self, operator, x, tmp_dtype, out_dtype=None, out_nd_format=False):
if self.on_npu:
if out_dtype is None:
out_dtype = x.dtype
with torch.cuda.amp.autocast(enabled=False):
x = operator.to(device=x.device, dtype=tmp_dtype)(x.to(tmp_dtype))
x = x.to(out_dtype)
if out_nd_format:
return self.npu_format_cast(x)
else:
return x
else:
return operator(x)
def run_group_norm(self, operator, x):
return self._run(operator, x, torch.float32)
def run_layer_norm(self, operator, x):
return self._run(operator, x, torch.float32)
def print_tensor_stats(self, tensor, name="Tensor", rank=None):
if rank and rank != self.rank:
return
if tensor is None:
self.print_msg(f"Tensor {name} is None.")
return
x_dtype = tensor.dtype
tensor = tensor.to(torch.bfloat16)
max_val = tensor.max().item()
min_val = tensor.min().item()
abs_max_val = min(abs(max_val), abs(min_val))
mean_val = tensor.mean().item()
median_val = tensor.median().item()
std_val = tensor.std().item()
shape = tensor.shape
self.print_msg(
f"{name} - Max: {max_val}, Min: {min_val}, Mean: {mean_val}, AbsMax: {abs_max_val},"
f"Median: {median_val}, Std: {std_val}, Shape: {shape}, Type: {x_dtype}")
def run_conv3d(self, operator, x, out_dtype):
return self._run(operator, x, self.conv_dtype, out_dtype, out_nd_format=True)
def run_pool_2d(self, operator, x):
return self._run(operator, x, self.replaced_type)
def run_pad_2d(self, operator, x, pad, mode="constant"):
if self.on_npu:
x_dtype = x.dtype
x = x.to(self.replaced_type)
x = operator(x, pad, mode)
x = x.to(x_dtype)
else:
x = operator(x, pad, mode)
return x
def seed_everything(self, seed=100):
seed += self.rank
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def print_with_rank(self, msg, rank=0, save=False):
if self.rank == rank:
print(f"{msg}", flush=True)
if save:
self._loss.append(msg)
def print_msg(self, msg, on=True, rank=None):
if on:
if self.rank == rank or rank is None:
print(f"[RANK-{self.rank}]: {msg}", flush=True)
def save_loss(self, filename, rank=0):
if self.rank == rank:
import json
with open(filename, 'w') as file:
json.dump(self._loss, file, indent=4)
def run_attention(self, query, key, value, atten_mask, input_layout, head_dim, head_num):
if self.enable_FA:
hidden_states = torch_npu.npu_fusion_attention(query, key, value,
atten_mask=atten_mask,
input_layout=input_layout,
scale=1 / math.sqrt(head_dim),
head_num=head_num)[0]
else:
hidden_states = self.scaled_dot_product_attention(query, key, value,
atten_mask=atten_mask,
input_layout=input_layout,
scale=1 / math.sqrt(head_dim),
head_num=head_num)
return hidden_states
def scaled_dot_product_attention(self, query, key, value, input_layout, head_num=None,
atten_mask=None, scale=None, dropout_p=0.0, is_causal=False) -> torch.Tensor:
# L, S = query.size(-2), key.size(-2)
def trans_tensor_shape(x, layout, head_num):
if layout == "BSH":
batch = x.shape[0]
x = x.view(batch, -1, head_num, x.shape[-1] // head_num).transpose(1, 2).contiguous()
elif layout == "SBH":
batch = x.shape[1]
x = x.view(-1, batch * head_num, x.shape[-1] // head_num).transpose(0, 1).contiguous()
x = x.view(batch, head_num, -1, x.shape[-1])
return x
query = trans_tensor_shape(query, input_layout, head_num)
key = trans_tensor_shape(key, input_layout, head_num)
value = trans_tensor_shape(value, input_layout, head_num)
attn_weight = query @ key.transpose(-2, -1) * scale
attn_bias = torch.zeros_like(attn_weight, dtype=query.dtype, device=query.device)
if is_causal:
assert atten_mask is None
temp_mask = torch.zeros_like(attn_weight, dtype=torch.bool, device=query.device).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), npu_config.inf_float)
attn_bias.to(query.dtype)
if atten_mask is not None:
assert (not self.enable_FA) and atten_mask.dtype != torch.bool, \
"attention_mask must not be bool type when use this function"
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
output = attn_weight @ value
if input_layout == "BSH":
output = output.transpose(1, 2).contiguous().view(output.shape[0], -1, head_num * output.shape[-1])
else:
output = output.view(output.shape[0] * head_num, -1, output.shape[-1]).transpose(0, 1).contiguous()
output = output.view(output.shape[0], -1, head_num * output.shape[-1])
return output
def print_tensor_with_rank(self, name, tensor, rank=[0], dim_print_cnt=[]):
if type(rank) is not list:
rank = [rank]
if self.rank in rank:
def print_dim(tensor_, indices):
if tensor_.dim() == len(indices):
return '{0:10.5f} '.format(tensor[tuple(indices)].detach().item())
else:
cur_dim = len(indices)
ret = ''
for x in range(0, tensor_.size(cur_dim), tensor_.size(cur_dim) // dim_print_cnt[cur_dim]):
ret += print_dim(tensor_, indices + [x])
return ret + '\n'
print(name, tensor.size(), self.rank, '\n', print_dim(tensor, []))
npu_config = NPUConfig()
================================================
FILE: opensora/sample/caption_refiner.py
================================================
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModelForCausalLM
TEMPLATE = """
Refine the sentence: \"{}\" to contain subject description, action, scene description. " \
"(Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. " \
"Make sure it is a fluent sentence, not nonsense.
"""
class OpenSoraCaptionRefiner(nn.Module):
def __init__(self, args, dtype, device):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(
args.caption_refiner, trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
args.caption_refiner, torch_dtype=dtype, trust_remote_code=True
).to(device).eval()
self.device = device
def get_refiner_output(self, prompt):
prompt = TEMPLATE.format(prompt)
messages = [
{"role": "system", "content": "You are a caption refiner."},
{"role": "user", "content": prompt}
]
input_ids = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = self.tokenizer([input_ids], return_tensors="pt").to(self.device)
generated_ids = self.model.generate(model_inputs.input_ids, max_new_tokens=512)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
================================================
FILE: opensora/sample/pipeline_inpaint.py
================================================
import inspect
import os
from typing import Callable, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
from altair import condition
import numpy as np
import torch
from einops import rearrange
from PIL import Image
import decord
from transformers import CLIPTextModelWithProjection, CLIPTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel
import torch.nn.functional as F
from torchvision.transforms import Compose, Lambda, Resize
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import AutoencoderKL, HunyuanDiT2DModel
from diffusers.models.embeddings import get_2d_rotary_pos_embed
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDPMScheduler
from diffusers.utils import logging, BaseOutput
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from opensora.models.diffusion.opensora_v1_3.modeling_inpaint import OpenSoraInpaint_v1_3
from opensora.sample.pipeline_opensora import OpenSoraPipeline, OpenSoraPipelineOutput, rescale_noise_cfg
from opensora.dataset.transform import CenterCropResizeVideo, SpatialStrideCropVideo,ToTensorAfterResize, maxhwresize
from opensora.utils.mask_utils import MaskProcessor, MaskCompressor, GaussianNoiseAdder, MaskType, STR_TO_TYPE, TYPE_TO_STR
try:
import torch_npu
from opensora.npu_config import npu_config
from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info
except:
torch_npu = None
npu_config = None
from opensora.utils.parallel_states import get_sequence_parallel_state, nccl_info
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def is_video_file(file_path):
video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.mpeg', '.mpg', '.3gp'}
file_extension = os.path.splitext(file_path)[1].lower()
return file_extension in video_extensions
def is_image_file(file_path):
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.webp'}
file_extension = os.path.splitext(file_path)[1].lower()
return file_extension in image_extensions
def open_image(file_path):
image = Image.open(file_path).convert("RGB")
return image
def open_video(file_path, start_frame_idx, num_frames, frame_interval=1):
decord_vr = decord.VideoReader(file_path, ctx=decord.cpu(0), num_threads=1)
total_frames = len(decord_vr)
frame_indices = list(range(start_frame_idx, min(start_frame_idx + num_frames * frame_interval, total_frames), frame_interval))
if len(frame_indices) == 0:
raise ValueError("No frames selected. Check your start_frame_idx and num_frames.")
if len(frame_indices) < num_frames:
raise ValueError(f"Requested {num_frames} frames but only {len(frame_indices)} frames are available, please adjust the start_frame_idx and num_frames or decrease the frame_interval.")
video_data = decord_vr.get_batch(frame_indices).asnumpy()
video_data = torch.from_numpy(video_data)
video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W)
return video_data
def get_pixel_values(file_path, num_frames):
if is_image_file(file_path[0]):
pixel_values = [open_image(path) for path in file_path]
pixel_values = [torch.from_numpy(np.array(image)) for image in pixel_values]
pixel_values = [rearrange(image, 'h w c -> c h w').unsqueeze(0) for image in pixel_values]
elif is_video_file(file_path[0]):
pixel_values = [open_video(video_path, 0, num_frames) for video_path in file_path]
return pixel_values
class OpenSoraInpaintPipeline(OpenSoraPipeline):
def __init__(
self,
vae: AutoencoderKL,
text_encoder: T5EncoderModel,
tokenizer: MT5Tokenizer,
transformer: OpenSoraInpaint_v1_3,
scheduler: DDPMScheduler,
text_encoder_2: CLIPTextModelWithProjection = None,
tokenizer_2: CLIPTokenizer = None,
):
super().__init__(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
)
# If performing continuation or random, the default mask is half of the frame, which can be modified
self.mask_processor = MaskProcessor(min_clear_ratio=0.5, max_clear_ratio=0.5)
self.mask_compressor = MaskCompressor(ae_stride_t=self.vae.vae_scale_factor[0], ae_stride_h=self.vae.vae_scale_factor[1], ae_stride_w=self.vae.vae_scale_factor[2])
self.noise_adder = None
def check_inputs(
self,
conditional_pixel_values_path,
conditional_pixel_values_indices,
mask_type,
max_hxw,
noise_strength,
prompt,
num_frames,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
prompt_embeds_2=None,
negative_prompt_embeds_2=None,
prompt_attention_mask_2=None,
negative_prompt_attention_mask_2=None,
callback_on_step_end_tensor_inputs=None,
):
if conditional_pixel_values_path is None:
raise ValueError("conditional_pixel_values_path should be provided")
else:
if not isinstance(conditional_pixel_values_path, list) or not isinstance(conditional_pixel_values_path[0], str):
raise ValueError("conditional_pixel_values_path should be a list of strings")
if not is_image_file(conditional_pixel_values_path[0]) and not is_video_file(conditional_pixel_values_path[0]):
raise ValueError("conditional_pixel_values_path should be an image or video file path")
if is_video_file(conditional_pixel_values_path[0]) and len(conditional_pixel_values_path) > 1:
raise ValueError("conditional_pixel_values_path should be a list of image file paths or a single video file path")
if conditional_pixel_values_indices is not None \
and (not isinstance(conditional_pixel_values_indices, list) or not isinstance(conditional_pixel_values_indices[0], int) \
or len(conditional_pixel_values_indices) != len(conditional_pixel_values_path)):
raise ValueError("conditional_pixel_values_indices should be a list of integers with the same length as conditional_pixel_values_path")
if mask_type is not None and not mask_type in STR_TO_TYPE.keys() and not mask_type in STR_TO_TYPE.values():
raise ValueError(f"Invalid mask type: {mask_type}")
if not isinstance(max_hxw, int) or not (max_hxw >= 102400 and max_hxw <= 236544):
raise ValueError("max_hxw should be an integer between 102400 and 236544")
if not isinstance(noise_strength, float) or not (noise_strength >= 0 and noise_strength <= 1):
raise ValueError("noise_strength should be a non-negative float")
super().check_inputs(
prompt,
num_frames,
height,
width,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_attention_mask_2,
negative_prompt_attention_mask_2,
callback_on_step_end_tensor_inputs,
)
def get_resize_transform(
self,
ori_height,
ori_width,
height=None,
width=None,
crop_for_hw=False,
hw_stride=32,
max_hxw=236544, # 480 x 480
):
if crop_for_hw:
assert height is not None and width is not None
transform = CenterCropResizeVideo((height, width))
else:
new_height, new_width = maxhwresize(ori_height, ori_width, max_hxw)
transform = Compose(
[
CenterCropResizeVideo((new_height, new_width)), # We use CenterCropResizeVideo to share the same height and width, ensuring that the shape of the crop remains consistent when multiple images are captured
SpatialStrideCropVideo(stride=hw_stride),
]
)
return transform
def get_video_transform(self):
norm_fun = Lambda(lambda x: 2. * x - 1.)
transform = Compose([
ToTensorAfterResize(),
norm_fun
])
return transform
def get_mask_type_cond_indices(self, mask_type, conditional_pixel_values_path, conditional_pixel_values_indices, num_frames):
if mask_type is not None and mask_type in STR_TO_TYPE.keys():
mask_type = STR_TO_TYPE[mask_type]
if is_image_file(conditional_pixel_values_path[0]):
if len(conditional_pixel_values_path) == 1:
mask_type = MaskType.i2v if mask_type is None else mask_type
if num_frames > 1:
conditional_pixel_values_indices = [0] if conditional_pixel_values_indices is None else conditional_pixel_values_indices
assert len(conditional_pixel_values_indices) == 1, "conditional_pixel_values_indices should be a list of integers with the same length as conditional_pixel_values_path"
elif len(conditional_pixel_values_path) == 2:
mask_type = MaskType.transition if mask_type is None else mask_type
if num_frames > 1:
conditional_pixel_values_indices = [0, -1] if conditional_pixel_values_indices is None else conditional_pixel_values_indices
assert len(conditional_pixel_values_indices) == 2, "conditional_pixel_values_indices should be a list of integers with the same length as conditional_pixel_values_path"
else:
if num_frames > 1:
assert conditional_pixel_values_indices is not None and len(conditional_pixel_values_path) == len(conditional_pixel_values_indices), "conditional_pixel_values_indices should be a list of integers with the same length as conditional_pixel_values_path"
mask_type = MaskType.random_temporal if mask_type is None else mask_type
elif is_video_file(conditional_pixel_values_path[0]):
# When the input is a video, video continuation is executed by default, with a continuation rate of double
mask_type = MaskType.continuation if mask_type is None else mask_type
return mask_type, conditional_pixel_values_indices
def get_masked_pixel_values_mask(
self,
conditional_pixel_values,
conditional_pixel_values_indices,
mask_type,
batch_size,
num_samples_per_prompt,
num_frames,
height,
width,
video_transform,
weight_dtype,
device
):
if device is None:
device = getattr(self, '_execution_device', None) or getattr(self, 'device', None) or torch.device('cuda')
conditional_pixel_values = conditional_pixel_values.to(device=device, dtype=weight_dtype)
if conditional_pixel_values.shape[0] == num_frames:
inpaint_cond_data = self.mask_processor(conditional_pixel_values, mask_type=mask_type)
masked_pixel_values, mask = inpaint_cond_data['masked_pixel_values'], inpaint_cond_data['mask']
else:
input_pixel_values = torch.zeros([num_frames, 3, height, width], device=device, dtype=weight_dtype)
input_mask = torch.ones([num_frames, 1, height, width], device=device, dtype=weight_dtype)
input_pixel_values[conditional_pixel_values_indices] = conditional_pixel_values
input_mask[conditional_pixel_values_indices] = 0
masked_pixel_values = input_pixel_values * (input_mask < 0.5)
mask = input_mask
print('conditional_pixel_values_indices', conditional_pixel_values_indices)
print('mask_type', TYPE_TO_STR[mask_type])
masked_pixel_values = video_transform(masked_pixel_values)
masked_pixel_values = masked_pixel_values.unsqueeze(0).repeat(batch_size * num_samples_per_prompt, 1, 1, 1, 1).transpose(1, 2).contiguous() # b c t h w
mask = mask.unsqueeze(0).repeat(batch_size * num_samples_per_prompt, 1, 1, 1, 1).transpose(1, 2).contiguous() # b c t h w
if self.noise_adder is not None:
# add some noise to improve motion strength
masked_pixel_values = self.noise_adder(masked_pixel_values, mask)
masked_pixel_values = masked_pixel_values.to(self.vae.vae.dtype)
masked_pixel_values = self.vae.encode(masked_pixel_values)
mask = self.mask_compressor(mask)
masked_pixel_values = torch.cat([masked_pixel_values] * 2) if self.do_classifier_free_guidance else masked_pixel_values
mask = torch.cat([mask] * 2) if self.do_classifier_free_guidance else mask
masked_pixel_values = masked_pixel_values.to(weight_dtype)
mask = mask.to(weight_dtype)
return masked_pixel_values, mask
@torch.no_grad()
def __call__(
self,
conditional_pixel_values_path: Union[str, List[str]] = None,
conditional_pixel_values_indices: Union[int, List[int]] = None,
mask_type: Union[str, MaskType] = None,
crop_for_hw: bool = False,
max_hxw: int = 236544,
noise_strength: Optional[float] = 0.0,
prompt: Union[str, List[str]] = None,
num_frames: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_samples_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_2: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds_2: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
prompt_attention_mask_2: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
guidance_rescale: float = 0.0,
max_sequence_length: int = 512,
device = None,
):
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 0. default height and width
num_frames = num_frames or (self.transformer.config.sample_size_t - 1) * self.vae.vae_scale_factor[0] + 1
height = height or self.transformer.config.sample_size[0] * self.vae.vae_scale_factor[1]
width = width or self.transformer.config.sample_size[1] * self.vae.vae_scale_factor[2]
# 1. Check inputs. Raise error if not correct
self.check_inputs(
conditional_pixel_values_path,
conditional_pixel_values_indices,
mask_type,
max_hxw,
noise_strength,
prompt,
num_frames,
height,
width,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_attention_mask_2,
negative_prompt_attention_mask_2,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = device or getattr(self, '_execution_device', None) or getattr(self, 'device', None) or torch.device('cuda')
# 3. Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
device=device,
dtype=self.transformer.dtype,
num_samples_per_prompt=num_samples_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
text_encoder_index=0,
)
if self.tokenizer_2 is not None:
(
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_attention_mask_2,
negative_prompt_attention_mask_2,
) = self.encode_prompt(
prompt=prompt,
device=device,
dtype=self.transformer.dtype,
num_samples_per_prompt=num_samples_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds_2,
negative_prompt_embeds=negative_prompt_embeds_2,
prompt_attention_mask=prompt_attention_mask_2,
negative_prompt_attention_mask=negative_prompt_attention_mask_2,
max_sequence_length=77,
text_encoder_index=1,
)
else:
prompt_embeds_2 = None
negative_prompt_embeds_2 = None
prompt_attention_mask_2 = None
negative_prompt_attention_mask_2 = None
# ==================prepare inpaint data=====================================
if noise_strength != 0:
self.noise_adder = GaussianNoiseAdder(mean=np.log(noise_strength), std=0.01, clear_ratio=0)
mask_type, conditional_pixel_values_indices = self.get_mask_type_cond_indices(mask_type, conditional_pixel_values_path, conditional_pixel_values_indices, num_frames)
conditional_pixel_values = get_pixel_values(conditional_pixel_values_path, num_frames)
min_height = min([pixels.shape[2] for pixels in conditional_pixel_values])
min_width = min([pixels.shape[3] for pixels in conditional_pixel_values])
resize_transform = self.get_resize_transform(
ori_height=min_height,
ori_width=min_width,
height=height,
width=width,
crop_for_hw=crop_for_hw,
max_hxw=max_hxw,
)
video_transform = self.get_video_transform()
conditional_pixel_values = torch.cat([resize_transform(pixels) for pixels in conditional_pixel_values])
real_height, real_width = conditional_pixel_values.shape[-2], conditional_pixel_values.shape[-1]
# ==================prepare inpaint data=====================================
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
if get_sequence_parallel_state():
world_size = hccl_info.world_size if torch_npu is not None else nccl_info.world_size
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_samples_per_prompt,
num_channels_latents,
(num_frames + world_size - 1) // world_size if get_sequence_parallel_state() else num_frames,
real_height,
real_width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# ==============================create mask=====================================
masked_pixel_values, mask = self.get_masked_pixel_values_mask(
conditional_pixel_values,
conditional_pixel_values_indices,
mask_type,
batch_size,
num_samples_per_prompt,
num_frames,
real_height,
real_width,
video_transform,
prompt_embeds.dtype,
device
)
# ==============================create mask=====================================
# 7 create image_rotary_emb, style embedding & time ids
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
if self.tokenizer_2 is not None:
prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
prompt_embeds = prompt_embeds.to(device=device)
prompt_attention_mask = prompt_attention_mask.to(device=device)
if self.tokenizer_2 is not None:
prompt_embeds_2 = prompt_embeds_2.to(device=device)
prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
# ==================make sp=====================================
if get_sequence_parallel_state():
prompt_embeds = rearrange(
prompt_embeds,
'b (n x) h -> b n x h',
n=world_size,
x=prompt_embeds.shape[1] // world_size
).contiguous()
rank = hccl_info.rank if torch_npu is not None else nccl_info.rank
prompt_embeds = prompt_embeds[:, rank, :, :]
latents_num_frames = latents.shape[2]
masked_pixel_values = masked_pixel_values[:, :, latents_num_frames * rank: latents_num_frames * (rank + 1)]
mask = mask[:, :, latents_num_frames * rank: latents_num_frames * (rank + 1)]
# ==================make sp=====================================
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# inpaint
latent_model_input = torch.cat([latent_model_input, masked_pixel_values, mask], dim=1)
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
dtype=latent_model_input.dtype
)
# ==================prepare my shape=====================================
# predict the noise residual
if prompt_embeds.ndim == 3:
prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d
if prompt_attention_mask.ndim == 2:
prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l
if prompt_embeds_2 is not None and prompt_embeds_2.ndim == 2:
prompt_embeds = prompt_embeds.unsqueeze(1) # b d -> b 1 d
attention_mask = torch.ones_like(latent_model_input)[:, 0].to(device=device)
# ==================prepare my shape=====================================
# ==================make sp=====================================
if get_sequence_parallel_state():
attention_mask = attention_mask.repeat(1, world_size, 1, 1)
# ==================make sp=====================================
noise_pred = self.transformer(
latent_model_input,
attention_mask=attention_mask,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
timestep=t_expand,
pooled_projections=prompt_embeds_2,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
negative_prompt_embeds_2 = callback_outputs.pop(
"negative_prompt_embeds_2", negative_prompt_embeds_2
)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
# ==================make sp=====================================
if get_sequence_parallel_state():
latents_shape = list(latents.shape) # b c t//sp h w
full_shape = [latents_shape[0] * world_size] + latents_shape[1:] # # b*sp c t//sp h w
all_latents = torch.zeros(full_shape, dtype=latents.dtype, device=latents.device)
torch.distributed.all_gather_into_tensor(all_latents, latents)
latents_list = list(all_latents.chunk(world_size, dim=0))
latents = torch.cat(latents_list, dim=2)
# ==================make sp=====================================
if not output_type == "latent":
videos = self.decode_latents(latents)
videos = videos[:, :num_frames, :real_height, :real_width]
else:
videos = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (videos, )
return OpenSoraPipelineOutput(videos=videos)
================================================
FILE: opensora/sample/pipeline_opensora.py
================================================
import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
import numpy as np
import torch
from einops import rearrange
from transformers import CLIPTextModelWithProjection, CLIPTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import AutoencoderKL, HunyuanDiT2DModel
from diffusers.models.embeddings import get_2d_rotary_pos_embed
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDPMScheduler, FlowMatchEulerDiscreteScheduler
from diffusers.utils import logging, BaseOutput
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from opensora.models.diffusion.opensora_v1_3.modeling_opensora import OpenSoraT2V_v1_3
try:
import torch_npu
from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info
except:
torch_npu = None
from opensora.utils.parallel_states import get_sequence_parallel_state, nccl_info
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class OpenSoraPipelineOutput(BaseOutput):
videos: Union[List[torch.FloatTensor], np.ndarray]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class OpenSoraPipeline(DiffusionPipeline):
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_optional_components = [
"text_encoder_2",
"tokenizer_2",
"text_encoder",
"tokenizer",
]
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"prompt_embeds_2",
"negative_prompt_embeds_2",
]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: T5EncoderModel,
tokenizer: MT5Tokenizer,
transformer: OpenSoraT2V_v1_3,
scheduler: DDPMScheduler,
text_encoder_2: CLIPTextModelWithProjection = None,
tokenizer_2: CLIPTokenizer = None,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
text_encoder_2=text_encoder_2,
)
def encode_prompt(
self,
prompt: str,
device: torch.device = None,
dtype: torch.dtype = None,
num_samples_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: Optional[int] = None,
text_encoder_index: int = 0,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
dtype (`torch.dtype`):
torch dtype
num_samples_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
text_encoder_index (`int`, *optional*):
Index of the text encoder to use. `0` for T5 and `1` for clip.
"""
if dtype is None:
if self.text_encoder_2 is not None:
dtype = self.text_encoder_2.dtype
elif self.transformer is not None:
dtype = self.transformer.dtype
else:
dtype = None
if device is None:
device = getattr(self, '_execution_device', None) or getattr(self, 'device', None) or torch.device('cuda')
tokenizers = [self.tokenizer, self.tokenizer_2]
text_encoders = [self.text_encoder, self.text_encoder_2]
tokenizer = tokenizers[text_encoder_index]
text_encoder = text_encoders[text_encoder_index]
if max_sequence_length is None:
if text_encoder_index == 0:
max_length = 512
if text_encoder_index == 1:
max_length = 77
else:
max_length = max_sequence_length
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
prompt_attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds = text_encoder(
text_input_ids.to(device),
attention_mask=prompt_attention_mask,
)
prompt_embeds = prompt_embeds[0]
if text_encoder_index == 1:
prompt_embeds = prompt_embeds.unsqueeze(1) # b d -> b 1 d for clip
prompt_attention_mask = prompt_attention_mask.repeat(num_samples_per_prompt, 1)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_samples_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_samples_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
# max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
negative_prompt_embeds = text_encoder(
uncond_input.input_ids.to(device),
attention_mask=negative_prompt_attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
if text_encoder_index == 1:
negative_prompt_embeds = negative_prompt_embeds.unsqueeze(1) # b d -> b 1 d for clip
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_samples_per_prompt, 1)
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_samples_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_samples_per_prompt, seq_len, -1)
return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
num_frames,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
prompt_embeds_2=None,
negative_prompt_embeds_2=None,
prompt_attention_mask_2=None,
negative_prompt_attention_mask_2=None,
callback_on_step_end_tensor_inputs=None,
):
if (num_frames - 1) % 4 != 0:
raise ValueError(f"`num_frames - 1` have to be divisible by 4 but is {num_frames}.")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is None and prompt_embeds_2 is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
raise ValueError(
"Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
raise ValueError(
"`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
f" {negative_prompt_embeds_2.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None):
shape = (
batch_size,
num_channels_latents,
(int(num_frames) - 1) // self.vae.vae_scale_factor[0] + 1,
int(height) // self.vae.vae_scale_factor[1],
int(width) // self.vae.vae_scale_factor[2],
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
num_frames: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: Optional[int] = 50,
timesteps: List[int] = None,
guidance_scale: Optional[float] = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_samples_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_2: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds_2: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
prompt_attention_mask_2: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
guidance_rescale: float = 0.0,
max_sequence_length: int = 512,
device = None,
):
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 0. default height and width
num_frames = num_frames or (self.transformer.config.sample_size_t - 1) * self.vae.vae_scale_factor[0] + 1
height = height or self.transformer.config.sample_size[0] * self.vae.vae_scale_factor[1]
width = width or self.transformer.config.sample_size[1] * self.vae.vae_scale_factor[2]
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
num_frames,
height,
width,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_attention_mask_2,
negative_prompt_attention_mask_2,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = device or getattr(self, '_execution_device', None) or getattr(self, 'device', None) or torch.device('cuda')
# 3. Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
device=device,
dtype=self.transformer.dtype,
num_samples_per_prompt=num_samples_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
text_encoder_index=0,
)
if self.tokenizer_2 is not None:
(
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_attention_mask_2,
negative_prompt_attention_mask_2,
) = self.encode_prompt(
prompt=prompt,
device=device,
dtype=self.transformer.dtype,
num_samples_per_prompt=num_samples_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds_2,
negative_prompt_embeds=negative_prompt_embeds_2,
prompt_attention_mask=prompt_attention_mask_2,
negative_prompt_attention_mask=negative_prompt_attention_mask_2,
max_sequence_length=77,
text_encoder_index=1,
)
else:
prompt_embeds_2 = None
negative_prompt_embeds_2 = None
prompt_attention_mask_2 = None
negative_prompt_attention_mask_2 = None
# 4. Prepare timesteps
if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
else:
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 5. Prepare latent variables
if get_sequence_parallel_state():
world_size = hccl_info.world_size if torch_npu is not None else nccl_info.world_size
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_samples_per_prompt,
num_channels_latents,
(num_frames + world_size - 1) // world_size if get_sequence_parallel_state() else num_frames,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
else:
extra_step_kwargs = {}
# 7 create image_rotary_emb, style embedding & time ids
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
if self.tokenizer_2 is not None:
prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
prompt_embeds = prompt_embeds.to(device=device)
prompt_attention_mask = prompt_attention_mask.to(device=device)
if self.tokenizer_2 is not None:
prompt_embeds_2 = prompt_embeds_2.to(device=device)
prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
# ==================make sp=====================================
if get_sequence_parallel_state():
prompt_embeds = rearrange(
prompt_embeds,
'b (n x) h -> b n x h',
n=world_size,
x=prompt_embeds.shape[1] // world_size
).contiguous()
rank = hccl_info.rank if torch_npu is not None else nccl_info.rank
prompt_embeds = prompt_embeds[:, rank, :, :]
# ==================make sp=====================================
# 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
timestep = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
dtype=latent_model_input.dtype
)
else:
timestep = t.expand(latent_model_input.shape[0])
# ==================prepare my shape=====================================
# predict the noise residual
if prompt_embeds.ndim == 3:
prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d
if prompt_attention_mask.ndim == 2:
prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l
if prompt_embeds_2 is not None and prompt_embeds_2.ndim == 2:
prompt_embeds = prompt_embeds.unsqueeze(1) # b d -> b 1 d
attention_mask = torch.ones_like(latent_model_input)[:, 0].to(device=device)
# ==================prepare my shape=====================================
# ==================make sp=====================================
if get_sequence_parallel_state():
attention_mask = attention_mask.repeat(1, world_size, 1, 1)
# ==================make sp=====================================
noise_pred = self.transformer(
latent_model_input,
attention_mask=attention_mask,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
timestep=timestep,
pooled_projections=prompt_embeds_2,
return_dict=False,
)[0]
assert not torch.any(torch.isnan(noise_pred))
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and guidance_rescale > 0.0 and not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
negative_prompt_embeds_2 = callback_outputs.pop(
"negative_prompt_embeds_2", negative_prompt_embeds_2
)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
# ==================make sp=====================================
if get_sequence_parallel_state():
latents_shape = list(latents.shape) # b c t//sp h w
full_shape = [latents_shape[0] * world_size] + latents_shape[1:] # # b*sp c t//sp h w
all_latents = torch.zeros(full_shape, dtype=latents.dtype, device=latents.device)
torch.distributed.all_gather_into_tensor(all_latents, latents)
latents_list = list(all_latents.chunk(world_size, dim=0))
latents = torch.cat(latents_list, dim=2)
# ==================make sp=====================================
if not output_type == "latent":
videos = self.decode_latents(latents)
videos = videos[:, :num_frames, :height, :width]
else:
videos = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (videos, )
return OpenSoraPipelineOutput(videos=videos)
def decode_latents(self, latents):
print(f'before vae decode {latents.shape}', torch.max(latents).item(), torch.min(latents).item(), torch.mean(latents).item(), torch.std(latents).item())
video = self.vae.decode(latents.to(self.vae.vae.dtype))
print(f'after vae decode {latents.shape}', torch.max(video).item(), torch.min(video).item(), torch.mean(video).item(), torch.std(video).item())
video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().permute(0, 1, 3, 4, 2).contiguous() # b t h w c
return video
================================================
FILE: opensora/sample/rec_image.py
================================================
import sys
sys.path.append(".")
from PIL import Image
import torch
from torchvision.transforms import ToTensor, Compose, Resize, Normalize, Lambda
from torch.nn import functional as F
import argparse
import numpy as np
from opensora.models.causalvideovae import ae_wrapper
def preprocess(video_data: torch.Tensor, short_size: int = 128) -> torch.Tensor:
transform = Compose(
[
ToTensor(),
Lambda(lambda x: 2. * x - 1.),
Resize(size=short_size),
]
)
outputs = transform(video_data)
outputs = outputs.unsqueeze(0).unsqueeze(2)
return outputs
def main(args: argparse.Namespace):
image_path = args.image_path
short_size = args.short_size
device = args.device
kwarg = {}
# vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir', **kwarg).to(device)
vae = ae_wrapper[args.ae](args.ae_path, **kwarg).eval().to(device)
if args.enable_tiling:
vae.vae.enable_tiling()
vae.vae.tile_overlap_factor = args.tile_overlap_factor
vae.eval()
vae = vae.to(device)
vae = vae.half()
with torch.no_grad():
x_vae = preprocess(Image.open(image_path), short_size)
x_vae = x_vae.to(device, dtype=torch.float16) # b c t h w
latents = vae.encode(x_vae)
latents = latents.to(torch.float16)
image_recon = vae.decode(latents) # b t c h w
x = image_recon[0, 0, :, :, :]
x = x.squeeze()
x = x.detach().cpu().numpy()
x = np.clip(x, -1, 1)
x = (x + 1) / 2
x = (255*x).astype(np.uint8)
x = x.transpose(1,2,0)
image = Image.fromarray(x)
image.save(args.rec_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image_path', type=str, default='')
parser.add_argument('--rec_path', type=str, default='')
parser.add_argument('--ae', type=str, default='')
parser.add_argument('--ae_path', type=str, default='')
parser.add_argument('--model_path', type=str, default='results/pretrained')
parser.add_argument('--short_size', type=int, default=336)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--tile_overlap_factor', type=float, default=0.25)
parser.add_argument('--enable_tiling', action='store_true')
args = parser.parse_args()
main(args)
================================================
FILE: opensora/sample/rec_video.py
================================================
import math
import random
import argparse
from typing import Optional
import cv2
import numpy as np
import numpy.typing as npt
import torch
from PIL import Image
from decord import VideoReader, cpu
from torch.nn import functional as F
from pytorchvideo.transforms import ShortSideScale
from torchvision.transforms import Lambda, Compose
import sys
from opensora.models.causalvideovae import ae_wrapper
from opensora.dataset.transform import ToTensorVideo, CenterCropResizeVideo
def array_to_video(image_array: npt.NDArray, fps: float = 30.0, output_file: str = 'output_video.mp4') -> None:
height, width, channels = image_array[0].shape
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height))
for image in image_array:
image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
video_writer.write(image_rgb)
video_writer.release()
def custom_to_video(x: torch.Tensor, fps: float = 2.0, output_file: str = 'output_video.mp4') -> None:
x = x.detach().cpu()
x = torch.clamp(x, -1, 1)
x = (x + 1) / 2
x = x.permute(0, 2, 3, 1).numpy()
x = (255 * x).astype(np.uint8)
array_to_video(x, fps=fps, output_file=output_file)
return
def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor:
decord_vr = VideoReader(video_path, ctx=cpu(0))
total_frames = len(decord_vr)
sample_frames_len = sample_rate * num_frames
# if total_frames > sample_frames_len:
# s = random.randint(0, total_frames - sample_frames_len - 1)
# s = 0
# e = s + sample_frames_len
# num_frames = num_frames
# else:
# s = 0
# e = total_frames
# num_frames = int(total_frames / sample_frames_len * num_frames)
s = 0
e = sample_frames_len
print(f'sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}', video_path,
total_frames)
frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)
video_data = decord_vr.get_batch(frame_id_list).asnumpy()
video_data = torch.from_numpy(video_data)
video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W)
return video_data
def preprocess(video_data: torch.Tensor, height: int = 128, width: int = 128) -> torch.Tensor:
transform = Compose(
[
ToTensorVideo(),
CenterCropResizeVideo((height, width)),
Lambda(lambda x: 2. * x - 1.)
]
)
video_outputs = transform(video_data)
video_outputs = torch.unsqueeze(video_outputs, 0)
return video_outputs
def main(args: argparse.Namespace):
device = args.device
kwarg = {}
# vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir', **kwarg).to(device)
# vae = CausalVAEModelWrapper(args.ae_path, **kwarg).to(device)
vae = ae_wrapper[args.ae](args.ae_path, **kwarg).eval().to(device)
if args.enable_tiling:
vae.vae.enable_tiling()
vae.vae.tile_overlap_factor = args.tile_overlap_factor
# vae.vae.tile_sample_min_size = 512
# vae.vae.tile_latent_min_size = 64
# vae.vae.tile_sample_min_size_t = 29
# vae.vae.tile_latent_min_size_t = 8
# if args.save_memory:
# vae.vae.tile_sample_min_size = 256
# vae.vae.tile_latent_min_size = 32
# vae.vae.tile_sample_min_size_t = 9
# vae.vae.tile_latent_min_size_t = 3
dtype = torch.float32
vae.eval()
vae = vae.to(device, dtype=dtype)
with torch.no_grad():
x_vae = preprocess(read_video(args.video_path, args.num_frames, args.sample_rate), args.height,
args.width)
print(x_vae.shape)
x_vae = x_vae.to(device, dtype=dtype) # b c t h w
# for i in range(10000):
latents = vae.encode(x_vae)
print(latents.shape)
latents = latents.to(dtype)
video_recon = vae.decode(latents) # b t c h w
print(video_recon.shape)
# vae = vae.half()
# from tqdm import tqdm
# with torch.no_grad():
# x_vae = torch.rand(1, 3, 93, 720, 1280)
# print(x_vae.shape)
# x_vae = x_vae.to(device, dtype=torch.float16) # b c t h w
# # x_vae = x_vae.to(device) # b c t h w
# for i in tqdm(range(100000)):
# latents = vae.encode(x_vae)
# print(latents.shape)
# latents = latents.to(torch.float16)
# video_recon = vae.decode(latents) # b t c h w
# print(video_recon.shape)
custom_to_video(video_recon[0], fps=args.fps, output_file=args.rec_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--video_path', type=str, default='')
parser.add_argument('--rec_path', type=str, default='')
parser.add_argument('--ae', type=str, default='')
parser.add_argument('--ae_path', type=str, default='')
parser.add_argument('--model_path', type=str, default='results/pretrained')
parser.add_argument('--fps', type=int, default=30)
parser.add_argument('--height', type=int, default=336)
parser.add_argument('--width', type=int, default=336)
parser.add_argument('--num_frames', type=int, default=100)
parser.add_argument('--sample_rate', type=int, default=1)
parser.add_argument('--device', type=str, default="cuda")
parser.add_argument('--tile_overlap_factor', type=float, default=0.25)
parser.add_argument('--tile_sample_min_size', type=int, default=512)
parser.add_argument('--tile_sample_min_size_t', type=int, default=33)
parser.add_argument('--tile_sample_min_size_dec', type=int, default=256)
parser.add_argument('--tile_sample_min_size_dec_t', type=int, default=33)
parser.add_argument('--enable_tiling', action='store_true')
parser.add_argument('--save_memory', action='store_true')
args = parser.parse_args()
main(args)
================================================
FILE: opensora/sample/sample.py
================================================
import os
import torch
try:
import torch_npu
from opensora.npu_config import npu_config
except:
torch_npu = None
npu_config = None
pass
from opensora.utils.sample_utils import (
init_gpu_env, init_npu_env, prepare_pipeline, get_args,
run_model_and_save_samples, run_model_and_save_samples_npu
)
from opensora.sample.caption_refiner import OpenSoraCaptionRefiner
if __name__ == "__main__":
args = get_args()
dtype = torch.float16
if torch_npu is not None:
npu_config.print_msg(args)
npu_config.conv_dtype = dtype
init_npu_env(args)
else:
args = init_gpu_env(args)
device = torch.cuda.current_device()
if args.num_frames != 1 and args.enhance_video is not None:
from opensora.sample.VEnhancer.enhance_a_video import VEnhancer
enhance_video_model = VEnhancer(model_path=args.enhance_video, version='v2', device=device)
else:
enhance_video_model = None
pipeline = prepare_pipeline(args, dtype, device)
if args.caption_refiner is not None:
caption_refiner_model = OpenSoraCaptionRefiner(args, dtype, device)
else:
caption_refiner_model = None
if npu_config is not None and npu_config.on_npu and npu_config.profiling:
run_model_and_save_samples_npu(args, pipeline, caption_refiner_model, enhance_video_model)
else:
run_model_and_save_samples(args, pipeline, caption_refiner_model, enhance_video_model)
================================================
FILE: opensora/serve/gradio_utils.py
================================================
import random
import imageio
import uuid
import torch
import numpy as np
POS_PROMPT = """
high quality, high aesthetic, {}
"""
NEG_PROMPT = """
nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality,
low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry.
"""
NUM_IMAGES_PER_PROMPT = 1
MAX_SEED = np.iinfo(np.int32).max
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
LOGO = """