Repository: PKU-YuanGroup/Open-Sora-Plan Branch: main Commit: f7fa604f4e3a Files: 171 Total size: 1.3 MB Directory structure: gitextract_la9ona01/ ├── .github/ │ └── workflows/ │ └── docker_build.yml ├── .gitignore ├── LICENSE ├── README.md ├── docs/ │ ├── Contribution_Guidelines.md │ ├── Prompt_Refiner.md │ ├── Report-v1.0.0-cn.md │ ├── Report-v1.0.0.md │ ├── Report-v1.1.0.md │ ├── Report-v1.2.0.md │ ├── Report-v1.3.0.md │ ├── Report-v1.5.0.md │ ├── Report-v1.5.0_cn.md │ └── VAE.md ├── examples/ │ ├── cond_pix_path.txt │ ├── cond_prompt.txt │ ├── rec_image.py │ ├── rec_video.py │ └── sora.txt ├── opensora/ │ ├── __init__.py │ ├── acceleration/ │ │ ├── __init__.py │ │ ├── communications.py │ │ └── parallel_states.py │ ├── adaptor/ │ │ ├── __init__.py │ │ ├── bf16_optimizer.py │ │ ├── engine.py │ │ ├── modules.py │ │ ├── stage_1_and_2.py │ │ ├── utils.py │ │ └── zp_manager.py │ ├── dataset/ │ │ ├── __init__.py │ │ ├── inpaint_dataset.py │ │ ├── t2v_datasets.py │ │ ├── transform.py │ │ └── virtual_disk.py │ ├── models/ │ │ ├── __init__.py │ │ ├── causalvideovae/ │ │ │ ├── __init__.py │ │ │ ├── dataset/ │ │ │ │ ├── __init__.py │ │ │ │ ├── ddp_sampler.py │ │ │ │ ├── transform.py │ │ │ │ └── video_dataset.py │ │ │ ├── eval/ │ │ │ │ ├── cal_fvd.py │ │ │ │ ├── cal_lpips.py │ │ │ │ ├── cal_psnr.py │ │ │ │ ├── cal_ssim.py │ │ │ │ ├── eval.py │ │ │ │ ├── fvd/ │ │ │ │ │ ├── styleganv/ │ │ │ │ │ │ └── fvd.py │ │ │ │ │ └── videogpt/ │ │ │ │ │ ├── fvd.py │ │ │ │ │ └── pytorch_i3d.py │ │ │ │ └── script/ │ │ │ │ ├── cal_clip_score.sh │ │ │ │ ├── cal_fvd.sh │ │ │ │ ├── cal_lpips.sh │ │ │ │ ├── cal_psnr.sh │ │ │ │ └── cal_ssim.sh │ │ │ ├── model/ │ │ │ │ ├── __init__.py │ │ │ │ ├── configuration_videobase.py │ │ │ │ ├── dataset_videobase.py │ │ │ │ ├── ema_model.py │ │ │ │ ├── losses/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── discriminator.py │ │ │ │ │ ├── lpips.py │ │ │ │ │ └── perceptual_loss.py │ │ │ │ ├── modeling_videobase.py │ │ │ │ ├── modules/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── attention.py │ │ │ │ │ ├── block.py │ │ │ │ │ ├── conv.py │ │ │ │ │ ├── normalize.py │ │ │ │ │ ├── ops.py │ │ │ │ │ ├── quant.py │ │ │ │ │ ├── resnet_block.py │ │ │ │ │ ├── updownsample.py │ │ │ │ │ └── wavelet.py │ │ │ │ ├── registry.py │ │ │ │ ├── trainer_videobase.py │ │ │ │ ├── utils/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── distrib_utils.py │ │ │ │ │ ├── module_utils.py │ │ │ │ │ ├── scheduler_utils.py │ │ │ │ │ ├── video_utils.py │ │ │ │ │ └── wavelet_utils.py │ │ │ │ └── vae/ │ │ │ │ ├── __init__.py │ │ │ │ ├── modeling_causalvae.py │ │ │ │ └── modeling_wfvae.py │ │ │ ├── sample/ │ │ │ │ └── rec_video_vae.py │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ ├── dataset_utils.py │ │ │ ├── downloader.py │ │ │ └── video_utils.py │ │ ├── diffusion/ │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ └── opensora_v1_3/ │ │ │ ├── __init__.py │ │ │ ├── modeling_inpaint.py │ │ │ ├── modeling_opensora.py │ │ │ └── modules.py │ │ ├── frame_interpolation/ │ │ │ ├── cfgs/ │ │ │ │ └── AMT-G.yaml │ │ │ ├── interpolation.py │ │ │ ├── networks/ │ │ │ │ ├── AMT-G.py │ │ │ │ ├── __init__.py │ │ │ │ └── blocks/ │ │ │ │ ├── __init__.py │ │ │ │ ├── feat_enc.py │ │ │ │ ├── ifrnet.py │ │ │ │ ├── multi_flow.py │ │ │ │ └── raft.py │ │ │ ├── readme.md │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ ├── build_utils.py │ │ │ ├── dist_utils.py │ │ │ ├── flow_utils.py │ │ │ └── utils.py │ │ ├── prompt_refiner/ │ │ │ ├── inference.py │ │ │ ├── merge.py │ │ │ └── train.py │ │ └── text_encoder/ │ │ ├── __init__.py │ │ ├── clip.py │ │ └── t5.py │ ├── npu_config.py │ ├── sample/ │ │ ├── caption_refiner.py │ │ ├── pipeline_inpaint.py │ │ ├── pipeline_opensora.py │ │ ├── rec_image.py │ │ ├── rec_video.py │ │ └── sample.py │ ├── serve/ │ │ ├── gradio_utils.py │ │ ├── gradio_web_server.py │ │ ├── gradio_web_server_i2v.py │ │ └── style.css │ ├── train/ │ │ ├── train_causalvae.py │ │ ├── train_inpaint.py │ │ └── train_t2v_diffusers.py │ └── utils/ │ ├── communications.py │ ├── dataset_utils.py │ ├── downloader.py │ ├── ema.py │ ├── ema_utils.py │ ├── freeinit_utils.py │ ├── lora_utils.py │ ├── mask_utils.py │ ├── parallel_states.py │ ├── sample_utils.py │ └── utils.py ├── pyproject.toml └── scripts/ ├── accelerate_configs/ │ ├── ddp_config.yaml │ ├── deepspeed_zero2_config.yaml │ ├── deepspeed_zero2_offload_config.yaml │ ├── deepspeed_zero3_config.yaml │ ├── deepspeed_zero3_offload_config.yaml │ ├── default_config.yaml │ ├── hostfile │ ├── zero2.json │ ├── zero2_npu.json │ ├── zero2_offload.json │ ├── zero3.json │ └── zero3_offload.json ├── causalvae/ │ ├── eval.sh │ ├── prepare_eval.sh │ ├── rec_image.sh │ ├── rec_video.sh │ ├── train.sh │ └── wfvae_4dim.json ├── slurm/ │ └── placeholder ├── text_condition/ │ ├── gpu/ │ │ ├── sample_inpaint_v1_3.sh │ │ ├── sample_t2v_v1_3.sh │ │ ├── train_inpaint_v1_3.sh │ │ └── train_t2v_v1_3.sh │ └── npu/ │ ├── sample_inpaint_v1_3.sh │ ├── sample_t2v_v1_3.sh │ ├── train_inpaint_v1_3.sh │ └── train_t2v_v1_3.sh ├── train_configs/ │ └── mask_config.yaml └── train_data/ └── merge_data.txt ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/docker_build.yml ================================================ name: docker-build on: workflow_dispatch: push: branches: - "main" paths: - "docker/Dockerfile" jobs: build-Open-Sora: runs-on: ubuntu-latest steps: - name: Checkout uses: actions/checkout@v4 - name: Set up QEMU uses: docker/setup-qemu-action@v3 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@v3 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Build and push Open-Sora image uses: docker/build-push-action@v5 with: context: . file: ./docker/Dockerfile push: true platforms: linux/amd64, linux/arm64, linux/s390x, linux/ppc64le tags: ${{ secrets.DOCKERHUB_USERNAME }}/open-sora ================================================ FILE: .gitignore ================================================ ucf101_stride4x4x4 __pycache__ *.mp4 .ipynb_checkpoints *.pth UCF-101/ results/ build/ opensora.egg-info/ wandb/ .idea *.ipynb *.jpg *.mp3 *.safetensors *.mp4 *.png *.gif *.pth *.pt cache_dir/ wandb/ test* sample_video*/ 512* 720* 1024* *debug* private* .deepspeed_env 256* sample_image*/ taming* *test* sft* flash* 65x256* alpha_vae *node* cache/ Open-Sora-Plan_models/ sample_image*cfg* *tmp* *pymp* check.py bucket.py whileinf.py validation_dir/ runs/ samples/ inpaint*/ bs32x8x1* *tmp* *pymp* check.py bucket.py whileinf.py bs4x8x16_* *.zip *validation/ bs1x8x32* bs16x8x1* bs8x8x2* bs8x8x1* bs8x8x8* bs1x8x16* checklora.py dim4todim8.py *vae8_any*320x320* samples/ runs/ *validation/ training_log*txt filter_motion* json2*.py motionfun* res_dist* filter_json_aes_m* stage2*.json kernel_meta ge_check_op.json WFVAE_DISTILL_FORMAL read_video* bs32x8x2* filter_json_aes_m* json2json* makenpu_json* *make_small_json* *schedule_noise* test* gpu_profiling* gyy_dense* torchelasti* *VEnhancer* *spdemo* i2v.txt *run_i2v* *curope* any* *nomotion* log* *svg *k8s* *rf* *lzj* final* opensora/train/*debug.py ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) Rabbitpre Intelligence Ltd Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================

Open-Sora Plan

This project aims to create a simple and scalable repo, to reproduce [Sora](https://openai.com/sora) (OpenAI, but we prefer to call it "ClosedAI" ). 本项目希望通过开源社区的力量复现Sora,由北大-兔展AIGC联合实验室共同发起,来自兔展、华为、鹏城实验室和开源社区伙伴均有深度贡献力量。 当前V1.5版本**完全基于华为昇腾训练(昇腾纯血版)**,欢迎Pull Request和使用! 我们正在快速迭代新版本,欢迎更多合作者或算法工程师加入,[算法工程师招聘-兔展智能.pdf](https://github.com/user-attachments/files/19107972/-.pdf)
[![arXiv](https://img.shields.io/badge/Arxiv-Open--Sora%20Plan-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2412.00131) [![arXiv](https://img.shields.io/badge/Arxiv-Helios-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2603.04379) [![arXiv](https://img.shields.io/badge/Arxiv-WF--VAE-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2411.17459) [![License](https://img.shields.io/badge/License-Apache-yellow)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/LICENSE)
[![slack badge](https://img.shields.io/badge/Discord-join-blueviolet?logo=discord&)](https://discord.gg/DFZg5678) [![WeChat badge](https://img.shields.io/badge/微信-加入-green?logo=wechat&)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/issues/53#issuecomment-1987226516) [![Twitter](https://img.shields.io/badge/-Twitter@LinBin46984-black?logo=twitter&logoColor=1D9BF0)](https://x.com/LinBin46984/status/1795018003345510687) [![Modelers](https://img.shields.io/badge/%E9%AD%94%E4%B9%90-%E6%A8%A1%E5%9E%8B%E4%BD%93%E9%AA%8C-blue)](https://modelers.cn/spaces/MindSpore-Lab/Open_Sora_Plan)
[![GitHub repo stars](https://img.shields.io/github/stars/PKU-YuanGroup/Open-Sora-Plan?style=flat&logo=github&logoColor=whitesmoke&label=Stars)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/stargazers)  [![GitHub repo forks](https://img.shields.io/github/forks/PKU-YuanGroup/Open-Sora-Plan?style=flat&logo=github&logoColor=whitesmoke&label=Forks)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/network)  [![GitHub repo watchers](https://img.shields.io/github/watchers/PKU-YuanGroup/Open-Sora-Plan?style=flat&logo=github&logoColor=whitesmoke&label=Watchers)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/watchers)  [![GitHub repo size](https://img.shields.io/github/repo-size/PKU-YuanGroup/Open-Sora-Plan?style=flat&logo=github&logoColor=whitesmoke&label=Repo%20Size)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/archive/refs/heads/main.zip)
[![GitHub repo contributors](https://img.shields.io/github/contributors-anon/PKU-YuanGroup/Open-Sora-Plan?style=flat&label=Contributors)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/graphs/contributors) [![GitHub Commit](https://img.shields.io/github/commit-activity/m/PKU-YuanGroup/Open-Sora-Plan?label=Commit)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/commits/main/) [![Pr](https://img.shields.io/github/issues-pr-closed-raw/PKU-YuanGroup/Open-Sora-Plan.svg?label=Merged+PRs&color=green)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/pulls) [![GitHub issues](https://img.shields.io/github/issues/PKU-YuanGroup/Open-Sora-Plan?color=critical&label=Issues)](https://github.com/PKU-YuanGroup/Video-LLaVA/issues?q=is%3Aopen+is%3Aissue) [![GitHub closed issues](https://img.shields.io/github/issues-closed/PKU-YuanGroup/Open-Sora-Plan?color=success&label=Issues)](https://github.com/PKU-YuanGroup/Video-LLaVA/issues?q=is%3Aissue+is%3Aclosed)
PKU-YuanGroup%2FOpen-Sora-Plan | Trendshift
If you like our project, please give us a star ⭐ on GitHub for latest update.
# 📣 News * **[2026.03.08]** 👋👋👋 We introduce [Helios](https://github.com/PKU-YuanGroup/Helios), a breakthrough video generation model that achieves minute-scale, high-quality video synthesis at **19.5 FPS on a single H100** GPU — without relying on conventional long video anti-drifting strategies or standard video acceleration techniques. Welcome to check [Technical Report](https://huggingface.co/papers/2603.04379)! * **[2025.06.05]** 🔥🔥🔥 We release version 1.5.0, our most powerful model! By introducing a **higher-compression WFVAE** and an improved sparse DiT architecture, **SUV**, we achieve performance **comparable to HunyuanVideo (Open-Source)** using an 8B-scale model and 40 million video samples. Version 1.5.0 is **fully trained and inferred on Ascend 910-series accelerators**; Please check the [mindspeed_mmdit](https://github.com/PKU-YuanGroup/Open-Sora-Plan/tree/mindspeed_mmdit) branch for our new code and [Report-v1.5.0.md](docs/Report-v1.5.0.md) for our report. The GPU version is coming soon. * **[2024.12.03]** ⚡️ We released our [arxiv paper](https://arxiv.org/abs/2412.00131) and WF-VAE [paper](https://arxiv.org/abs/2411.17459) for v1.3. The next more powerful version is coming soon. * **[2024.10.16]** 🎉 We released version 1.3.0, featuring: **WFVAE**, **prompt refiner**, **data filtering strategy**, **sparse attention**, and **bucket training strategy**. We also support 93x480p within **24G VRAM**. More details can be found at our latest [report](docs/Report-v1.3.0.md). * **[2024.08.13]** 🎉 We are launching Open-Sora Plan v1.2.0 **I2V** model, which is based on Open-Sora Plan v1.2.0. The current version supports image-to-video generation and transition generation (the starting and ending frames conditions for video generation). Check out the Image-to-Video section in this [report](https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/docs/Report-v1.2.0.md#training-image-to-video-diffusion-model). * **[2024.07.24]** 🔥🔥🔥 v1.2.0 is here! Utilizing a 3D full attention architecture instead of 2+1D. We released a true 3D video diffusion model trained on 4s 720p. Check out our latest [report](docs/Report-v1.2.0.md). * **[2024.05.27]** 🎉 We are launching Open-Sora Plan v1.1.0, which significantly improves video quality and length, and is fully open source! Please check out our latest [report](docs/Report-v1.1.0.md). Thanks to [ShareGPT4Video's](https://sharegpt4video.github.io/) capability to annotate long videos. * **[2024.04.09]** 🤝 Excited to share our latest exploration on metamorphic time-lapse video generation: [MagicTime](https://github.com/PKU-YuanGroup/MagicTime), which learns real-world physics knowledge from time-lapse videos. * **[2024.04.07]** 🎉🎉🎉 Today, we are thrilled to present Open-Sora-Plan v1.0.0, which significantly enhances video generation quality and text control capabilities. See our [report](docs/Report-v1.0.0.md). Thanks to HUAWEI NPU for supporting us. * **[2024.03.27]** 🚀🚀🚀 We release the report of [VideoCausalVAE](docs/CausalVideoVAE.md), which supports both images and videos. We present our reconstructed video in this demonstration as follows. The text-to-video model is on the way. * **[2024.03.01]** 🤗 We launched a plan to reproduce Sora, called Open-Sora Plan! Welcome to **watch** 👀 this repository for the latest updates. # 😍 Gallery Text-to-Video Generation of Open-Sora Plan v1.5.0. ### Youtube: [![Demo Video of Open-Sora Plan V1.5.0](https://github.com/user-attachments/assets/130bbba2-3ded-4092-92ef-b65b673cb1a6)](https://youtu.be/IiWTdx2EHCY) ### Bilibili: [![Demo Video of Open-Sora Plan V1.5.0](https://github.com/user-attachments/assets/130bbba2-3ded-4092-92ef-b65b673cb1a6)](https://www.bilibili.com/video/BV1X77tzxE3b/) # 😮 Highlights Open-Sora Plan shows excellent performance in video generation. ### 🔥 WFVAE with higher performance and compression - With an 8×8×8 downsampling rate, but achieves higher PSNR than the VAE used in Wan2.1. Lowers the training cost for the DiT built upon it. ### 🚀 More powerful sparse dit - The more powerful sparse attention architecture, SUV, achieves performance close to dense DiT while providing over a 35% speedup.

# 🐳 Resource | Version | Architecture | Diffusion Model | CausalVideoVAE | Data | Prompt Refiner | |:---|:---|:---|:---|:---|:---| | v1.5.0 | SUV (Skiparse 3D) | [121x576x1024](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.5.0/blob/main/MindSpeed/model_ema.pt)[5] | [Anysize_8x8x8_32dim](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.5.0/blob/main/MindSpeed/wfvae_888_dim32.ckpt) | - | - | | v1.3.0 [4] | Skiparse 3D | [Anysize in 93x640x640](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/any93x640x640)[3], [Anysize in 93x640x640_i2v](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/any93x640x640_i2v)[3] | [Anysize](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/vae)| [prompt_refiner](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner) | [checkpoint](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner)| | | v1.2.0 | Dense 3D | [93x720p](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/93x720p), [29x720p](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/29x720p)[1], [93x480p](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/93x480p)[1,2], [29x480p](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/29x480p), [1x480p](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/1x480p), [93x480p_i2v](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/93x480p_i2v) | [Anysize](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/vae)| [Annotations](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0) | - | | v1.1.0 | 2+1D | [221x512x512](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main/221x512x512), [65x512x512](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main/65x512x512) |[Anysize](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main/vae) |[Data and Annotations](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.1.0)| - | | v1.0.0 | 2+1D | [65x512x512](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x512x512), [65x256x256](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x256x256), [17x256x256](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/17x256x256) | [Anysize](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/vae) | [Data and Annotations](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.0.0)| - | > [1] Please note that the weights for v1.2.0 29×720p and 93×480p were trained on Panda70M and have not undergone final high-quality data fine-tuning, so they may produce watermarks. > [2] We fine-tuned 3.5k steps from 93×720p to get 93×480p for community research use. > [3] The model is trained arbitrarily on stride=32. So keep the resolution of the inference a multiple of 32. Frames need to be 4n+1, e.g. 93, 77, 61, 45, 29, 1 (image). > [4] Model weights are also available at [OpenMind](https://modelers.cn/models/linbin/Open-Sora-Plan-v1.3.0) and [WiseModel](https://wisemodel.cn/models/PKU-YUAN/Open-Sora-Plan-v1.3.0). > [5] The current model weights are only compatible with the NPU + MindSpeed-MM framework. Model weights are also available at and [modelers](https://modelers.cn/models/PKU-YUAN-Group/Open-Sora-Plan-v1.5.0/tree/main/MindSpeed). > [!Warning] > >

> > 🚨 For version 1.2.0, we no longer support 2+1D models. > >
# ⚙️ How to start ### GPU coming soon... ### NPU Please check out the **[mindspeed_mmdit](https://github.com/PKU-YuanGroup/Open-Sora-Plan/tree/mindspeed_mmdit)** branch and follow the README.md for configuration. # 📖 Technical report Please check [Report-v1.5.0.md](docs/Report-v1.5.0.md). # 💡 How to Contribute We greatly appreciate your contributions to the Open-Sora Plan open-source community and helping us make it even better than it is now! For more details, please refer to the [Contribution Guidelines](docs/Contribution_Guidelines.md) # 👍 Acknowledgement and Related Work * [Allegro](https://github.com/rhymes-ai/Allegro): Allegro is a powerful text-to-video model that generates high-quality videos up to 6 seconds at 15 FPS and 720p resolution from simple text input based on our Open-Sora Plan. The significance of open-source is becoming increasingly tangible. * [Latte](https://github.com/Vchitect/Latte): It is a wonderful 2+1D video generation model. * [PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha): Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis. * [ShareGPT4Video](https://github.com/InternLM/InternLM-XComposer/tree/main/projects/ShareGPT4Video): Improving Video Understanding and Generation with Better Captions. * [VideoGPT](https://github.com/wilson1yan/VideoGPT): Video Generation using VQ-VAE and Transformers. * [DiT](https://github.com/facebookresearch/DiT): Scalable Diffusion Models with Transformers. * [FiT](https://github.com/whlzy/FiT): Flexible Vision Transformer for Diffusion Model. * [Positional Interpolation](https://arxiv.org/abs/2306.15595): Extending Context Window of Large Language Models via Positional Interpolation. # 🔒 License * See [LICENSE](LICENSE) for details. ## ✨ Star History [![Star History](https://api.star-history.com/svg?repos=PKU-YuanGroup/Open-Sora-Plan)](https://star-history.com/#PKU-YuanGroup/Open-Sora-Plan&Date) # ✏️ Citing ```bibtex @article{lin2024open, title={Open-Sora Plan: Open-Source Large Video Generation Model}, author={Lin, Bin and Ge, Yunyang and Cheng, Xinhua and Li, Zongjian and Zhu, Bin and Wang, Shaodong and He, Xianyi and Ye, Yang and Yuan, Shenghai and Chen, Liuhan and others}, journal={arXiv preprint arXiv:2412.00131}, year={2024} } ``` ```bibtex @article{helios, title={Helios: Real Real-Time Long Video Generation Model}, author={Yuan, Shenghai and Yin, Yuanyang and Li, Zongjian and Huang, Xinwei and Yang, Xiao and Yuan, Li}, journal={arXiv preprint arXiv:2603.04379}, year={2026} } ``` ```bibtex @article{li2024wf, title={WF-VAE: Enhancing Video VAE by Wavelet-Driven Energy Flow for Latent Video Diffusion Model}, author={Li, Zongjian and Lin, Bin and Ye, Yang and Chen, Liuhan and Cheng, Xinhua and Yuan, Shenghai and Yuan, Li}, journal={arXiv preprint arXiv:2411.17459}, year={2024} } ``` # 🤝 Community contributors ================================================ FILE: docs/Contribution_Guidelines.md ================================================ # Contributing to the Open-Sora Plan Community The Open-Sora Plan open-source community is a collaborative initiative driven by the community, emphasizing a commitment to being free and void of exploitation. Organized spontaneously by community members, we invite you to contribute to the Open-Sora Plan open-source community and help elevate it to new heights! ## Submitting a Pull Request (PR) As a contributor, before submitting your request, kindly follow these guidelines: 1. Start by checking the [Open-Sora Plan GitHub](https://github.com/PKU-YuanGroup/Open-Sora-Plan/pulls) to see if there are any open or closed pull requests related to your intended submission. Avoid duplicating existing work. 2. [Fork](https://github.com/PKU-YuanGroup/Open-Sora-Plan/fork) the [open-sora plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) repository and download your forked repository to your local machine. ```bash git clone [your-forked-repository-url] ``` 3. Add the original Open-Sora Plan repository as a remote to sync with the latest updates: ```bash git remote add upstream https://github.com/PKU-YuanGroup/Open-Sora-Plan ``` 4. Sync the code from the main repository to your local machine, and then push it back to your forked remote repository. ``` # Pull the latest code from the upstream branch git fetch upstream # Switch to the main branch git checkout main # Merge the updates from the upstream branch into main, synchronizing the local main branch with the upstream git merge upstream/main # Additionally, sync the local main branch to the remote branch of your forked repository git push origin main ``` > Note: Sync the code from the main repository before each submission. 5. Create a branch in your forked repository for your changes, ensuring the branch name is meaningful. ```bash git checkout -b my-docs-branch main ``` 6. While making modifications and committing changes, adhere to our [Commit Message Format](#Commit-Message-Format). ```bash git commit -m "[docs]: xxxx" ``` 7. Push your changes to your GitHub repository. ```bash git push origin my-docs-branch ``` 8. Submit a pull request to `Open-Sora-Plan:main` on the GitHub repository page. ## Commit Message Format Commit messages must include both `` and `` sections. ```bash []: │ │ │ └─⫸ Briefly describe your changes, without ending with a period. │ └─⫸ Commit Type: |docs|feat|fix|refactor| ``` ### Type * **docs**: Modify or add documents. * **feat**: Introduce a new feature. * **fix**: Fix a bug. * **refactor**: Restructure code, excluding new features or bug fixes. ### Summary Describe modifications in English, without ending with a period. > e.g., git commit -m "[docs]: add a contributing.md file" This guideline is borrowed by [minisora](https://github.com/mini-sora/minisora). We sincerely appreciate MiniSora authors for their awesome templates. ================================================ FILE: docs/Prompt_Refiner.md ================================================ ## Data We have open-sourced our dataset of 32,555 pairs, which includes Chinese data. The dataset is available [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner). The details can be found [here](https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/docs/Report-v1.3.0.md#prompt-refiner). In fact, it is a JSON file with the following structure. ``` [ { "instruction": "Refine the sentence: \"A newly married couple sharing a piece of there wedding cake.\" to contain subject description, action, scene description. (Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. Make sure it is a fluent sentence, not nonsense.", "input": "", "output": "The newlywed couple, dressed in elegant attire..." }, ... ] ``` ## Train `--data_path` is the path to the prepared JSON file. `--model_path` is the directory containing the LLaMA 3.1 weights, including `config.json` and some weight files. `--lora_out_path` is the path where the LoRA model will be saved. ``` cd opensora/models/prompt_refiner CUDA_VISIBLE_DEVICES=0 python train.py \ --data_path path/to/data.json \ --model_path path/to/llama_model \ --lora_out_path path/to/save/lora_model ``` ## Merge `--model_path` is the directory containing the LLaMA 3.1 weights, including `config.json` and some weight files. `--lora_in_path` is the directory containing the pre-trained LoRA model. `--lora_out_path` is the path for the merged model. ``` cd opensora/models/prompt_refiner CUDA_VISIBLE_DEVICES=0 python merge.py \ --base_path path/to/llama_model \ --lora_in_path path/to/save/lora_model \ --lora_out_path path/to/save/merge_model ``` ## Inference `--model_path` is the directory containing the weights (LLaMA 3.1 or merged Lora weight), including `config.json` and some weight files. `--prompt` is the text you want to input, which will be refined. ``` cd opensora/models/prompt_refiner CUDA_VISIBLE_DEVICES=0 python merge.py \ --mode_path path/to/data.json \ --prompt path/to/save/lora_model ``` ================================================ FILE: docs/Report-v1.0.0-cn.md ================================================ # 技术报告 v1.0.0 在2024年3月,我们推出了Open-Sora-Plan,一个旨在复现OpenAI [Sora](https://openai.com/sora)的开源计划。它作为一个基础的开源框架,能够训练视频生成模型包括无条件视频生成,类别引导视频生成,文生视频。 **今天,我们兴奋地展示Open-Sora-Plan v1.0.0,极大地改进视频生成质量、文本控制能力。** 相比于之前的视频生成模型,Open-Sora-Plan v1.0.0 有以下的改进: 1. **CausalVideoVAE高效的训练与推理**。 我们用4×8×8的对视频进行时间和空间的压缩。 2. **图片视频联合训练提升视觉质量**。 CasualVideoVAE 将首帧看作图片,天然支持同时编码图片和视频。这允许扩散模型提取更多时空细节来改善质量。 ### Open-Source Release 我们开源了Open-Sora-Plan去促进视频生成社区的进一步发展。公开代码、数据、模型。 - 在线演示:Hugging Face [![hf_space](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.0.0), [![Replicate demo and cloud API](https://replicate.com/camenduru/open-sora-plan-512x512/badge)](https://replicate.com/camenduru/open-sora-plan-512x512) 和 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/Open-Sora-Plan-jupyter/blob/main/Open_Sora_Plan_jupyter.ipynb), 感谢[@camenduru](https://github.com/camenduru)大力支持我们的工作!🤝 - 代码:所有训练脚本和采样代码。 - 模型:包括扩散模型和CausalVideoVAE [这里](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0)。 - 数据:所有原视频和对应描述 [这里](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.0.0)。 ## 效果 Open-Sora-Plan v1.0.0支持图片视频联合训练。我们在此展示视频和图片的重建以及生成: 720×1280**视频重建**。 因为github的限制,原视频放在: [1](https://streamable.com/gqojal), [2](https://streamable.com/6nu3j8). https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/c100bb02-2420-48a3-9d7b-4608a41f14aa https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/8aa8f587-d9f1-4e8b-8a82-d3bf9ba91d68 1536×1024**图片重建** 65×1024×1024**文生视频** https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/2641a8aa-66ac-4cda-8279-86b2e6a6e011 65×512×512**文生视频** https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/37e3107e-56b3-4b09-8920-fa1d8d144b9e 512×512**文生视频** ![download](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/491d72bc-e762-48ff-bdcc-cc69350f56d6) ## 详细技术报告 ### CausalVideoVAE #### 模型结构 ![image](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/e3c8b35d-a217-4d96-b2e9-5c248a2859c8) 因果VAE架构继承了[Stable-Diffusion Image VAE](https://github.com/CompVis/stable-diffusion/tree/main)。 为了保证图片VAE的预训练权重可以无缝运用到视频VAE中,模型结构采取如下设计: 1. **CausalConv3D**: 将Conv2D 转变成CausalConv3D可以实现图片和视频的联合训练. CausalConv3D 对第一帧进行特殊处理,因为它无法访问后续帧。对于更多细节,请参考https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/145 2. **初始化**:将Conv2D扩展到Conv3D常用的[方法](https://github.com/hassony2/inflated_convnets_pytorch/blob/master/src/inflate.py#L5)有两种:平均初始化和中心初始化。 但我们采用了特定的初始化方法(尾部初始化)。 这种初始化方法确保模型无需任何训练就能够直接重建图像,甚至视频。 #### 训练细节 image 我们展示了 17×256×256 下两种不同初始化方法的损失曲线。黄色曲线代表使用尾部初始化的损失,而蓝色曲线对应中心初始化的损失。 如图所示,尾部初始化在损失曲线上表现出更好的性能。 此外,我们发现中心初始化会导致错误累积,导致长时间内崩溃。 #### 推理技巧 尽管训练Diffusion中VAE始终是冻住的,我们仍然无法负担CasualVideoVAE的花销。在我们的实验中, 80G的显存只能够在半精度下推理一个256×512×512或32×1024×1024的视频 ,这限制了我们扩展到更长更高清的视频。因此我们采用tile convolution,能够以几乎恒定的内存推理任意时长或任意分辨率的视频。 ### 数据构建 我们定义高质量的视频数据集包括两个核心法则:(1) 没有与内容无关的水印。(2) 高质量的文本注释。 **对于法则1**,我们从开源网站(CC0协议)爬取了大约40k videos:1234个来自[mixkit](https://mixkit.co/),7408个来自[pexels](https://www.pexels.com/),31616个来自[pixabay](https://pixabay.com/)。我们根据[Panda70M](https://github.com/snap-research/Panda-70M/blob/main/splitting/README.md)提供的场景变换剪切script将这些视频切成大约434k video clips。事实上,根据我们的剪切结果,从这些网上上爬取的99%的视频都是单一的场景。另外,我们发现爬取的数据中超过60%为风景相关视频。更多细节可以在[这](https://github.com/PKU-YuanGroup/Open-Sora-Dataset)找到。 **对于法则2**,很难有大量的高质量的文本注释能够从网上直接爬取。因此我们用成熟的图片标注模型来获取高质量的稠密描述。我们对2个多模态大模型进行消融实验:[ShareGPT4V-Captioner-7B](https://github.com/InternLM/InternLM-XComposer/blob/main/projects/ShareGPT4V/README.md) 和 [LLaVA-1.6-34B](https://github.com/haotian-liu/LLaVA)。前者是专门用来制作文本注释的模型,而后者是一个通用的多模态大模型。经过我们的消融实验,他们在caption的表现差不多。然而他们的推理速度在A800上差距很大:40s/it of batch size of 12 for ShareGPT4V-Captioner-7B,15s/it of batch size of 1 for ShareGPT4V-Captioner-7B。我们开源所有的[文本注释和原视频](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.0.0)。 | 模型名字 | 平均长度 | 最大值 | 标准差 | |---|---|---|---| | ShareGPT4V-Captioner-7B | 170.0827524529121 | 467 | 53.689967539537776 | | LLaVA-1.6-34B | 141.75851073472666 | 472 | 48.52492072346965 | ### 训练扩散模型 与之前的工作类似,我们采用多阶段的级联的训练方法,总共消耗了2048个A800 GPU 小时。我们发现联合图片训练能够显著加速模型的收敛并且增强视觉观感,这与[Latte](https://github.com/Vchitect/Latte)一致。以下是我们的训练花销。 | 名字 | Stage 1 | Stage 2 | Stage 3 | Stage 4 | |---|---|---|---|---| | 训练视频尺寸 | 17×256×256 | 65×256×256 | 65×512×512 | 65×1024×1024 | | 计算资源 (#A800 GPU x #小时) | 32 × 40 | 32 × 18 | 32 × 6 | 训练中 | | 权重 | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/17x256x256) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x256x256) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x512x512) | 训练中 | | 日志 | [wandb](https://api.wandb.ai/links/linbin/p6n3evym) | [wandb](https://api.wandb.ai/links/linbin/t2g53sew) | [wandb](https://api.wandb.ai/links/linbin/uomr0xzb) | 训练中 | | 训练数据 | ~40k videos | ~40k videos | ~40k videos | ~40k videos | ## 下版本预览 ### CausalVideoVAE 目前我们发布的CausalVideoVAE v1.0.0版本存在2个主要的缺陷:**运动模糊**以及**网格效应**。我们对CasualVideoVAE做了一系列的改进使它推理成本更低且性能更强大,我们暂时叫它为预览版本,将在下个版本发布。 **1分钟720×1280视频重建**。 受限于GitHub,我们将原视频放在这:[原视频](https://streamable.com/u4onbb),[重建视频](https://streamable.com/qt8ncc)。 https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/cdcfa9a3-4de0-42d4-94c0-0669710e407b 我们从kinetic 400的验证集中随机选取100个样本进行评估,结果表如下所示: | | SSIM↑ | LPIPS↓ | PSNR↑ | FLOLPIPS↓ | |---|---|---|---|---| | v1.0.0 | 0.829 | 0.106 | 27.171 | 0.119 | | Preview | 0.877 | 0.064 | 29.695 | 0.070 | #### 运动模糊 | **v1.0.0** | **预览版本** | | --- | --- | | ![6862cae0-b1b6-48d1-bd11-84348cf42b42](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/f815636f-fb38-4891-918b-50b1f9aa086d) | ![9189da06-ef2c-42e6-ad34-bd702a6f538e](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/1e413f50-a785-485a-9851-a1449f952f1c) | #### 网格效应 | **v1.0.0** | **预览版本** | | --- | --- | | ![img](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/7fec5bed-3c83-4ee9-baef-4a3dacafc658) | ![img](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/4f41b432-a3ef-484e-a492-8afd8a691bf7) | ### 数据构建 **数据源**:正如上文提到,我们的数据集中超过60%为风景视频。这意味着我们的开域视频生成能力有限。然而当前的大规模开源数据集大多从YouTube爬取,尽管视频的数量多,但我们担忧视频本身的质量是否达标。因此,我们将继续收集高质量的数据集,同时也欢迎开源社区的推荐。 **Caption生成流程**:当我们训练时长增加时,我们不得不考虑更有效的视频caption生成方法,而不是多模态图片大模型。我们正在开发一个新的视频注释生成管线,它能够很好的支持长视频,敬请期待。 ### 训练扩散模型 尽管目前v1.0.0展现了可喜的结果,但我们仍然离Sora有一段距离。在接下来的工作中,我们主要围绕这三个方面: 1. **动态分辨率与时长的训练**: 我们的目标是开发出能够以不同分辨率和持续时间训练模型的技术,使训练过程更加灵活、适应性更强。 2. **更长的视频生成**: 我们将探索扩展模型生成能力的方法,使其能够制作更长的视频,超越目前的限制。 3. **更多条件控制**: 我们力求增强模型的条件控制能力,为用户提供更多的选项和对生成视频的控制能力。 另外,通过仔细观察生成的视频,我们发现存在一些不符合常理的斑点或异常的流动,这是由于CasualVideoVAE的性能不足导致的 如上面提到。在未来的实验中,我们将使用更强的VAE,重新训练一个扩散模型。 ================================================ FILE: docs/Report-v1.0.0.md ================================================ # Report v1.0.0 In March 2024, we launched a plan called Open-Sora-Plan, which aims to reproduce the OpenAI [Sora](https://openai.com/sora) through an open-source framework. As a foundational open-source framework, it enables training of video generation models, including Unconditioned Video Generation, Class Video Generation, and Text-to-Video Generation. **Today, we are thrilled to present Open-Sora-Plan v1.0.0, which significantly enhances video generation quality and text control capabilities.** Compared with previous video generation model, Open-Sora-Plan v1.0.0 has several improvements: 1. **Efficient training and inference with CausalVideoVAE**. We apply a spatial-temporal compression to the videos by 4×8×8. 2. **Joint image-video training for better quality**. Our CausalVideoVAE considers the first frame as an image, allowing for the simultaneous encoding of both images and videos in a natural manner. This allows the diffusion model to grasp more spatial-visual details to improve visual quality. ### Open-Source Release We open-source the Open-Sora-Plan to facilitate future development of Video Generation in the community. Code, data, model are made publicly available. - Demo: Hugging Face demo [![hf_space](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.0.0). 🤝 Enjoying the [![Replicate demo and cloud API](https://replicate.com/camenduru/open-sora-plan-512x512/badge)](https://replicate.com/camenduru/open-sora-plan-512x512) and [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/Open-Sora-Plan-jupyter/blob/main/Open_Sora_Plan_jupyter.ipynb), created by [@camenduru](https://github.com/camenduru), who generously supports our research! - Code: All training scripts and sample scripts. - Model: Both Diffusion Model and CausalVideoVAE [here](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0). - Data: Both raw videos and captions [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.0.0). ## Gallery Open-Sora-Plan v1.0.0 supports joint training of images and videos. Here, we present the capabilities of Video/Image Reconstruction and Generation: ### CausalVideoVAE Reconstruction **Video Reconstruction** with 720×1280. Since github can't upload large video, we put it here: [1](https://streamable.com/gqojal), [2](https://streamable.com/6nu3j8). https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/c100bb02-2420-48a3-9d7b-4608a41f14aa https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/8aa8f587-d9f1-4e8b-8a82-d3bf9ba91d68 **Image Reconstruction** in 1536×1024. **Text-to-Video Generation** with 65×1024×1024 https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/2641a8aa-66ac-4cda-8279-86b2e6a6e011 **Text-to-Video Generation** with 65×512×512 https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/37e3107e-56b3-4b09-8920-fa1d8d144b9e **Text-to-Image Generation** with 512×512 ![download](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/491d72bc-e762-48ff-bdcc-cc69350f56d6) ## Detailed Technical Report ### CausalVideoVAE #### Model Structure ![image](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/e3c8b35d-a217-4d96-b2e9-5c248a2859c8) The CausalVideoVAE architecture inherits from the [Stable-Diffusion Image VAE](https://github.com/CompVis/stable-diffusion/tree/main). To ensure that the pretrained weights of the Image VAE can be seamlessly applied to the Video VAE, the model structure has been designed as follows: 1. **CausalConv3D**: Converting Conv2D to CausalConv3D enables joint training of image and video data. CausalConv3D applies a special treatment to the first frame, as it does not have access to subsequent frames. For more specific details, please refer to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/145 2. **Initialization**: There are two common [methods](https://github.com/hassony2/inflated_convnets_pytorch/blob/master/src/inflate.py#L5) to expand Conv2D to Conv3D: average initialization and center initialization. But we employ a specific initialization method (tail initialization). This initialization method ensures that without any training, the model is capable of directly reconstructing images, and even videos. #### Training Details image We present the loss curves for two distinct initialization methods under 17×256×256. The yellow curve represents the loss using tail init, while the blue curve corresponds to the loss from center initialization. As shown in the graph, tail initialization demonstrates better performance on the loss curve. Additionally, we found that center initialization leads to error accumulation, causing the collapse over extended durations. #### Inference Tricks Despite the VAE in Diffusion training being frozen, we still find it challenging to afford the cost of the CausalVideoVAE. In our case, with 80GB of GPU memory, we can only infer a video of either 256×512×512 or 32×1024×1024 resolution using half-precision, which limits our ability to scale up to longer and higher-resolution videos. Therefore, we adopt tile convolution, which allows us to infer videos of arbitrary duration or resolution with nearly constant memory usage. ### Data Construction We define a high-quality video dataset based on two core principles: (1) No content-unrelated watermarks. (2) High-quality and dense captions. **For principles 1**, we crawled approximately 40,000 videos from open-source websites under the CC0 license. Specifically, we obtained 1,234 videos from [mixkit](https://mixkit.co/), 7,408 videos from [pexels](https://www.pexels.com/), and 31,616 videos from [pixabay](https://pixabay.com/). These videos adhere to the principle of having no content-unrelated watermarks. According to the scene transformation and clipping script provided by [Panda70M](https://github.com/snap-research/Panda-70M/blob/main/splitting/README.md), we have divided these videos into approximately 434,000 video clips. In fact, based on our clipping results, 99% of the videos obtained from these online sources are found to contain single scenes. Additionally, we have observed that over 60% of the crawled data comprises landscape videos. More details can be found [here](https://github.com/PKU-YuanGroup/Open-Sora-Dataset). **For principles 2**, it is challenging to directly crawl a large quantity of high-quality dense captions from the internet. Therefore, we utilize a mature Image-captioner model to obtain high-quality dense captions. We conducted ablation experiments on two multimodal large models: [ShareGPT4V-Captioner-7B](https://github.com/InternLM/InternLM-XComposer/blob/main/projects/ShareGPT4V/README.md) and [LLaVA-1.6-34B](https://github.com/haotian-liu/LLaVA). The former is specifically designed for caption generation, while the latter is a general-purpose multimodal large model. After conducting our ablation experiments, we found that they are comparable in performance. However, there is a significant difference in their inference speed on the A800 GPU: 40s/it of batch size of 12 for ShareGPT4V-Captioner-7B, 15s/it of batch size of 1 for LLaVA-1.6-34B. We open-source all annotations [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.0.0). We show some statistics here, and we set the maximum length of the model to 300, which covers almost 99% of the samples. | Name | Avg length | Max | Std | |---|---|---|---| | ShareGPT4V-Captioner-7B | 170.0827524529121 | 467 | 53.689967539537776 | | LLaVA-1.6-34B | 141.75851073472666 | 472 | 48.52492072346965 | ### Training Diffusion Model Similar to previous work, we employ a multi-stage cascaded training approach, which consumes a total of 2,528 A800 GPU hours. We found that joint training with images significantly accelerates model convergence and enhances visual perception, aligning with the findings of [Latte](https://github.com/Vchitect/Latte). Below is our training card: | Name | Stage 1 | Stage 2 | Stage 3 | Stage 4 | |---|---|---|---|---| | Training Video Size | 17×256×256 | 65×256×256 | 65×512×512 | 65×1024×1024 | | Compute (#A800 GPU x #Hours) | 32 × 40 | 32 × 22 | 32 × 17 | Under training | | Checkpoint | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/17x256x256) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x256x256) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x512x512) | Under training | | Log | [wandb](https://api.wandb.ai/links/linbin/p6n3evym) | [wandb](https://api.wandb.ai/links/linbin/t2g53sew) | [wandb](https://api.wandb.ai/links/linbin/uomr0xzb) | Under training | | Training Data | ~40k videos | ~40k videos | ~40k videos | ~40k videos | ## Next Release Preview ### CausalVideoVAE Currently, the released version of CausalVideoVAE (v1.0.0) has two main drawbacks: **motion blurring** and **gridding effect**. We have made a series of improvements to CausalVideoVAE to reduce its inference cost and enhance its performance. We are currently referring to this enhanced version as the "preview version," which will be released in the next update. Preview reconstruction is as follows: **1 min Video Reconstruction with 720×1280**. Since github can't put too big video, we put it here: [origin video](https://streamable.com/u4onbb), [reconstruction video](https://streamable.com/qt8ncc). https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/cdcfa9a3-4de0-42d4-94c0-0669710e407b We randomly selected 100 samples from the validation set of Kinetics-400 for evaluation, and the results are presented in the following table: | | SSIM↑ | LPIPS↓ | PSNR↑ | FLOLPIPS↓ | |---|---|---|---|---| | v1.0.0 | 0.829 | 0.106 | 27.171 | 0.119 | | Preview | 0.877 | 0.064 | 29.695 | 0.070 | #### Motion Blurring | **v1.0.0** | **Preview** | | --- | --- | | ![6862cae0-b1b6-48d1-bd11-84348cf42b42](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/f815636f-fb38-4891-918b-50b1f9aa086d) | ![9189da06-ef2c-42e6-ad34-bd702a6f538e](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/1e413f50-a785-485a-9851-a1449f952f1c) | #### Gridding effect | **v1.0.0** | **Preview** | | --- | --- | | ![img](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/7fec5bed-3c83-4ee9-baef-4a3dacafc658) | ![img](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/4f41b432-a3ef-484e-a492-8afd8a691bf7) | ### Data Construction **Data source**. As mentioned earlier, over 60% of our dataset consists of landscape videos. This implies that our ability to generate videos in other domains is limited. However, most of the current large-scale open-source datasets are primarily obtained through web scraping from platforms like YouTube. While these datasets provide a vast quantity of videos, we have concerns about the quality of the videos themselves. Therefore, we will continue to collect high-quality datasets and also welcome recommendations from the open-source community. We are launching an Open-Sora-Dataset project, check out the details at [Open-Sora-Dataset](https://github.com/PKU-YuanGroup/Open-Sora-Dataset) **Caption Generation Pipeline**. As the video duration increases, we need to consider more efficient methods for video caption generation instead of relying solely on large multimodal image models. We are currently developing a new video caption generation pipeline that provides robust support for long videos. We are excited to share more details with you in the near future. Stay tuned! ### Training Diffusion Model Although v1.0.0 has shown promising results, we acknowledge that we still have a ways to go to reach the level of Sora. In our upcoming work, we will primarily focus on three aspects: 1. **Training support for dynamic resolution and duration**: We aim to develop techniques that enable training models with varying resolutions and durations, allowing for more flexible and adaptable training processes. 2. **Support for longer video generation**: We will explore methods to extend the generation capabilities of our models, enabling them to produce longer videos beyond the current limitations. 3. **Enhanced conditional control**: We seek to enhance the conditional control capabilities of our models, providing users with more options and control over the generated videos. Furthermore, through careful observation of the generated videos, we have noticed the presence of some non-physiological speckles or abnormal flow. This can be attributed to the limited performance of CausalVideoVAE, as mentioned earlier. In future experiments, we plan to retrain a diffusion model using a more powerful version of CausalVideoVAE to address these issues. ================================================ FILE: docs/Report-v1.1.0.md ================================================ # Report v1.1.0 In April 2024, we launched Open-Sora-Plan v1.0.0, featuring a simple and efficient design along with remarkable performance in text-to-video generation. It has already been adopted as a foundational model in numerous research projects, including its data and model. **Today, we are excited to present Open-Sora-Plan v1.1.0, which significantly improves video generation quality and duration.** Compared to the previous version, Open-Sora-Plan v1.1.0, the improvements include: 1. **Better compressed visual representations**. We optimized the CausalVideoVAE architecture, which now has stronger performance and higher inference efficiency. 2. **Generate higher quality, longer videos**. We used higher quality visual data and captions by [ShareGPT4Video](https://sharegpt4video.github.io/), enabling the model to better understand the workings of the world. Along with performance improvements, Open-Sora-Plan v1.1.0 maintains the minimalist design and data efficiency of v1.0.0. Remarkably, we found that v1.1.0 exhibits similar performance to the Sora base model, indicating that our version's evolution aligns with the scaling law demonstrated by Sora. ### Open-Source Release We open-source the Open-Sora-Plan to facilitate future development of Video Generation in the community. Code, data, model will be made publicly available. - Demo: Hugging Face demo [here](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.1.0). - Code: All training scripts and sample scripts. - Model: Both Diffusion Model and CasualVideoVAE [here](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.1.0). - Data: Both raw videos and captions [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.1.0). ## Gallery ### 221×512×512 Text-to-Video Generation | 221×512×512 (9.2s) | 221×512×512 (9.2s) | 221×512×512 (9.2s) | 221×512×512 (9.2s) | | --- | --- | --- | --- | | | | | | | This close-up shot of a Victoria crowned pigeon showcases its striking blue plumage ... | a cat wearing sunglasses and working as a lifeguard at pool. | Photorealistic closeup video of two pirate ships battling each other as they sail ... | A movie trailer featuring the adventures ofthe 30 year old spacemanwearing a redwool ... | | | | | | | A snowy forest landscape with a dirt road running through it. The road is flanked by ... | Drone shot along the Hawaii jungle coastline, sunny day. Kayaks in the water. | Alpacas wearing knit wool sweaters, graffiti background, sunglasses. | The camera rotates around a large stack of vintage televisions all showing different ... | | | | | | | A drone camera circles around a beautiful historic church built on a rocky outcropping ... | Aerial view of Santorini during the blue hour, showcasing the stunning architecture ... | A robot dog explores the surface of Mars, kicking up red dust as it investigates ... | An aerial shot of a lighthouse standing tall on a rocky cliff, its beacon cutting ... | | | | | | | 3D animation of a small, round, fluffy creature with big, expressive eyes explores ... | A corgi vlogging itself in tropical Maui. | A single drop of liquid metal falls from a floating orb, landing on a mirror-like ... | The video presents an abstract composition centered around a hexagonal shape adorned ... | ### 65×512×512 Text-to-Video Generation | 65×512×512 (2.7s) | 65×512×512 (2.7s) | 65×512×512 (2.7s) | 65×512×512 (2.7s) | | --- | --- | --- | --- | | | | | | | Extreme close-up of chicken and green pepper kebabs grilling on a barbeque with flames. | 3D animation of a small, round, fluffy creature with big, expressive eyes explores a ... | A corgi vlogging itself in tropical Maui. | In a studio, there is a painting depicting a ship sailing through the rough sea. | | | | | | | A robot dog trots down a deserted alley at night, its metallic paws clinking softly ... | A solitary spider weaves its web in a quiet corner. The web shimmers and glows with ... | A lone surfer rides a massive wave, skillfully maneuvering through the surf. The water ... | A solitary cheetah sprints across the savannah, its powerful muscles propelling it ... | | | | | | | A solitary astronaut plants a flag on an alien planet covered in crystal formations ... | At dawn's first light, a spaceship slowly exits the edge of the galaxy against a ...| A dapper puppy in a miniature suit, basking in the afternoon sun, adjusting his tie ... | A wise old elephant painting abstract art with its trunk, each stroke a burst of color ... | | | | | | | In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two ... | A Shiba Inu dog wearing a beret and black turtleneck. | A painting of a boat on water comes to life, with waves crashing and the boat becoming ... | Many spotted jellyfish pulsating under water. Their bodies are transparent and glowing ... | | | | | | | An animated hedgehog with distinctive spiky hair and large eyes is seen exploring a ... | An animated rabbit in a playful pink snowboarding outfit is carving its way down a ... | A person clad in a space suit with a helmet and equipped with a chest light and arm ... | | ### 65×512×512 Video Editing | generated 65×512×512 (2.7s) | edited 65×512×512 (2.7s) | | --- | --- | | | | | | | | | | ### 512×512 Text-to-Image Generation ## Detailed Technical Report ### CasualVideoVAE #### Model Structure As the number of frames increases, the encoder overhead of CausalVideoVAE gradually rises. When training with 257 frames, 80GB of VRAM is insufficient for the VAE to encode the video. Therefore, we reduced the number of CausalConv3D layers, retaining only the last two stages of CausalConv3D in the encoder. This change significantly lowers the overhead while maintaining nearly the same performance. Note that we only modified the encoder; the decoder still retains all CausalConv3D layers, as training the Diffusion Model does not require the decoder. vaemodel We compare the computational overhead of the two versions by testing the forward inference of the encoder on the H100. | Version | 129×256×256 | | 257×256×256 | | 513×256×256 | | |---|---|---|---|---|---|---| | | Peak Mem. | Speed | Peak Mem. | Speed |Peak Mem. | Speed | | v1.0.0 | 22G | 2.9 it/s | OOM | - | OOM | - | | v1.1.0 | 18G | 4.9 it/s | 34G | 2.5 it/s | 61G | 1.2 it/s | #### Temporal Module vaemodel In v1.0.0, our temporal module had only TemporalAvgPool. TemporalAvgPool leads to the loss of high-frequency information in the video, such as details and edges. To address this issue, we improved this module in v1.1.0. As shown in the figure below, we introduced convolution and added learnable weights, allowing different branches to decouple different features. When we omit CausalConv3D, the video is reconstructed very blurry. Similarly, when we omit TemporalAvgPool, the video becomes very sharp. | | SSIM↑ | LPIPS↓ | PSNR↑ | |---|---|---|---| | Base | 0.850 | 0.091 | 28.047 | | + Frames | 0.868 | 0.070 | 28.829 | | + Reset mixed factor | 0.873 | 0.070 | 29.140 | #### Training Details Similar to v1.0.0, we initialized from the Latent Diffusion's VAE and used tail initialization. For CausalVideoVAE, we trained for 100k steps in the first stage with a video shape of 9×256×256. Subsequently, we increased the frame count from 9 to 25 and found that this significantly improved the model's performance. It is important to clarify that we enabled the mixed factor during both the first and second stages, with a value of a (sigmoid(mixed factor)) reaching 0.88 at the end of training, indicating the model's tendency to retain low-frequency information. In the third stage, we reinitialized the mixed factor to 0.5 (sigmoid(0.5) = 0.6225), which further enhanced the model's capabilities. #### Loss Function We found that using GAN loss helps retain high-frequency information and alleviates grid artifacts. Additionally, we observed that switching from 2D GAN to 3D GAN provides further improvements. | GAN Loss/Step | SSIM↑ | LPIPS↓ | PSNR↑ | |---|---|---|---| | 2D/80k | 0.879 | 0.068 | 29.480 | | 3D/80k | 0.882 | 0.067 | 29.890 | #### Inference Tricks Therefore, we introduced a method called **temporal rollback tiled convolution**, a tiling approach specifically designed for CausalVideoVAE. Specifically, all windows except the first one discard the first frame because the first frame in a window is treated as an image, while the remaining frames should be treated as video frames. tiled_temp We tested the speed on the H100 with a window size of 65×256×256. | Version | 129×256×256 | | 257×256×256 | | 513×256×256 | | |---|---|---|---|---|---|---| | | Peak Mem. | Speed | Peak Mem. | Speed |Peak Mem. | Speed | | 4×8×8 | 10G | 1.3 s/it | 10G | 2.6 s/it | 10G | 5.3 s/it | ### Data Construction Since Open-Sora-Plan supports joint training of images and videos, our data collection is divided into two parts: images and videos. Images do not need to originate from videos; they are independent datasets. We spent approximately 32×240 H100 hours generating image and video captions, and all of this is **open source**! #### Image-Text Collection Pipeline We obtained 11 million image-text pairs from [Pixart-Alpha](https://huggingface.co/datasets/PixArt-alpha/SAM-LLaVA-Captions10M), with captions generated by [LLaVA](https://github.com/haotian-liu/LLaVA). Additionally, we utilized the high-quality OCR dataset [Anytext-3M](https://github.com/tyxsspa/AnyText), which pairs each image with corresponding OCR characters. However, these captions were insufficient to describe the entire image, so we used [InternVL-1.5](https://github.com/OpenGVLab/InternVL) for supplementary descriptions. Since T5 only supports English, we filtered for English data, which constitutes about half of the complete dataset. Furthermore, we selected high-quality images from [Laion-5B](https://laion.ai/blog/laion-5b/) to enhance human-like generation quality. The selection criteria included high resolution, high aesthetic scores, and watermark-free images containing people. Here, we are open-sourcing the prompt used for InternVL-1.5: ``` # for anytext-3m Combine this rough caption: "{}", analyze the image in a comprehensive and detailed manner. "{}" can be recognized in the image. # for human-160k Analyze the image in a comprehensive and detailed manner. ``` | Name | Image Source | Text Captioner | Num pair | |---|---|---|---| | SAM-11M | [SAM](https://ai.meta.com/datasets/segment-anything/) | [LLaVA](https://github.com/haotian-liu/LLaVA) | 11,185,255 | | Anytext-3M-en | [Anytext](https://github.com/tyxsspa/AnyText) | [InternVL-1.5](https://github.com/OpenGVLab/InternVL) | 1,886,137 | | Human-160k | [Laion](https://laion.ai/blog/laion-5b/) | [InternVL-1.5](https://github.com/OpenGVLab/InternVL) | 162,094 | #### Video-Text Collection Pipeline In v1.0.0, we sampled one frame from each video to generate captions. However, as video length increased, a single frame could not adequately describe the entire video's content or temporal movements. Therefore, we used a video captioner to generate captions for the entire video clip. Specifically, we used [ShareGPT4Video](https://sharegpt4video.github.io/), which effectively covers temporal information and describes the entire video content. The v1.1.0 video dataset comprises approximately 3k hours, compared to only 300 hours in v1.0.0. As before, we have open-sourced all text annotations and videos (both under the CC0 license), which can be found [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main). | Name | Hours | Num frames | Num pair | |---|---|---|---| | [Mixkit](https://mixkit.co/) | 42.0h | 65 | 54,735 | | | | 513 | 1,997 | | [Pixabay](https://pixabay.com/) | 353.3h | 65 | 601,513 | | | | 513 | 51,483 | | [Pexel](https://www.pexels.com/) | 2561.9h | 65 | 3,832,666 | | | | 513 | 271,782 | ### Training Diffusion Model Similar to our previous work, we employed a multi-stage cascaded training method. Below is our training card: #### Stage 1 Surprisingly, we initially believed that the performance of the diffusion model would improve with longer training. However, by observing the [logs](https://api.wandb.ai/links/linbin/o76j03j4), we found that videos generated at 50k steps were of higher quality than those at 70-100k steps. In fact, extensive sampling revealed that checkpoints at 40-60k steps outperformed those at 80-100k steps. Quantitatively, 50k steps correspond to approximately 2 epochs of training. It is currently unclear whether this is due to overfitting from a small dataset or the limited capacity of the 2+1D model. #### Stage 2 In the second stage, we used Huawei Ascend computing power for training. This stage's training and inference were fully supported by Huawei. We conducted sequence parallel training and inference on a large-scale cluster, distributing one sample across eight ranks. Models trained on Huawei Ascend can also be loaded into GPUs and generate videos of the same quality. #### Stage 3 In the third stage, we further increased the frame count to 513 frames, approximately 21 seconds at 24 FPS. However, this stage presents several challenges, such as ensuring temporal consistency in the 2+1D model over long durations and whether the current amount of data is sufficient. We are still training the model for this stage and continuously monitoring its progress. | Name | Stage 1 | Stage 2 | Stage 3 | |---|---|---|---| | Training Video Size | 65×512×512 | 221×512×512 | 513×512×512 | | Compute (#Num x #Hours) | 80 H100 × 72 | 512 Ascend × 72 | Under Training | | Checkpoint | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main/65x512x512) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main/221x512x512) | Under Training | | Log | [wandb](https://api.wandb.ai/links/linbin/o76j03j4) | - | - | | Training Data | ~3k hours videos + 13M images | | | ### Video Editing The recently proposed [ReVideo](https://mc-e.github.io/project/ReVideo/) achieves accurate video editing by modifying the first frame and applying motion control within the edited area. Although it achieves excellent video editing performance, the editing length is limited by the base model [SVD](https://github.com/Stability-AI/generative-models). Open-Sora, as a fundamental model for long-video generation, can compensate for this issue. Currently, we are collaborating with the ReVideo team to use Open-Sora as the base model for long video editing. Some preliminary results are shown [here](). The initial version still needs improvement in several aspects. In the future, we will continue to explore integration with ReVideo to develop improved long-video editing models. ## Failed Case and Discussion Despite the promising results of v1.1.0, there remains a gap between our model and Sora. Here, we present some failure cases and discuss them. ### CasualVideoVAE Despite the significant performance improvement of VAE in v1.1.0 over v1.0.0, we still encounter failures in challenging cases, such as sand dunes and leaves. The video on the left shows the reconstructed video downsampled by a factor of 4 in time, while the video on the right is downsampled by a factor of 2. Both exhibit jitter when reconstructing fine-grained features. This indicates that reducing temporal downsampling alone cannot fully resolve the jitter issue. https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/1a87d6d8-4bf1-4b4e-83bb-84870c5c3a11 https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/1a87d6d8-4bf1-4b4e-83bb-84870c5c3a11 ### Diffusion Model #### Semantic distortion On the left is a video generated by v1.1.0 showing a puppy in the snow. In this video, the puppy's head exhibits semantic distortion, indicating that the model struggles to correctly identify which head belongs to which dog. On the right is a video generated by Sora's [base model](https://openai.com/index/video-generation-models-as-world-simulators/). We observe that Sora's early base model also experienced semantic distortion issues. This suggests that we may achieve better results by scaling up the model and increasing the amount of training data. Prompt:A litter of golden retriever puppies playing in the snow.Their heads pop out of the snow, covered in. | Our | Sora Base×1 | Sora Base×4 | Sora Base×32 | |---|---|---|---| | | | | | #### Limited dynamics The primary difference between videos and images lies in their dynamic nature, where objects undergo a series of changes across consecutive frames. However, the videos generated by v1.1.0 still contain many instances of limited dynamics. Upon reviewing a large number of training videos, we found that while web-crawled videos have high visual quality, they are often filled with meaningless close-up shots. These close-ups typically show minimal movement or are even static. On the left, we present a generated video of a bird, while on the right is a training video we found, which is almost static. There are many similar videos in the dataset from stock footage sites. Prompt:This close-up shot of a Victoria crowned pigeon showcases its striking blue plumage and red chest. Its crest is made of delicate, lacy feathers, while its eye is a striking red color. The bird's head is tilted slightly to the side,giving the impression of it looking regal and majestic. The background is blurred,drawing attention to the bird's striking appearance. | Our | Raw video | |---|---| | | | #### Negative prompt We found that using negative prompts can significantly improve video quality, even though we did not explicitly tag the training data with different labels. On the left is a video sampled using a negative prompt, while on the right is a video generated without a negative prompt. This suggests that we may need to incorporate more prior knowledge into the training data. For example, when a video has a watermark, we should note "watermark" in the corresponding caption. When a video's bitrate is too low, we should add more tags to distinguish it from high-quality videos, such as "low quality" or "blurry." We believe that explicitly injecting these priors can help the model differentiate between the vast amounts of pretraining data (low quality) and the smaller amounts of fine-tuning data (high quality), thereby generating higher quality videos. Prompt:A litter of golden retriever puppies playing in the snow.Their heads pop out of the snow, covered in. Negative Prompt:distorted, discontinuous, ugly, blurry, low resolution, motionless, static, low quality | With Negative Prompt | Without Negative Prompt | |---|---| | | | ## Future Work In our future work, we will focus on two main areas: (1) data scaling and (2) model design. Once we have a robust baseline model, we will extend it to handle variable durations and conditional control models. ### Data Scaling #### Data source As mentioned earlier, our dataset is entirely sourced from stock footage websites. Although these videos are of high quality, many consist of close-up shots of specific areas, resulting in slow motion in the videos. We believe this is one of the main reasons for the limited dynamics observed. Therefore, we will continue to collect datasets from diverse sources to address this issue. #### Data volume In v1.1.0, our dataset comprises only ~3k hours of video. We are actively collecting more data and anticipate that the video dataset for the next version will reach ~100k hours. We welcome recommendations from the open-source community for additional datasets. ### Model Design #### CasualVideoVAE In our internal testing, even without downsampling in time, we found that it is not possible to completely resolve the jitter issue in reconstructing fine-grained features. Therefore, we need to reconsider how to mitigate video jitter to the greatest extent possible while simultaneously supporting both images and videos. We will introduce a more powerful CasualVideoVAE in the next version. #### Diffusion Model In v1.1.0, we found that 2+1D models can generate higher-quality videos in short durations. However, for long videos, they tend to exhibit discontinuities and inconsistencies. Therefore, we will explore more possibilities in model architecture to address this issue. ================================================ FILE: docs/Report-v1.2.0.md ================================================ # Report v1.2.0 In May 2024, we launched Open-Sora-Plan v1.1.0, featuring a 2+1D model architecture that could be quickly utilized for exploratory training in text-to-video generation tasks. However, when handling dense visual tokens, the 2+1D architecture could not simultaneously process spatial and temporal dimensions. Therefore, we transitioned to **a 3D full attention architecture**, which better captures the joint spatial-temporal features. Although this version is experimental, it advances video generation architecture to a new realm, leading us to release it as v1.2.0. Compared to previous video generation models, Open-Sora-Plan v1.2.0 offers the following improvements: 1. **Better compressed visual representations**. We optimized the structure of CausalVideoVAE, which now delivers enhanced performance and higher inference efficiency. 2. **Better video generation architecture**. Instead of 2+1D, we use a diffusion model with a 3D full attention architecture, which provides a better understanding of the world. ### Open-Source Release We open-source the Open-Sora-Plan to facilitate future development of Video Generation in the community. Code, data, model are made publicly available. - Code: All training scripts and sample scripts. - Model: Both Diffusion Model and CausalVideoVAE [here](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0). - Data: Filtered data [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0). ## Gallery 93×1280×720 Text-to-Video Generation. The video quality has been compressed for playback on GitHub.
## Detailed Technical Report ### CausalVideoVAE #### Model Structure The VAE in version 1.2.0 maintains the overall architecture of the previous version but merges the temporal and spatial downsampling layers. In version 1.1.0, we performed spatial downsampling (stride=1,2,2) followed by temporal downsampling (stride=2,1,1). In version 1.2.0, we conduct both spatial and temporal downsampling simultaneously (stride=2,2,2) and perform spatial-temporal upsampling in the decoder (interpolate_factor=2,2,2). Due to the absence of additional convolutions during downsampling and upsampling, this method more seamlessly inherits the weights from the SD2.1 VAE, leading to improved initialization of our VAE. #### Training Details As with v1.1.0, we initialize from the [SD2.1 VAE](https://huggingface.co/stabilityai/sd-vae-ft-mse) using tail initialization. We perform the first phase of training on the Kinetic400 video dataset, then use the EMA weights from this phase to initialize the second phase, which is fine-tuned on high-quality data (collected in v1.1.0). All training is conducted on 25-frame 256×256 videos using **one A100 node**. | Training stage | Dataset | Training steps | |---|---|---| | 1 | K400 | 200,000 | | 2 | collected in v1.1.0 | 450,000 | #### Evaluation We evaluated our VAE on the validation sets of two video datasets: [Webvid](https://github.com/m-bain/webvid) and [Panda70m](https://github.com/snap-research/Panda-70M/), and compared it with our [v1.1.0](https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/docs/Report-v1.1.0.md), [SD2.1 VAE](https://huggingface.co/stabilityai/sd-vae-ft-mse), [CV-VAE](https://github.com/AILab-CVC/CV-VAE), and [Open-Sora's VAE](https://github.com/hpcaitech/Open-Sora). The Webvid validation set contains 5k videos, while the Panda70m validation set has 6k videos. The videos were resized to 256 pixels on the short side, center-cropped to 256x256, and then 33 consecutive frames were extracted. We used PSNR, SSIM, and LPIPS metrics, and measured the encoding speed on an A100 GPU. The specific results are as follows: **WebVid** | Model | Compress Ratio |PNSR↑ | SSIM↑ |LPIPS↓ | |---|---|---|---|---| | SD2-1 VAE | 1x8x8 | 30.19 | 0.8379 | 0.0568 | | SVD VAE | 1x8x8 |31.15 |0.8686 | **0.0547** | | CV-VAE | 4x8x8 | 30.76 | 0.8566 | 0.0803 | | Open-Sora VAE | 4x8x8 | 31.12 | 0.8569 | 0.1003 | | Open-Sora Plan v1.1 | 4x8x8 | 30.26 | 0.8597 |0.0551 | | Open-Sora Plan v1.2 | 4x8x8| **31.16** | **0.8694** | 0.0586 | **Panda70M** | Model | Compress Ratio| PNSR↑ | SSIM↑ |LPIPS↓ | |---|---|---|---|---| | SD2-1 VAE | 1x8x8 |30.40 | 0.8894 | 0.0396 | | SVD VAE | 1x8x8 |31.00 | **0.9058** | **0.0379** | | CV-VAE | 4x8x8| 29.57 | 0.8795 | 0.0673 | | Open-Sora VAE | 4x8x8 | **31.06** | 0.8969 | 0.0666 | | Open-Sora Plan v1.1 | 4x8x8 | 29.16 | 0.8844 | 0.0481 | | Open-Sora Plan v1.2 | 4x8x8| 30.49 |0.8970 |0.0454| **Encode Time on A100** |Input Size| CV-VAE | Open-Sora | Open-Sora Plan v1.1 | Open-Sora Plan v1.2 | |---|---|---|---|---| | 33x256x256 | 0.186 | 0.147 |0.104 | **0.102** | | 81x256x256 | 0.465 | 0.357 |0.243 | **0.242** | ### Training Text-to-Video Diffusion Model #### Model Structure The most significant change is that we **replaced all 2+1D Transformer blocks with 3D full attention blocks**. Each video is first processed by a patch embedding layer, which downsamples the spatial dimensions by a factor of 2. The video is then flattened into a one-dimensional sequence across the frame, width, and height dimensions. We replaced [T5-XXL](https://huggingface.co/DeepFloyd/t5-v1_1-xxl) with [mT5-XXL](https://huggingface.co/google/mt5-xxl) to enhance multilingual adaptation. Additionally, we incorporated RoPE. ### Sequence Parallelism Due to the high computational complexity of 3D full attention, we must allocate a video across 2 GPUs for parallel processing when training with long-duration and high-resolution videos. We can control the number of GPUs used for a video sample by adjusting the batch size on a node. For example, with `sp_size=8` and `train_sp_batch_size=4`, 2 GPUs are used for a single sample. **We support sequence parallelism for both training and inference**. **Training on 93×720p**, we report speed on H100. | GPU (sp_size) | batch size | Enable sp | Train_sp_batch_size | Speed | Step per day | |---|---|---|---|---|---| |8|8|×|-|100s/step|~850| |8|-|√|4|53s/step|~1600| |8|-|√|2|27s/step|~3200| **Inference on 93×720p**, we report speed on H100. | Size | 1 GPU | 8 GPUs | |---|---|---| |29×720p|420s/100step|80s/100step| |93×720p|3400s/100step|450s/100step| #### Dynamic training Deep neural networks are typically trained using batched inputs. For efficient hardware processing, batch shapes are fixed, leading to a fixed data size. This requires either cropping or padding images to a uniform size, both of which have drawbacks: cropping degrades performance, while padding is inefficient and results in significant information loss. Generally, there are three methods for training with arbitrary token counts: Patch n' Pack, bucket, and pad-mask. **Patch n' Pack** ([NaViT](https://arxiv.org/abs/2307.06304)): bypasses the fixed sequence length limitation by combining tokens from multiple samples into a new sample. This approach allows variable-resolution images while maintaining aspect ratios by packaging multiple samples together, thereby reducing training time and enhancing performance and flexibility. However, this method involves significant code modifications and requires re-adaptation when exploring different model architectures in fields with unstable model designs. **Bucket** ([Pixart-alpha](https://arxiv.org/abs/2310.00426), [Open-Sora](https://github.com/hpcaitech/Open-Sora)): This method packages data of different resolutions into buckets, sampling batches from each bucket to ensure same resolution within each batch. It requires minimal code modifications to the model, mainly adjusting the data sampling strategy. **Pad-mask** ([FiT](https://arxiv.org/abs/2402.12376), our v1.0/v1.1): This method sets a maximum resolution and pads all data to this resolution, generating a corresponding mask. Although the approach is straightforward, it is computationally inefficient. We believe that current video generation models are still in an exploratory phase. Extensive modifications to model code during this period can incur unnecessary development costs. The pad-mask method, while straightforward, is computationally inefficient and can waste resources in video, which involves dense computations. Ultimately, we chose the bucket strategy, which requires no modifications to the model code. Next, we will explain how our bucket strategy supports arbitrary lengths and resolutions. For simplicity, we will use video duration as an example: We define a megabatch as the total data processed in a single step across all GPUs. A megabatch can be divided into multiple batches, with each batch corresponding to the data processed by a single GPU. **Sort by frame**: The first step is to count the number of frames in all video data and sort them. This step aims to group similar data together, with sorting being one method to achieve this. **Group megabatch**: Next, all data is divided into groups, each forming a megabatch. Since all data is pre-sorted, most videos within a megabatch have the same number of frames. However, there will always be boundary cases, such as having both 61-frame and 1-frame videos in a single megabatch. **Re-organize megabatch**: We re-organize these special megabatches, which actually constitute a small proportion. We randomly replace the minority data in the megabatch with the majority data, thus re-organizing it into a megabatch with same frame counts. **Shuffle megabatch**: To ensure data randomness, we shuffle both within each megabatch and between different megabatches. When supporting dynamic resolutions, we simply replace each sample's frame sequence with (frame × height × width). This method ensures that the data dimension processed by each GPU in every step is the same, preventing situations where GPU1 waits for GPU0 to finish processing a longer video. Moreover, it is entirely decoupled from the model code, serving as a plug-and-play video sampling strategy. #### Training stage Similar to previous work, we use a multi-stage training approach. With the 3D DiT architecture, all parameters can be transferred from images to videos without loss. To explore training costs, all parameters of the diffusion model are trained from scratch. Therefore, we first train an text-to-image model, using the training strategy from [Pixart-alpha](https://arxiv.org/abs/2310.00426). The video model is initialized with weights from a 480p image model. We first train 480p videos with 29 frames. Next, we adapt the weights to 720p resolution, training on approximately 6 million higher-quality (HQ) samples from Panda70M, filtered for aesthetic quality and motion. Finally, we refine the model with a more higher-quality (HQ) subset of 1 million samples. After that, we use a filtered data (collected in v1.1.0) for fine-tuning 93-frame 720p videos. Below is our training card. We release the annotation file [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/anno_json). | Name | Stage 1 | Stage 2 | Stage 3 | Stage 4 |Stage 5 | |---|---|---|---|---|---| | Training Video Size | 1×320×240 | 1×640×480 | 29×640×480 | 29×1280×720 | 93×1280×720 | | Training Step| 146k | 200k | 30k | 21k | 3k | | Compute (#Num x #Hours) | 32 Ascend × 81 | 32 Ascend × 142 | 128 Ascend × 38 | 256 H100 × 64 | 256 H100 × 84 | | Checkpoint | - | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/1x480p) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/29x480p) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/29x720p) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/93x720p) | | Log | - | - | [wandb](https://api.wandb.ai/links/1471742727-Huawei/trdu2kba) | [wandb](https://api.wandb.ai/links/linbin/vvxvcd7s) | [wandb](https://api.wandb.ai/links/linbin/easg3qkl) | Training Data | [10M SAM](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0/blob/main/anno_json/sam_image_11185255_resolution.json) | 5M internal image data | [6M HQ Panda70M](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0/blob/main/anno_json/Panda70M_HQ6M.json) | [6M HQ Panda70M](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0/blob/main/anno_json/Panda70M_HQ6M.json) | [1M HQ Panda70M](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0/blob/main/anno_json/Panda70M_HQ1M.json) and [100k HQ data](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/anno_json) (collected in v1.1.0) | Additionally, we fine-tuned 3.5k steps from the final 93×720p to get [93×480p](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/93x480p) for community research use. ### Training Image-to-Video Diffusion Model #### Model Structure To reuse the weights of the Text-to-Video model, our Image-to-Video model is inspired by the Stable Diffusion Inpainting Model and adopts a strategy based on frame-level inpainting. By incorporating three types of information—original noise, masked video, and mask—under different control frame conditions, our model can generate coherent videos while ensuring flexibility in its usage. Compared to the denoiser structure of the Text-to-Video model, the Inpainting Model's denoiser has only changed the number of channels in the `conv in` layer. To ensure the model has a good prior knowledge, we introduce the masked video and mask information through zero initialization. We believe this is due to the 2+1D structure's lack of ability to establish long-range information dependencies, and relying solely on attention in the temporal dimension makes it difficult to capture information changes under frame control. In Text-to-Video tasks, this phenomenon is not as evident because all frames share the same text prompt embedding. However, in Image-to-Video tasks, simply concatenating images in the channel dimension does not ensure the model can accurately capture changes between frames. This is because the model cannot directly replicate image information from the channels to reduce the loss, and the 2+1D structure's interaction solely on the temporal axis fails to allow the model to discern which information from the control frames can be utilized, especially there are significant differences between frames. Therefore, without a shared image-semantic information, the control frame information might not be effectively conveyed to each frame. ##### About Semantic Adapter In previous models based on the Unet 2+1D architecture, it is necessary to input the control frames into the CLIP model to obtain semantic embeddings. These semantic embeddings are then injected into the denoiser through cross-attention. The structure that extracts CLIP embeddings and injects them into the denoiser is commonly referred to as a semantic adapter. In the 2+1D architecture, the semantic adapter is commonly present. Additionally, papers like [DynamiCrafter](https://arxiv.org/abs/2310.12190) have pointed out that incorporating the semantic adapter helps maintain stability in the generated videos. We believe this is because the 2+1D structure lacks the ability to establish long-range information dependencies, and relying solely on attention in the temporal dimension makes it difficult to capture information changes under frame control. In the Text-to-Video task, this phenomenon is not as evident because all frames share the same text prompt embedding. However, in the Image-to-Video task, without shared semantic information, it may lead to the inability to effectively transfer control frame information to each individual frame.
We conducted a simple comparison of the performance of using the Inpainting Model under the 2+1D structure (Open-Sora Plan v1.1, left in the figure) versus the 3D structure (Open-Sora Plan v1.2, right in the figure). With the same number of optimization steps, the probability of unstable visual performance in the 2+1D structure was significantly higher than in the 3D structure. Even at convergence, the 2+1D structure's visual stability was still inferior to that of the 3D structure, and it was even worse than the early training stages of the 3D structure. ## Future Work and Discussion #### CausalVideoVAE We observed that high-frequency motion information in videos tends to exhibit jitter, and increasing training duration and data volume does not significantly alleviate this issue. In videos, compressing the duration while maintaining the original latent dimension can lead to significant information loss. A more robust VAE will be released in the next version. #### Diffusion Model We replaced T5 with mT5 to enhance multilingual capabilities, but this capability is limited as our training data is currently only in English. The multilingual ability primarily comes from the mT5 mapping space. We will explore additional text encoders and expand the data in the next steps. Our model performs well in generating character consistency, likely due to panda70m being a character-centric dataset. However, it still shows poor performance in text consistency and object generalization. We suspect this may be due to the limited amount of data the model has seen, as evidenced by the non-convergence of the loss in the final stage. **We hope to collaborate with the open-source community to optimize the 3D DiT architecture.** ================================================ FILE: docs/Report-v1.3.0.md ================================================ # Report v1.3.0 In August 2024, we released Open-Sora-Plan v1.2.0, transitioning to a 3D full attention architecture, which enhanced the capture of joint spatial-temporal features. However, the substantial computational cost made it unsustainable, and the lack of a clear training strategy hindered continuous progress along a focused path. In version 1.3.0, Open-Sora-Plan introduced the following five key features: **1. A more powerful and cost-efficient WFVAE.** We decompose video into several sub-bands using wavelet transforms, naturally capturing information across different frequency domains, leading to more efficient and robust VAE learning. **2. Prompt Refiner.** A large language model designed to refine short text inputs. **3. High-quality data cleaning strategy.** The cleaned panda70m dataset retains only 27% of the original data. **4. DiT with new sparse attention.** A more cost-effective and efficient learning approach. **5. Dynamic resolution and dynamic duration.** This enables more efficient utilization of videos with varying lengths (treating a single frame as an image). ### Open-Source Release We open-source the Open-Sora-Plan to facilitate future development of Video Generation in the community. Code, data, model will be made publicly available. - Code: All training scripts and sample scripts. - Model: Both Diffusion Model and CasualVideoVAE [here](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0). - Data: The data of prompt refiner is [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner). ## Gallery Text & Image to Video Generation. [![Demo Video of Open-Sora Plan V1.3](https://github.com/user-attachments/assets/4ff1d873-3dde-4905-a907-dbff51174c20)](https://www.bilibili.com/video/BV1KR2fYPEF5/?spm_id_from=333.999.0.0&vd_source=cfda99203e659100629b465161f1d87d) ## Detailed Technical Report ### WF-VAE As video generation models move toward higher resolutions and longer durations, the computational cost of video VAEs grows exponentially, becoming unsustainable. Most related work addresses this by using tiling to reduce inference memory consumption. However, in high-resolution, long-duration scenarios, tiling significantly increases inference time. Additionally, since tiling is lossy for latents, it can lead to visual artifacts such as shadows or flickering in the generated videos. Then, we introduce WFVAE, which provide a new model to handle these problems. #### Model Structure
SCR-20241023-tzct
The compression rate fundamentally determines the quality of VAE-reconstructed videos. We analyzed the energy and entropy of different subbands obtained through wavelet transform and found that most of the energy in videos is concentrated in the low-frequency bands. Moreover, by replacing the `LLL` subband of the VAE-reconstructed video with the original video's `LLL` subband, we observed a significant improvement in the spatiotemporal quality of the videos.
In previous VAE architectures, the lack of a "highway" for transmitting the dominant energy during video compression meant that this pathway had to be gradually established during model training, leading to redundancy in model parameters and structure. Therefore, in our model design, we created a more efficient transmission path for the LLL subband energy, significantly simplifying the model architecture, reducing inference time, and lowering memory consumption. #### Training Details More details will be provided in the forthcoming paper. #### Ablation Study In our experiments, we used the K400 training and validation sets, conducted on 8xH100 GPUs. The latent dimension was fixed at 4. We observed that as model parameters increased, there was still room for improvement in reconstruction metrics. GroupNorm showed instability during training, performing worse than LayerNorm on PSNR but better on LPIPS.
#### Performance The following metrics were tested on H100 with float32 precision. For fairness, tiling was disabled for all models, and direct inference was performed.
SCR-20241023-tzwz
#### Evaluation We evaluated PSNR and LPIPS on the Panda70M test set at 256 pixels and 33 frames. In the open-source WF-VAE-S (8-dim), our encoder was distilled from the 8-dim OD-VAE, resulting in some metric degradation compared to direct training. | Latent Dim | Model | Params | PSNR | LPIPS | |---|---|---|---|---| | 4 | OD-VAE(Our VAE in v1.2.0) | 94M + 144M | 30.311| 0.043| | 4 | WFVAE-S | 38M + 108M | 30.579 | 0.044 | | 8 | WFVAE-S(Distillion) |38M + 108M | 31.764|0.050 | #### Causal Cache
To address the issue of tiling, we replaced GroupNorm with LayerNorm and introduced a novel method called **Causal Cache**, enabling lossless temporal block-wise inference. First, we replaced GroupNorm with LayerNorm and utilized the properties of CausalConv3D to achieve lossless inference through temporal dimension chunking. In each layer of CausalConv3D, we cache the information from the previous few frames to maintain continuity during the convolution sliding operation for the next temporal chunk, thereby enabling lossless processing. As illustrated, we use a kernel size of 3 and a stride of 1 as an example: **Initial Chunk (chunk idx=0):** For the first time chunk, we perform standard causal padding to support joint processing of graphs and videos. After the convolution operation, we cache the last two frames of this chunk into the causal cache in preparation for the next chunk's inference. **Subsequent Chunks (chunk idx=1 and beyond):** Starting from the second time chunk, we no longer use causal padding. Instead, we concatenate the cached causal information from the previous chunk to the front of the current chunk. We continue to cache the last two frames of the current input into the causal cache for use in subsequent chunks. ## Prompt Refiner User-provided captions are typically fewer than 10 words, whereas the text annotations in the current training data are often dense. This inconsistency between training and inference may result in poor visual quality and weak text alignment. We categorize captions into four types: (1) Short captions from real user input; we collected 11k from [COCO](https://cocodataset.org/#home). (2) Captions composed of multiple tags; we collected 5k from [DiffusionDB](https://github.com/poloclub/diffusiondb). (3) Medium-length captions generated by large language models; 3k sourced from [JourneyDB](https://github.com/JourneyDB/JourneyDB). (4) Ultra-long, surrealist captions, sourced from Sora/Vidu/Pika/Veo and approximately 0.5k generated by GPT. We used ChatGPT to rewrite the above captions, with the following instructions provided to ChatGPT: ``` rewrite the sentence to contain subject description action, scene description. Optional: camera language, light and shadow, atmosphere and conceive some additional actions to make the sentence more dynamic, make sure it is a fluent sentence, not nonsense. ``` Finally, we performed LoRA fine-tuning using [LLaMa 3.1](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct), completing the training in just 30 minutes with a single H100. We fine-tuned for only 1 epoch, using a batch size of 32 and a LoRA rank of 64. The log can be found [here](https://api.wandb.ai/links/1471742727-Huawei/p5xmkft5). We open-sourced the data [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner). ### Data Construction We randomly sampled from the original Panda70m dataset and found many videos to be static, contain multiple subtitles, or suffer from motion blur. Additionally, the captions in Panda70m did not always accurately describe the video content. To address this, we designed a video filtering pipeline, which retained approximately 27% of the videos after processing.
#### Jump Cut and Detect Motion We used [LPIPS](https://github.com/richzhang/PerceptualSimilarity) frame-skipping to compute inter-frame semantic similarity, identifying anomalies as cut points and taking the mean as the motion score. We found that videos with motion scores below 0.001 were nearly static, while those above 0.3 exhibited significant jitter and flicker. After applying this method, we manually reviewed 2k videos and concluded that the cut detection accuracy was sufficient for pre-training requirements. #### OCR We estimated the average position of subtitles on common video platforms to be around 18%. Consequently, we set the maximum crop threshold to 20% of the video's original dimensions and used [EasyOCR](https://github.com/JaidedAI/EasyOCR) to detect subtitles (sampling one frame per second). However, not all videos have subtitles or printed text located at the edges; this method may miss text appearing in central areas, such as in advertisement videos or speeches. Nonetheless, we cannot assume that the presence of text in a video necessitates filtering it out, as certain texts in specific contexts can be meaningful. We leave such judgments to aesthetic considerations. #### Aesthetic As before, we used the [Laion aesthetic predictor](https://github.com/christophschuhmann/improved-aesthetic-predictor) for evaluation. Based on the visualization [website](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html), we determined that a score of 4.75 serves as a suitable threshold, effectively filtering out excessive text while retaining high-quality aesthetics. We will add an additional aesthetic prompt, such as `A high-aesthetic scene, ` for data with a score above 6.25. #### Video Quality Some old photos or videos have very low bit rates, resulting in blurry visual effects even at 480P resolution, often resembling a mosaic appearance. Aesthetic filtering struggles to exclude these videos, as it resizes images to 224 resolution. We aim to establish a metric for assessing absolute video quality, independent of the visual content itself, focusing solely on compression artifacts, low bit rates, and jitter. We employed the technical prediction score from [DOVER](https://github.com/VQAssessment/DOVER) and excluded videos with scores below 0. #### Recheck Motion Since some videos contain subtitles, variations in the subtitles may lead to inaccurate motion values. Therefore, we re-evaluated the motion values and discarded static videos. #### Captioning We used [QWen2-VL-7B](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) for video annotation. ``` Please describe the content of this video in as much detail as possible, including the objects, scenery, animals, characters, and camera movements within the video. Do not include '\n' in your response. Please start the description with the video content directly. Please describe the content of the video and the changes that occur, in chronological order. ``` However, the 7B model tends to generate certain prefixes, such as "This video" or "The video." We compiled a list of all irrelevant opening strings and removed them. ``` 'The video depicts ', 'The video captures ', 'In the video, ', 'The video showcases ', 'The video features ', 'The video is ', 'The video appears to be ', 'The video shows ', 'The video begins with ', 'The video displays ', 'The video begins in ', 'The video consists of ', 'The video opens with ', 'The video opens on ', 'The video appears to capture ', 'The video appears to show ', "The video appears to depict ", "The video opens in ", "The video appears to focus closely on ", "The video starts with ", "The video begins inside ", "The video presents ", "The video takes place in ", "The video appears to showcase ", "The video appears to display ", "The video appears to focus on ", "The video appears to feature " ``` ### Training Text-to-Video Diffusion Model #### Framework ##### Skiparse (Skip-Sparse) Attention In video generation models, alternating 2+1D spatial-temporal blocks is a commonly used approach, yet these models lack long-range modeling, limiting their performance ceiling. Consequently, models like [CogVideoX](https://arxiv.org/abs/2408.06072), [Meta Movie Gen](https://ai.meta.com/research/movie-gen/), and Open-Sora Plan v1.2 employ **Full 3D Attention** as a denoiser, achieving substantially improved visual fidelity and motion quality compared to 2+1D models. This approach, however, requires calculating attention across all tokens in each clip encoding, which significantly raises training costs. For instance, Open-Sora Plan v1.2, training a 2.7-billion-parameter model, takes **100 seconds per step at 93x720p and over 15 seconds per step at 93x480p**, severely constraining scalability under limited computational resources. To accelerate training while ensuring adequate performance, we propose the **Skiparse (Skip-Sparse) Attention** method. Specifically, under a fixed sparse ratio $$k$$ , we organize candidate tokens for attention through two alternating skip-gather methods. This approach preserves the attention operation is global while effectively reducing FLOPS, enabling faster training of 3D Attention models. In our experiments, applying Skiparse with sparse ratio $$k=4$$ to a 2.7B model reduced training time to **42 seconds per step at 93x720p and 8 seconds per step at 93x480p**.
**Skiparse DiT modifies only the Attention component** within the Transformer Block, using two alternating Skip Sparse Transformer Blocks. With sparse ratio $$k$$, the sequence length in the attention operation reduces to $$\frac{1}{k}$$ of the original, and batch size increases by $$k$$-fold, lowering the theoretical complexity of self-attention to $$\frac{1}{k}$$ of the original, while cross-attention complexity remains unchanged. Due to GPU/NPU parallel processing, increasing the batch size by $$k$$-fold does not linearly decrease speed to $$\frac{1}{k}$$, resulting in a performance boost that exceeds theoretical expectations.
In Single Skip mode, the elements located at positions $$[0, k, 2k, 3k, ...]$$ , $$[1, k+1, 2k+1, 3k+1, ...]$$ , ..., $$[k-1, 2k-1, 3k-1, ...]$$ are grouped into the same scope (with each list forming one scope of elements). The figure above, using $$k=2$$ as an example, illustrates this organizational structure. This concept is straightforward, as each token performs attention with tokens spaced $$k-1$$ apart.
In Group Skip mode, elements at positions $$[(0, 1, ..., k-1), (k^2, k^2+1, ..., k^2+k-1), (2k^2, 2k^2+1, ..., 2k^2+k-1), ...]$$ , $$[(k, k+1, ..., 2k-1), (k^2+k, k^2+k+1, ..., k^2+2k-1), (2k^2+k, 2k^2+k+1, ..., 2k^2+2k-1), ...]$$ , ..., $$[(k^2-k, k^2-k-1, ..., k^2-1), (2k^2-k, 2k^2-k-1, ..., 2k^2-1), (3k^2-k, 3k^2 -k-1, ..., 3k^2-1), ...]$$ are grouped together as a scope (with each list forming a scope). This arrangement may seem complex numerically, so it can be helpful to understand with the above figure. In this pattern, we first **group adjacent tokens** in segments of length $$k$$ , then **bundle these groups** with other groups that are spaced $$k-1$$ groups apart into a single scope.For example, in $$[(0, 1, ..., k-1), (k^2, k^2+1, ..., k^2+k-1), (2k^2, 2k^2+1, ..., 2k^2+k-1), ...]$$ , each set of indices in parentheses represents a group. Each group is then connected with another group that is offset by $$k-1$$ groups, forming one scope. Since the last index of the first group is $$k-1$$ , the first token in the next group to be linked will be at index $$k-1+k(k-1)+1=k^2$$ . Following this pattern, you can determine the indices for each scope in this configuration. ##### Why "Skiparse"? The 2+1D DiT models temporal understanding only along the time axis of a single spatial location, theoretically and practically limiting performance. In real-world scenarios, changes at a specific spatial location are typically influenced not by prior content at that same location but by content across all spatial locations at preceding times. This constraint makes it challenging for 2+1D DiT to model complex physical dynamics accurately. Full 3D Attention represents global attention, allowing any spatial position at any time to access information from any other position across all times, aligning well with real-world physical modeling. However, this approach is time-consuming and inefficient, as visual information often contains considerable redundancy, making it unnecessary to establish attention across all spatiotemporal tokens. **A ideal spatiotemporal modeling approach should employ attention that minimizes the overhead from redundant visual information while capturing the complexities of the dynamic physical world**. Reducing redundancy requires avoiding connections among all tokens, yet global spatiotemporal attention remains essential for modeling complex physical interactions. To achieve a balance between 2+1D efficiency and Full 3D’s strong spatiotemporal modeling, we developed Skiparse Attention. This approach provides global spatiotemporal attention within each block, with each block having the same “receptive field”. The use of "group" operations also introduces a degree of locality, aligning well with visual tasks. Interestingly, once you understand the Skiparse Attention mechanism, you’ll notice that **the attention in 2+1D DiT corresponds to a sparse ratio of $$k=HW$$ (since $$T \ll HW$$ , making the "skip" in Group Skip negligible), while Full 3D DiT corresponds to a sparse ratio of $$k=1$$.** In Skiparse Attention, $$k$$ is typically chosen to be close to 1, yet far smaller than $$HW$$ , making it a 3D Attention that approaches the effectiveness of Full 3D Attention. In Skiparse Attention, Single Skip is a straightforward operation, easily understood by most. Within Group Skip, the Group operation is also intuitive, serving as a means to model local information. However, **Group Skip involves not only grouping but also skipping**—particularly between groups—which is often overlooked. This oversight frequently leads researchers to confuse Skiparse Attention with a Skip + Window Attention approach. The key difference lies in even-numbered blocks: Window Attention only groups tokens without skipping between groups. The distinctions among these attention methods are illustrated in the figure below, which shows the attention scopes for self-attention only, with dark tokens representing the tokens involved in each attention calculation.
To deeply understand why nearly global attention is necessary and why Skiparse Attention theoretically approximates Full 3D Attention more closely than other common methods, we introduce the concept of **Average Attention Distance**. This concept is defined as follows: for any two tokens, if it takes $$m$$ attention operations to establish a connection between them, the attention distance is $$m$$ . The average attention distance for a tensor is then the mean of the attention distances across all token pairs, representing the corresponding attention method’s overall connectivity efficiency. The average attention distance of all tokens within a tensor is defined as the average attention distance for that particular attention method. For example, in Full 3D Attention, any token can connect with any other token in just one attention operation, resulting in an average attention distance of 1. In 2+1D Attention, the process is somewhat more complex, though still straightforward to understand. In all configurations above, any two different tokens can connect with an attention distance between 1 and 2 (Note that we define the attention distance between a token and itself as zero). Thus, for the other three attention methods, we can first identify which tokens have an attention distance of 1. Subsequently, tokens with an attention distance of 2 can be determined, allowing us to calculate the average attention distance. In the $$2N$$ Block, attention operates over the $$(H, W)$$ dimensions, where tokens within this region have an attention distance of 1. In the $$2N+1$$ Block, attention operates along the $$(T)$$ dimension, also assigning an attention distance of 1 for these tokens. The total number of tokens with an attention distance of 1 in this case is $$HW + T - 2$$ (excluding the token itself, hence $$(HW + T - 1) - 1 = HW + T - 2$$). Therefore, in 2+1D Attention, the average attention distance (AVG Attention Distance) is: $$ \begin{aligned} d&=\frac{1}{THW}\left[ 1\times 0+\left( HW+T-2 \right) \times 1+\left[ THW-\left( HW+T-1 \right) \right] \times 2 \right]\\ &=2-\left( \frac{1}{T}+\frac{1}{HW} \right)\\ \end{aligned} $$ In Skip+Window Attention, aside from the token itself, there are $$\frac{THW}{k} - 1$$ tokens with an attention distance of 1 in the $$2N$$ Block, and $$k - 1$$ tokens with an attention distance of 1 in the $$2N+1$$ Block. Thus, the total number of tokens with an attention distance of 1 is $$\frac{THW}{k} + k - 2$$. Therefore, in Skip+Window Attention, the average attention distance (AVG Attention Distance) is: $$ \begin{aligned} d&=\frac{1}{THW}\left[ 1\times 0+\left( \frac{THW}{k}+k-2 \right) \times 1+\left[ THW-\left( \frac{THW}{k}+k-1 \right) \right] \times 2 \right]\\ &=2-\left( \frac{1}{k}+\frac{k}{THW} \right)\\ \end{aligned} $$ In Skiparse Attention, aside from the token itself, $$\frac{THW}{k} - 1$$ tokens have an attention distance of 1 in the $$2N$$ Block, and $$\frac{THW}{k} - 1$$ tokens have an attention distance of 1 in the $$2N+1$$ Block. Notably, $$\frac{THW}{k^2} - 1$$ tokens can establish an attention distance of 1 in both blocks and should not be counted twice. Therefore, in Skiparse Attention, the average attention distance (AVG Attention Distance) is: $$ \begin{aligned} d&=\frac{1}{THW}\left[ 1\times 0+\left[ \frac{2THW}{k}-2-\left( \frac{THW}{k^2}-1 \right) \right] \times 1+\left[ THW-\left( \frac{2THW}{k}-\frac{THW}{k^2} \right) \right] \times 2 \right]\\ &=2-\frac{2}{k}+\frac{1}{k^2}-\frac{1}{THW}\\ &=2-\frac{2}{k}+\frac{1}{k^2}\left( 1\ll THW \right)\\ \end{aligned} $$ In fact, in the Group Skip of the $$2N+1$$ Block, the actual sequence length is $$k\lceil \frac{THW}{k^2} \rceil$$ rather than $$\frac{THW}{k}$$. The prior calculation assumes the ideal case where $$k \ll THW$$ and $$k$$ divides $$THW$$ exactly, yielding $$k\lceil \frac{THW}{k^2} \rceil = k \cdot \frac{THW}{k^2} = \frac{THW}{k}$$. In practical applications, excessively large $$k$$ values are typically avoided, making this derivation a reasonably accurate approximation for general use. Specifically, when $$k = HW$$ and padding is disregarded, since $$T \ll HW$$, group skip attention reduces to window attention with a window size of $$HW$$. Given that padding does not affect the final computation, Skiparse Attention is equivalent to 2+1D Attention when $$k = HW$$. For the commonly used resolution of 93x512x512, using a causal VAE with a 4x8x8 compression rate and a DiT with a 1x2x2 patch embedding, we obtain a latent shape of 24x32x32 before applying attention. The AVG Attention Distance for different calculation methods would then be as follows: | | Full 3D Attention | 2+1D Attention | | ---------------------- | ----------------- | --------------- | | AVG Attention Distance | 1 | 1.957 | | | Skip + Window Attention(k=2) | Skip + Window Attention(k=4) | Skip + Window Attention(k=6) | Skip + Window Attention(k=8) | | ---------------------- | ---------------------------- | ---------------------------- | ---------------------------- | ---------------------------- | | AVG Attention Distance | 1.500 | 1.750 | 1.833 | 1.875 | | | Skiparse Attention(k=2) | Skiparse Attention(k=4) | Skiparse Attention(k=6) | Skiparse Attention(k=8) | | ---------------------- | ----------------------- | ----------------------- | ----------------------- | ----------------------- | | AVG Attention Distance | 1.250 | 1.563 | 1.694 | 1.766 | In 2+1D Attention, the average attention distance is 1.957, larger than that of Skip + Window Attention and Skiparse Attention at commonly used sparse ratios. While Skip + Window Attention achieves a shorter average attention distance, its modeling capacity remains limited due to the locality of attention in its 2N+1 blocks. Skiparse Attention, with the shortest average attention distance, applies global attention in both 2N and 2N+1 blocks, making its spatiotemporal modeling capabilities closer to Full 3D Attention than the other two non-Full 3D methods.
The figure above shows how Skiparse Attention’s AVG Attention Distance changes with sparse ratio $$k$$. We can summarize the characteristics of these attention types as follows: | | Full 3D Attention | 2+1D Attention | Skip + Window Attention | Skiparse Attention | | ---------------------------------- | ----------------- | -------------------------------- | --------------------------------------- | ------------------------------------------------------------ | | Speed | Slow | Fast | Depending on $$k$$ | Depending on $$k$$ | | Spatiotemporal modeling capability | Strong | Weak | Weak | Approaches Full 3D | | Is attention global? | Yes | No | Half of the attention blocks are global | Yes | | Computation load per block | Equal | Not Equal | Not Equal | Equal | | AVG Attention Distance | 1 | $$2-(\frac{1}{T}+\frac{1}{HW})$$ | $$2-(\frac{1}{k}+\frac{k}{THW})$$ | $$2-\frac{2}{k}+\frac{1}{k^2},1
#### Dynamic training Overall, we maintained the bucket strategy from v1.2.0, pre-defining the shape of each video during training and aggregating data of the same shape through a sampler. Finally, the dataloader retrieves data based on our aggregated indices. In our early implementation, we specified `--max_width`, `--max_height`, `--min_width`, and `--min_height`. While this allows for specifying arbitrary resolutions within a certain range, this approach can easily lead to OOM issues during video training. For instance, for a 720P (720×1280) video, if the maximum dimensions are set to 720, the video would be scaled to 405×720. However, if there are square videos with resolutions greater than 720, they would be scaled to 720×720. Most videos are non-square, and to prevent OOM, we need to reserve GPU memory, which leads to significant computational waste. Therefore, we recommend using `--max_token` and `--min_token` to limit any range, as this aligns better with the Transformer architecture. #### Training scheduler We replaced the eps-pred loss with v-pred loss and enable ZeroSNR. For videos, we resample to 16 FPS for training. **Stage 1**: We initially initialized from the image weights of version 1.2.0 and trained images at a resolution of 1x320x320. The objective of this phase was to fine-tune the 3D dense attention model to a sparse attention model. The entire fine-tuning process involved approximately 100k steps, with a batch size of 1024 and a learning rate of 2e-5. The image data was primarily sourced from SAM in version 1.2.0. **Stage 2**: We trained the model jointly on images and videos, with a maximum resolution of 93x320x320. The entire fine-tuning process involved approximately 300k steps, with a batch size of 1024 and a learning rate of 2e-5. The image data was primarily sourced from SAM in version 1.2.0, while the video data consisted of the unfiltered Panda70m. In fact, the model had nearly converged around 100k steps, and by 300k steps, there were no significant gains. Subsequently, we performed data cleaning and caption rewriting, with further data analysis discussed at the end. **Stage 3**: We fine-tuned the model using our filtered Panda70m dataset, with a fixed resolution of 93x352x640. The entire fine-tuning process involved approximately 30k steps, with a batch size of 1024 and a learning rate of 1e-5. ### Training Image-to-Video Diffusion Model #### Framework
In terms of framework, Open-Sora Plan v1.3 continues to use the Inpainting model architecture from Open-Sora Plan v1.2. #### Data processing For data processing, Open-Sora Plan v1.3 introduces two new mask types: an all-1 mask and an all-0 mask. This brings the total number of mask types in the Inpainting Model to six.
In the figure above, black indicates retained frames, while white denotes discarded frames. The corresponding frame strategies are as follows: - **Clear**: Retain all frames. - **T2V**: Discard all frames. - **I2V**: Retain only the first frame; discard the rest. - **Transition**: Retain only the first and last frames; discard the rest. - **Continuation**: Retain the first $$n$$ frames; discard the rest. - **Random**: Retain $$n$$ randomly selected frames; discard the rest. #### Progressive training The Open-Sora Plan v1.3 uses more data for training and employs a progressive training approach to help the model understand frame-based inpainting tasks. Since the Inpainting Model supports various mask inputs, different mask inputs correspond to tasks of varying difficulty levels. Therefore, we can first teach the model simple tasks, such as random masks, allowing it to develop a basic capability for frame-based inpainting before gradually increasing the proportion of more challenging tasks. It is important to note that at different training stages, we ensure that at least 5% of the data the model sees pertains to T2V tasks, which is aimed at enhancing the model's understanding of prompts. The model weights are initialized from the T2V model with zero initialization. The batch size is fixed at 256, and the learning rate is set to 1e-5, using a two-stage training approach. **Stage 1**: Any resolution and duration within 93x102400 (320x320), using unfiltered motion and aesthetic low-quality data: (1) Step 1: t2v 10%, continuation 40%, random mask 40%, clear 10%. Ensure that at least 50% of the frames are retained during continuation and random mask, training with 4 million samples. (2) Step 2: t2v 10%, continuation 40%, random mask 40%, clear 10%. Ensure that at least 25% of the frames are retained during continuation and random mask, training with 4 million samples. (3) Step 3: t2v 10%, continuation 40%, random mask 40%, clear 10%. Ensure that at least 12.5% of the frames are retained during continuation and random mask, training with 4 million samples. (4) Step 4: t2v 10%, continuation 25%, random mask 60%, clear 5%. Ensure that at least 12.5% of the frames are retained during continuation and random mask, training with 4 million samples. (5) Step 5: t2v 10%, continuation 25%, random mask 60%, clear 5%, training with 8 million samples. (6) Step 6: t2v 10%, continuation 10%, random mask 20%, i2v 40%, transition 20%, training with 16 million samples. (7) Step 7: t2v 5%, continuation 5%, random mask 10%, i2v 40%, transition 40%, training with 10 million samples. **Stage 2:** Any resolution and duration within 93x236544 (e.g., 480x480, 640x352, 352x640), using filtered motion and aesthetic high-quality data: t2v 5%, continuation 5%, random mask 10%, i2v 40%, transition 40%, training with 15 million samples. #### About the Semantic Adapter We conducted further experiments on the Semantic Adapter module and compared the video quality of Image-to-Video under various Image Encoders, including [Clip](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K) and [Dino v2](https://huggingface.co/timm/vit_large_patch14_dinov2.lvd142m). We also attempted strategies such as directly injecting image embeddings into cross-attention or extracting features from the Image Encoder using Qformer before injecting them into cross-attention. Under various strategies, we did not observe significant performance improvements; the impact on video quality was much smaller than that of the dataset. Therefore, we decided not to include the Semantic Adapter in Open-Sora Plan v1.3. #### Noise Injection Strategy for Conditional Images Researchs like [CogVideoX](https://arxiv.org/abs/2408.06072) and [Stable Video Diffusion](https://stability.ai/stable-video) have indicated that adding a certain amount of noise to Conditional Images can enhance the generalization capability of I2V models and achieve a greater range of motion. Therefore, we will implement this strategy in Open-Sora Plan v1.3, just the same as in [CogVideoX](https://arxiv.org/abs/2408.06072). ### The implementation of Skiparse Attention Skiparse is theoretically easy to understand and straightforward to implement. Its implementation mainly relies on the rearrange operation, which reduces the sequence length of latents before entering `F.scaled_dot_product_attention()`. Aside from this adjustment, no other modifications are made. For simplicity, the following discussion focuses solely on the self-attention part, excluding the attention mask. The pseudocode implementation of Single Skip is as follows: ```python # x.shape: (B,N,C) def single_skip_rearrange(x, sparse_k): return rearrange(x, 'b (g k) d -> (k b) g d', k=sparse_k) def reverse_sparse(x, sparse_k): return rearrange(x, '(k b) g d -> b (g k) d', k=sparse_k) q, k, v = Q(x), K(x), V(x) q = add_rope(q) k = add_rope(k) q = single_skip_rearrange(q) k = single_skip_rearrange(k) v = single_skip_rearrange(v) hidden_states = F.scaled_dot_product_attention(q=q,k=k,v=v) output = reverse_sparse(hidden_states) ``` The core of the Skiparse operation lies in "rearranging the sequence", which corresponds to the Single Skip operation in the pseudocode: ```python rearrange(x, '(g k) b d -> g (k b) d', k=sparse_k) ``` This operation can be understood as a combination of a reshape and a transpose operation:
In this way, $$k$$ sub-sequences can be created, and $$k$$ can be moved to the batch dimension, allowing the Attention mechanism to compute the sub-sequences in parallel. Understanding Single Skip makes Group Skip easy to comprehend as well; it simply adds a grouping operation before the Skip. Its pseudocode is as follows: ```python # x.shape: (B,N,C) def group_skip_rearrange(x, sparse_k): return rearrange(x, ' b (n m k) d -> (m b) (n k) d', m=sparse_k, k=sparse_k) def reverse_sparse(x, sparse_k): return rearrange(x, '(m b) (n k) d -> b (n m k) d', m=sparse_k, k=sparse_k) q, k, v = Q(x), K(x), V(x) q = add_rope(q) k = add_rope(k) q = group_skip_rearrange(q) k = group_skip_rearrange(k) v = group_skip_rearrange(v) hidden_states = F.scaled_dot_product_attention(q=q,k=k,v=v) output = reverse_sparse(hidden_states) ``` Every $$k^2$$ tokens form a repetition, and every $$k$$ tokens form a group. To help everyone better understand this operation, the following figure illustrates the situation when $$k=3$$:
It is important to note that the rope is added before the Skiparse operation and cannot be placed after it, as the sequence after Skiparse will lose its original spatial positions. ## Future Work and Discussion ### CasualVideoVAE For videos, increasing the compression ratio while maintaining the original latent dimension leads to significant information loss. Therefore, it is a trend to increase the latent dimension to achieve higher compression ratios. A more advanced VAE will be released in the next version. ### Diffusion Model The current 2B model in version 1.3.0 shows performance saturation during the later stages of training. However, it does not perform well in understanding physical laws (e.g., a cup overflowing with milk, a car moving forward, or a person walking). We have 4 hypotheses regarding this issue: #### The current data domain is too narrow. We randomly sampled 2,000 videos from Panda70m and conducted manual verification, finding that less than 1% featured cars in motion, and there were even fewer than 10 videos of people walking. Approximately 80% of the videos consist of half-body conversations with multiple people in front of the camera. Therefore, we speculate that the narrow data domain of Panda70m restricts the model's ability to generate many scenarios. We plan to collect more data in the next version. #### Joint training of images and videos Models such as [Open-Sora v1.2](https://github.com/hpcaitech/Open-Sora), [EasyAnimate v4](https://github.com/aigc-apps/EasyAnimate), and [Vchitect-2.0](https://github.com/Vchitect/Vchitect-2.0) can easily generate high-visual-quality videos, possibly due to their direct inheritance of image weights ([Pixart-Sigma](https://pixart-alpha.github.io/PixArt-sigma-project/), [HunyuanDiT](https://github.com/Tencent/HunyuanDiT), [SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers)). They train the model with a small amount of video data to learn how to flow along the time axis based on 2D images. However, we trained images from scratch with only 10M-level data, which is far from sufficient. We have two hypotheses regarding the training strategy: (1) the first is to start joint training from scratch, with images significantly outnumbering videos; (2) The second is to first train a high-quality image model and then use joint training, with a higher proportion of videos at that stage. Considering the learning path and training costs, the second approach may offer more decoupling, while the first aligns better with scaling laws. #### The model still needs to scale By observing the differences between [CogVideoX-2B](https://github.com/THUDM/CogVideo) and its 5B variant, we can clearly see that the 5B model understands more physical laws than the 2B model. We speculate that instead of spending excessive effort designing for smaller models, it may be more effective to leverage scaling laws to solve these issues. In the next version, we will scale up the model to explore the boundaries of video generation. We currently have two plans: one is to continue using the Deepspeed/FSDP approach, sharding the EMA and text encoder across ranks with Zero3, which is sufficient for training 10-15B models. The other is to adopt [MindSpeed](https://gitee.com/ascend/MindSpeed) for various parallel strategies, enabling us to scale the model up to 30B. #### Supervised loss in training Whether flow-based models are more suitable than v-pred models remains uncertain and requires further ablation studies to determine. ### How else can "Skiparse" skip? The sparse method we use is theoretically and practically straightforward; however, its implementation treats the original video data purely as a one-dimensional sequence, neglecting the 2D spatial priors. Thus, we extended Skiparse to create Skiparse-2D, which is better suited for 2D Visuals.
In Skiparse-2D, a sparse ratio of $$k$$ represents the sparsity along the $$h$$ or $$w$$ direction. In terms of the number of tokens involved in attention computation, it is equivalent to the square of the sparse ratio in Skiparse-1D. We conducted basic experiments comparing Skiparse-1D and Skiparse-2D. Under identical experimental settings, Skiparse-2D showed no improvement over Skiparse-1D in terms of loss or sampling results. Additionally, Skiparse-2D is less flexible to implement than Skiparse-1D. Therefore, we opted to use the Skiparse-1D approach for training in Open-Sora Plan v1.3. Nevertheless, given our limited experimentation, the feasibility of Skiparse-2D remains worth exploring. Intuitively, Skiparse-2D better aligns with the spatial characteristics of visuals, and as sparse ratio $$k$$ increases, its approach intuitively approximates that of 2+1D. We therefore encourage interested researchers in the community to pursue further exploration in this area. ================================================ FILE: docs/Report-v1.5.0.md ================================================ ## Report v1.5.0 In October 2024, we released Open-Sora Plan v1.3.0, introducing the sparse attention structure, Skiparse Attention, to the field of video generation for the first time. Additionally, we adopted the efficient WFVAE, significantly reducing encoding time and memory usage during training. In Open-Sora Plan v1.5.0, We introduce several key updates to enhance the framework: 1、Improved Sparse DiT, SUV. Building on Skiparse Attention, we extend sparse DiT into a U-shaped sparse structure. This design preserves speed advantages while enabling sparse DiT to achieve performance comparable to dense DiT. 2、Higher-compression WFVAE. In Open-Sora Plan v1.5.0, we explore a WFVAE with an 8×8×8 downsampling rate. It outperforms the performance of the widely adopted 4×8×8 VAE in the community, while reducing the latent shape by half and shortening the attention sequence length. 3、Data and model scaling. In Open-Sora Plan v1.5.0, we collect 1.1 billion high-quality images and 40 million high-quality videos. The model is scaled up to 8.5 billion parameters, resulting in strong overall performance. 4、Simplified Adaptive Gradient Clipping strategy. Compared to the more complex batch-dropping method in version 1.3.0, version 1.5.0 maintains a simple adaptive gradient norm threshold for clipping, making it more compatible with various parallel training strategies. Open-Sora Plan v1.5.0 is fully trained and inferred on Ascend 910-series accelerators, using the mindspeed-mm framework to support parallel training strategies. ### Open-Source Release Open-Sora Plan v1.5.0 is open-sourced with the following components: 1、All training and inference code. You can also find the implementation of Open-Sora Plan v1.5.0 in the official [MindSpeed-MM](https://gitee.com/ascend/MindSpeed-MM) repository. 2、The WFVAE weights with 8×8×8 compression, along with the 8.5B SUV denoiser weights. ## Detailed Technical Report ### Data collection and processing Our dataset includes 1.1B images from [Recap-DataComp-1B](https://huggingface.co/datasets/UCSC-VLAA/Recap-DataComp-1B)、[COYO-700M](https://github.com/kakaobrain/coyo-dataset)、[LAION-Aesthetics](https://laion.ai/blog/laion-aesthetics/), with no filtering applied aside from resolution checks. The video data are drawn from [Panda-70M](https://github.com/snap-research/Panda-70M) and internal sources, and filtered using the same protocol as in Open-Sora Plan v1.3.0, yielding 40M high-quality videos. ### Adaptive Grad Clipping In Open-Sora Plan v1.3.0, we introduce an Adaptive Grad Clipping strategy based on discarding gradient-abnormal batches. While highly stable, this method involve overly complex execution logic. In Open-Sora Plan v1.5.0, we optimize the strategy by maintaining the gradient norm threshold via an exponential moving average (EMA). Gradients exceeding the threshold are clipped accordingly. This approach effectively extends the fixed threshold of 1.0, which is commonly used in large-scale models, into a dynamic, training-dependent threshold. ```python ''' moving_avg_max_grad_norm: the maximum gradient norm maintained via EMA moving_avg_max_grad_norm_var: the variance of the maximum gradient norm maintained via EMA clip_threshold: the gradient clipping threshold computed using the 3-sigma rule ema_decay: the EMA decay coefficient, typically set to 0.99. grad_norm: grad norm at the current step ''' clip_threshold = moving_avg_max_grad_norm + 3.0 * (moving_avg_max_grad_norm_var ** 0.5) if grad_norm <= clip_threshold: # If the gradient norm is below the clipping threshold, the parameters are updated normally at this step, and both the moving_avg_max_grad_norm and moving_avg_max_grad_norm_var are updated accordingly. moving_avg_max_grad_norm = ema_decay * moving_avg_max_grad_norm + (1 - ema_decay) * grad_norm max_grad_norm_var = (moving_avg_max_grad_norm - grad_norm) ** 2 moving_avg_max_grad_norm_var = ema_decay * moving_avg_max_grad_norm_var + (1 - ema_decay) * max_grad_norm_var # update weights... else: # If the gradient norm exceeds the clipping threshold, the gradients are first clipped to reduce the norm to the threshold value before updating the parameters. clip_coef = grad_norm / clip_threshold grads = clip(grads, clip_coef) # clipping grads # update weights... ``` Compared to the strategy in v1.3.0, this approach is simpler to implement and effectively addresses the issue of loss spikes that occur in the later stages of diffusion training when the gradient norm is significantly below 1.0. ### WFVAE with 8x8x8 compression In version 1.5.0, we increase the temporal compression rate of the VAE from 4× to 8×, reducing the latent shape to half that of the previous version. This enables the generation of videos with higher frame counts. | Model | THW(C) | PSNR | LPIPS | rFVD | | ----------------- | ------------- | ------------ | ------------- | ------------ | | CogVideoX | 4x8x8 (16) | 36.38 | 0.0243 | 50.33 | | StepVideo | 8x16x16 (16) | 33.61 | 0.0337 | 113.68 | | LTXVideo | 8x32x32 (128) | 33.84 | 0.0380 | 150.87 | | Wan2.1 | 4x8x8 (16) | 35.77 | **0.0197** | **46.05** | | Ours (WF-VAE-M) | 8x8x8 (32) | **36.91** | 0.0205 | 52.53 | **Test on an open-domain dataset with 1K samples.** For more details on WFVAE, please refer to [WF-VAE: Enhancing Video VAE by Wavelet-Driven Energy Flow for Latent Video Diffusion Model](https://arxiv.org/abs/2411.17459) ### Training Text-to-Video Diffusion Model #### Framework —— SUV: A Sparse U-shaped Diffusion Transformer For Fast Video Generation In Open-Sora Plan v1.3.0, we discuss the strengths and weaknesses of Full 3D Attention and 2+1D Attention. Based on their characteristics, we propose Skiparse Attention, a novel global sparse attention mechanism. Under a predefined sparsity $k$, Skiparse Attention selects a subsequence of length $\frac{1}{k}$ of the original sequence in an alternating Single-Skip and Group-Skip pattern for attention interaction. This design approximates the effect of Full 3D Attention. As the sparsity increases, the selected positions become more widely spaced; as it decreases, the positions become more concentrated. Regardless of the sparsity, Skiparse Attention remains global. In Open-Sora Plan v1.5.0, we interpret this sparse interaction pattern as a form of token-level information downsampling. Sparser Skiparse Attention performs more semantic-level interactions, while denser Skiparse Attention captures fine-grained information. Following the multi-scale design principle in neural networks, we introduce Skiparse Attention with U-shaped sparsity variation: low-sparsity Skiparse Attention is used in shallow layers, with Full 3D Attention applied at the shallowest layer, and high-sparsity Skiparse Attention in deeper layers. Inspired by the UNet architecture, we further incorporate long skip connections between stages with identical sparsity. This U-shaped DiT architecture based on Skiparse Attention is referred to as **SUV**. ![SUV](https://github.com/user-attachments/assets/6eb54e37-7077-4746-a4c6-9b7165dd48fe) In Open-Sora Plan v1.5.0, we adopt an SUV architecture based on MMDiT. Skiparse Attention is applied to the video latents, while the text embeddings are only repeated to align with the skiparse-processed latent shape, without any sparsification. The SUV architecture offers the following advantages: 1、SUV is the first sparsification method proven effective for video generation. Our ablation studies show that it achieves performance comparable to dense DiT within the approximate training steps. Moreover, it can be applied during both pretraining and inference. Testing on the Ascend 910B platform at 121×576×1024 shape shows SUV runs over 35% faster than Dense DiT, with the attention operation alone gaining a speed boost of over 45%. 2、Unlike UNet structures that explicitly downsample feature maps and cause information loss, the U-shaped structure of SUV operates on attention. The shape of the feature map remains unchanged, preserving information while altering only the granularity of token-level interactions. 3、Skiparse Attention and SUV only change the attention computation during the forward pass instead of modifying model weights. This allows dynamic adjustment of sparsity throughout training: lower sparsity for image or low-resolution video training, and higher sparsity for high-resolution video training. As a result, FLOPS grow approximately linearly with increasing of sequence length. A more detailed analysis of the SUV architecture will be released in a future arXiv update. #### Training Stage Our training consists of two stages: Text-to-Image and Text-to-Video. #### Text-to-Image Previous studies have shown that image weights trained on synthetic data may negatively impact video training. Therefore, in the v1.5.0 update, we choose to train image weights using a much larger corpus of real-world data, totaling 1.1B images. Since image data come in various resolutions, whereas videos are primarily in a 9:16 aspect ratio, we adopt multi-resolution training for images using five common aspect ratios—(1,1), (3,4), (4,3), (9,16), and (16,9)—along with the Min-Max Token Strategy. In contrast, video training is conducted using a fixed 9:16 resolution. The difference between Skiparse Attention and Full Attention lies in the token sequences involved in the forward computation; the required weights remain identical. Therefore, we can first train the model using Dense MMDiT with Full 3D Attention, and then fine-tune it to the Sparse MMDiT mode after sufficient training. **Image-Stage-1:** Training is conducted using 512 Ascend 910B accelerators. We train a randomly initialized Dense MMDiT on 256²-pixel images with multi-resolution enabled. The learning rate is set to 1e-4, with a batch size of 8096. This stage runs for a total of 225k steps. **Image-Stage-2:** Training is conducted using 384 Ascend 910B accelerators. We train on 384²-pixel images with multi-resolution still enabled. The learning rate remains 1e-4, the batch size is 6144, and training lasts for 150k steps. **Image-Stage-3:** Training is conducted using 256 Ascend 910B accelerators. We train on 288x512 images with force resolution. The learning rate is 1e-4, the batch size is 4096, and training lasts for 110k steps. This stage completes the Dense MMDiT training. **Image-Stage-4:** Training is conducted using 256 Ascend 910B accelerators. We initialize the SUV model using the pretrained weights from Dense MMDiT, with skip connections zero-initialized to ensure that the model could produce non-noise outputs at the start. In practice, zero-shot inference reveals that the generated images contained meaningful low-frequency structures. Our experiments confirm that fine-tuning from Dense DiT to SUV converges quickly. This stage uses a fixed resolution of 288×512, a learning rate of 1e-4, a batch size of 4096, and is trained for approximately 160k steps. #### Text-to-Video For video training, we fix the aspect ratio at 9:16 and training solely on video data instead of joint training with image data. All training in this stage is performed using 512 Ascend 910B accelerators. **Video-Stage-1:** Starting from the SUV weights pretrained during the Text-to-Image phase, we train on videos with a shape of 57×288×512 for about 40k steps. The setup includes a learning rate of 6e-5, TP/SP parallelism of 2, gradient accumulation set to 2, a micro batch size of 2, and a global batch size of 1024. Videos are trained at 24 fps, representing approximately 2.4 seconds (57/24 ≈ 2.4s) of content per sample. This stage marks the initial adaptation from image-based to video-based weights, for which shorter video clips are intentionally selected to ensure stable initialization. **Video-Stage-2:** We further train on videos with a shape of 57×288×512 for 45k steps, keeping the learning rate, TP/SP parallelism, and gradient accumulation settings unchanged. However, the training frame rate is reduced to 12 fps, corresponding to ~4.8 seconds of video content per sample (57/12 ≈ 4.8s). This stage aims to enhance temporal learning without increasing sequence length, serving as preparation for later high-frame-counts training. **Video-Stage-3:** We train on videos with a shape of 121×288×512 for approximately 25k steps. The learning rate is adjusted to 4e-5, with TP/SP parallelism set to 4, gradient accumulation steps set to 2, a micro batch size of 4, and a global batch size of 1024. In this stage, we revert to a training frame rate of 24 fps. **Video-Stage-4:** We conduct training on videos with a shape of 121×576×1024 for a total of 16k + 9k steps. The learning rates are set to 2e-5 and 1e-5 for the two phases, respectively. TP/SP parallelism is configured as 4, with gradient accumulation steps set to 4, a micro batch size of 1, and a global batch size of 512. **Video-Stage-5:** We train on a high-quality subset of the dataset for 5k steps, using a learning rate of 1e-5. TP/SP parallelism is set to 4, with gradient accumulation steps of 4, a micro batch size of 1, and a global batch size of 512. #### Performance on Vbench | Model | Parameters | Total Score | Quality Score | Semantic Score | **aesthetic quality** | | -------------------------- | ---------- | ------------- | ------------- | -------------- | --------------------- | | Mochi-1 | 10B | 80.13% | 82.64% | 70.08% | 56.94% | | CogvideoX-2B | 2B | 80.91% | 82.18% | 75.83% | 60.82% | | CogvideoX-5B | 5B | 81.61% | 82.75% | 77.04% | 61.98% | | Step-Video-T2V | 30B | 81.83% | 84.46% | 71.28% | 61.23% | | CogvideoX1.5-5B | 5B | 82.17% | 82.78% | **79.76%** | 62.79% | | Gen-3 | - | 82.32% | 84.11% | 75.17% | 63.34% | | HunyuanVideo (Open-Source) | 13B | **83.24%** | **85.09%** | 75.82% | 60.36% | | Open-Sora Plan v1.5.0 | 8B | 83.02% | 84.24% | 78.18% | **66.89%** | ### Training Image-to-Video Diffusion Model Coming Soon... ### Future Work Currently, open-source models such as Wan2.1 have achieved performance comparable to closed-source commercial counterparts. Given the gap in computing resources and data availability compared to industry-scale efforts, the future development of the Open-Sora Plan will focus on the following directions: 1、Latents Cache。 In the training process of Text-to-Video models, the data must be processed through two key modules—the Variational Autoencoder (VAE) and the Text Encoder—to extract features from both video/images and their corresponding prompts. These encoded features serve as inputs to the training model. However, in existing industry practices, feature encoding is redundantly performed on the multimodal training dataset during every training epoch. This leads to additional computational overhead and significantly prolongs the total training time. Specifically, in conventional training pipelines, the VAE and Text Encoder modules are typically kept resident in GPU memory to perform feature encoding in real time during each epoch. While this ensures on-the-fly encoding, it also results in persistently high GPU memory usage, becoming a major bottleneck for training efficiency. This issue is exacerbated when handling large-scale datasets or complex models, where memory constraints further limit model capacity and training speed. To address the above issue, we propose an optimization strategy that replaces repeated feature computation with feature lookup. The core idea is to decouple feature encoding from model training. Specifically, during pretraining or the first training epoch, we compute and store the most computationally expensive text prompt features in external high-performance storage. During subsequent training, the model directly loads these precomputed features from storage, avoiding redundant encoding operations. This design significantly reduces computational overhead and GPU memory usage, allowing more memory to be allocated to model training. Based on the following configuration environment, we compare the training time per epoch and per step before and after applying the feature caching strategy. Experimental results show that storing precomputed features reduces multi-epoch training time by approximately 30% and frees up around 20% of GPU memory resources. | **Configuration** | **Details** | | :---------------: | :-----------------------------------------: | | Model | Open-Sora Plan v1.5.0 (2B-level parameters) | | Dataset | 100K images and 10K videos | | Accelerators | 8× Nvidia A800 GPUs | | Feature Storage | Huawei OceanStor AI Storage | Test cases: | **Training Stage** | **Test Type** | **Batch Size** | **Time per Step** | **Time per Epoch** | **Memory Usage** | | ------------------ | ---------------------- | -------------- | ----------------- | ------------------ | ---------------- | | Low-Res Images | General Method | 64 | 6.53s | 21 min 12s | 56 GB | | | Feature Caching Method | 64 | 4.10s | 13 min 19s | 40 GB | | | General Method | 128 | 12.78s | 20 min 39s | 74 GB | | | Feature Caching Method | 128 | 7.81s | 12 min 38s | 50 GB | | Low-Res Videos | General Method | 8 | 8.90s | 26 min 23s | 68 GB | | | Feature Caching Method | 8 | 7.78s | 23 min 05s | 51 GB | | High-Res Videos | General Method | 4 | 17.00s | 101 min | 71 GB | | | Feature Caching Method | 4 | 16.00s | 97 min | 57 GB | 2、Improved DiT pretraining with sparse or linear attention. In v1.3.0, we introduce the first DiT pretrained with sparse attention in the community. This is extended in v1.5.0 into the SUV architecture, enabling sparse DiT to achieve performance comparable to its dense counterpart. While sparse and linear attention have demonstrated significant success in the LLM domain, their application in video generation remains underexplored. In future versions, we plan to further investigate the integration of sparse and linear attention into video generation models. 3、MoE-based DiT. Since the release of [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1), the MoE (Mixture-of-Experts) paradigm has become a common approach for scaling LLMs to larger parameter sizes. Currently, open-source video generation models are capped at around 14B parameters, which is still relatively small compared to the 100B+ scales in the LLM field. Incorporating MoE into the DiT architecture, and exploring its combination with sparse and linear attention, is a future direction under consideration by the Open-Sora Plan team. 4、Unified video generation models for both generation and understanding. The March release of GPT-4o demonstrates that unified architectures combining generation and understanding can offer fundamentally different capabilities compared to purely generative models. In the video domain, we should similarly anticipate the potential breakthroughs that such unified generative models might bring. 5、Enhancing Image-to-Video generation models. Current approaches in this field still largely follow either the SVD paradigm or the inpainting-based paradigm adopted since Open-Sora Plan v1.2.0. Both approaches require extensive fine-tuning of pretrained Text-to-Video models. From a practical standpoint, Text-to-Video is more aligned with academic exploration, while Image-to-Video is more relevant to real-world production scenarios. As a result, developing a new paradigm for Image-to-Video will be a key focus for the Open-Sora Plan team moving forward. ================================================ FILE: docs/Report-v1.5.0_cn.md ================================================ ## Report v1.5.0 在2024年的10月,我们发布了Open-Sora Plan v1.3.0,第一次将一种稀疏化的attention结构——skiparse attention引入video generation领域。同时,我们采用了高效的WFVAE,使得训练时的编码时间和显存占用大大降低。 在Open-Sora Plan v1.5.0中,Open-Sora Plan引入了几个关键的更新: 1、更好的sparse dit——SUV。在skiparse attention的基础上,我们将sparse dit扩展至U形变化的稀疏结构,使得在保持速度优势的基础上sparse dit可以取得和dense dit相近的性能。 2、更高压缩率的WFVAE。在Open-Sora Plan v1.5.0中,我们尝试了8x8x8下采样率的WFVAE,它在性能上媲美社区中广泛存在的4x8x8下采样率的VAE的同时latent shape减半,降低attention序列长度。 3、data和model scaling。在Open-Sora Plan v1.5.0中,我们收集了1.1B的高质量图片数据和40m的高质量视频数据,并将模型大小scale到8.5B,使最终得到的模型呈现出不俗的性能。 4、更简易的Adaptive Grad Clipping。相比于version 1.3.0中较复杂的丢弃污点batch的策略,在version 1.5.0中我们简单地维护一个adaptive的grad norm threshold并clipping,以此更适应各种并行策略的需要。 Open-Sora Plan v.1.5.0全程在昇腾910系列加速卡上完成训练和推理,并采用mindspeed-mm训练框架适配并行策略。 ### Open-Source Release Open-Sora Plan v1.5.0的开源包括: 1、所有训练和推理代码。你也可以在[MindSpeed-MM](https://gitee.com/ascend/MindSpeed-MM)官方仓库找到open-sora plan v1.5.0版本的实现。 2、8x8x8下采样的WFVAE权重以及8.5B的SUV去噪器权重。 ## Detailed Technical Report ### Data collection and processing 我们共收集了来自Recap-DataComp-1B、Coyo700M、Laion-aesthetic的共1.1B图片数据。对于图片数据,我们不进行除了分辨率之外的筛选。我们的视频数据来自于Panda70M以及其他自有数据。对于视频数据,我们采用与Open-Sora Plan v1.3.0相同的处理策略进行筛选,最终数据量为40m的高质量视频数据。 ### Adaptive Grad Clipping 在Open-Sora Plan v1.3.0中,我们介绍了一种基于丢弃梯度异常batch的Adaptive Grad Clipping策略,这种策略具有很高的稳定性,但是执行逻辑过于复杂。因此,在Open-Sora Plan v1.5.0中,我们选择将该策略进行优化,采用EMA方式维护grad norm的threshold,并在grad norm超过该threshold时裁剪到threshold以下。该策略本质上是将大模型领域常用的1.0常数grad norm threshold扩展为一个随着训练进程动态变化的threshold。 ```python ''' moving_avg_max_grad_norm: EMA方式维护的最大grad norm moving_avg_max_grad_norm_var: EMA方式维护的最大grad norm的方差 clip_threshold: 根据3 sigma策略计算得到的梯度裁剪阈值 ema_decay: EMA衰减系数,一般为0.99 grad_norm: 当前step的grad norm ''' clip_threshold = moving_avg_max_grad_norm + 3.0 * (moving_avg_max_grad_norm_var ** 0.5) if grad_norm <= clip_threshold: # grad norm小于裁剪阈值,则该step参数正常更新,同时更新维护的moving_avg_max_grad_norm 和 moving_avg_max_grad_norm_var moving_avg_max_grad_norm = ema_decay * moving_avg_max_grad_norm + (1 - ema_decay) * grad_norm max_grad_norm_var = (moving_avg_max_grad_norm - grad_norm) ** 2 moving_avg_max_grad_norm_var = ema_decay * moving_avg_max_grad_norm_var + (1 - ema_decay) * max_grad_norm_var 参数更新... else: # grad norm大于裁剪阈值,则先裁剪grad使grad norm减少至clip_threshold,再进行参数更新。 clip_coef = grad_norm / clip_threshold grads = clip(grads, clip_coef) # 裁剪grads 参数更新... ``` 该策略相较于v1.3.0中策略实现更简单,且能够很好应对diffusion训练后期grad norm远小于1.0时仍存在loss spike的问题。 ### WFVAE with 8x8x8 downsampling 在V1.5.0版本中,我们将VAE的时间压缩率从4倍压缩提高至8倍压缩,使得对于同样原始尺寸的视频,latent shape减少为先前版本的一半,这使得我们可以实现更高帧数的视频生成。 | Model | THW(C) | PSNR | LPIPS | rFVD | | ----------------- | ------------- | ------------ | ------------- | ------------ | | CogVideoX | 4x8x8 (16) | 36.38 | 0.0243 | 50.33 | | StepVideo | 8x16x16 (16) | 33.61 | 0.0337 | 113.68 | | LTXVideo | 8x32x32 (128) | 33.84 | 0.0380 | 150.87 | | Wan2.1 | 4x8x8 (16) | 35.77 | **0.0197** | **46.05** | | Ours (WF-VAE-M) | 8x8x8 (32) | **36.91** | 0.0205 | 52.53 | **Test on an open-domain dataset with 1K samples.** WFVAE详情请见[WF-VAE: Enhancing Video VAE by Wavelet-Driven Energy Flow for Latent Video Diffusion Model](https://arxiv.org/abs/2411.17459) ### Training Text-to-Video Diffusion Model #### Framework —— SUV: A Sparse U-shaped Diffusion Transformer For Fast Video Generation 在Open-Sora Plan v1.3.0中,我们讨论了Full 3D Attention以及2+1D Attention的优劣,并综合他们的特点提出了Skiparse Attention——一种新型的global sparse attention。 在一个事先指定的sparse ratio $k$ 下,Skiparse Attention按照Single Skip - Group Skip交替的方式选定原序列长度 $\frac{1}{k}$ 的子序列进行attention交互,以此达到近似Full 3D Attention的效果。在Skiparse Attention中,sparse ratio越大,子序列在原序列中的位置越稀疏;sparse ratio越小,子序列在原序列中的位置越密集。但无论sparse ratio为多少,Skiparse Attention总是global的。 在Open-Sora Plan v1.5.0中,我们将这种稀疏交互方式看作一种token上的信息下采样,越稀疏的Skiparse Attention是一种更偏语义级的信息交互,越密集的Skiparse Attention是一种更偏细粒度的信息交互。遵循神经网络中多尺度设计的准则,我们在网络中引入U形变化稀疏度的Skiparse Attention,即浅层采用稀疏度低的Skiparse Attention,并在最浅层使用Full 3D Attention,深层采用稀疏度高的Skiparse Attention。特别的,类比UNet的设计,我们在相同稀疏度的Stage之间引入了Long Skip Connection。我们将这种U形变化的基于Skiparse Attention的DiT称之为SUV。 ![SUV](https://github.com/user-attachments/assets/6eb54e37-7077-4746-a4c6-9b7165dd48fe) 在Open-Sora Plan v1.5.0中我们采用了基于MMDiT的SUV架构。对于video latents,我们对其进行skiparse attention操作,对于text embedding,我们仅对其进行repeat以对齐skiparse后的latent shape而不进行任何稀疏化操作。 SUV架构存在以下优点: 1、SUV是首个在视频生成模型上验证有效的稀疏化方法,在我们的消融实验中表明其在同样训练步数下可以达到接近dense dit的性能,且可以同时应用于预训练和推理中。在910B测试平台下,在121x576x1024的视频shape下,SUV的推理速度相比Dense DiT提升35%以上,其中Attn部分速度提升45%以上。 2、相较于UNet结构对feature map进行显式的下采样造成了信息损失,SUV的U形结构作用在Attention上,feature map的shape并没有发生变化,即信息并未发生损失,改变的只是token间信息交互的粒度。 3、Skiparse Attention及SUV不改变权重大小,只改变forward时attention的计算方式。这使得我们可以随着训练进程动态调整稀疏度,在图片训练或低分辨率视频训练时采用较低的稀疏度,在高分辨率视频训练时提高稀疏度,获得随序列长度近似线性增长的FLOPS。 对SUV架构更细致的分析,将会在后续更新至arxiv。 #### Training Stage 我们的训练包括Text-to-Image和Text-to-Video两个阶段。 #### Text-to-Image 先前的工作表明从合成数据训练得到的图像权重可能会影响视频训练时的效果。因此,在v1.5.0更新中,我们选择在更大的真实数据域内训练图像权重。我们收集了共1.1B的图片数据进行训练。由于图片存在多种不同的分辨率,而视频主要为9:16分辨率,因此我们选择在训练图片权重时开启多分辨率(5个常见宽高比:(1,1), (3,4), (4,3), (9,16), (16,9) )及Min-Max token Strategy训练,而在训练视频时采用固定9:16的宽高比固定分辨率训练。 Skiparse Attention与Full Attention的区别在于前向过程中参与计算的token序列不同,所需要的权重变量则完全相同。因此,我们可以先用Full 3D Attention的Dense MMDiT做训练,并在训练充分后Fine-tune至Sparse MMDiT模式。 **Image-Stage-1:** 采用512张Ascend 910B进行训练。 我们采用随机初始化的Dense MMDiT在256^2px级别分辨率的图片上训练,开启多分辨率。学习率为1e-4,batch size为8096。在这个阶段我们总共训练了225k steps。 **Image-Stage-2:** 采用384张Ascend 910B进行训练。在384^px级别的图片上训练,开启多分辨率训练。学习率为1e-4,batch size为6144,共训练150k step。 **Image-Stage-3:** 采用256张Ascend 910B进行训练。固定288x512分辨率训练。学习率为1e-4,batch size为4096,共训练110k step。Dense MMDiT阶段训练完成。 **Image-Stage-4:** 采用256张Ascend 910B进行训练。采用Dense MMDiT的权重初始化SUV,其中skip connection采用零初始化,保证初始SUV权重能够推出非噪声图片。事实上,zero shot推理得到的图片具备一定的低频信息,我们验证了Dense DiT到SUV的finetune可以很快达成。该阶段固定分辨率为288x512,学习率为1e-4,batch size为4096,共训练约160k step。 #### Text-to-Video 在训练视频时,我们采用的宽高比固定为9:16,且并未采用视频图像联合训练,而是仅用视频数据做训练。以下训练均在512张Ascend 910B上完成。 **Video-Stage-1:** 继承Text-to-Image阶段得到的SUV权重,我们在57x288x512的视频上训练了大约40k step,学习率为6e-5,TP/SP并行度为2,学习率为6e-5,梯度累积次数为2, micro batch size为2,global batch size为1024。在这个阶段,我们采用的train fps为24,即大约57/24≈2.4s的视频内容。该阶段作为图片权重到视频权重迁移的第一个阶段,我们选择了较短的视频训练作为良好的初始化。 **Video-Stage-2:** 我们同样在57x288x512的视频上训练45k step,学习率、TP/SP并行度和梯度累积设置保持不变,但是train fps更改为12,即对应的原视频长度为57/12≈4.8s的内容。该阶段旨在不增加序列长度的同时提高对时序的学习,为后续高帧数训练阶段做准备。 **Video-Stage-3:** 我们在121x288x512的视频上训练约25k step,学习率调整为4e-5、TP/SP并行度设置为4,梯度累积次数设置为2,micro batch size为4,global batch size为1024。在这个阶段我们重新采用train fps为24。 **Video-Stage-4:** 在121x576x1024的视频上共训练16k + 9k step,学习率分别为2e-5和1e-5,TP/SP并行度设置为4,梯度累积次数设置为4,micro batch size为1,global batch size为512。 **Video-Stage-5:** 我们选择数据中的高质量子集训练了5k step,学习率为1e-5,TP/SP并行度设置为4,梯度累积次数设置为4,micro batch size为1,global batch size为512。 #### Performance on Vbench | Model | Parameters | Total Score | Quality Score | Semantic Score | **aesthetic quality** | | -------------------------- | ---------- | ------------- | ------------- | -------------- | --------------------- | | Mochi-1 | 10B | 80.13% | 82.64% | 70.08% | 56.94% | | CogvideoX-2B | 2B | 80.91% | 82.18% | 75.83% | 60.82% | | CogvideoX-5B | 5B | 81.61% | 82.75% | 77.04% | 61.98% | | Step-Video-T2V | 30B | 81.83% | 84.46% | 71.28% | 61.23% | | CogvideoX1.5-5B | 5B | 82.17% | 82.78% | **79.76%** | 62.79% | | Gen-3 | - | 82.32% | 84.11% | 75.17% | 63.34% | | HunyuanVideo (Open-Source) | 13B | **83.24%** | **85.09%** | 75.82% | 60.36% | | Open-Sora Plan v1.5.0 | 8B | 83.02% | 84.24% | 78.18% | **66.89%** | ### Training Image-to-Video Diffusion Model Comming Soon... ### Future Work 目前,开源社区已经有与闭源商业版本相当性能的模型,如Wan2.1。鉴于算力和数据相比企业来说仍存在不足,后续Open-Sora Plan团队的改进方向为: 1、Latents Cache。 在Text2Video模型的训练过程中,训练数据需要经过变分自编码器(VAE)和文本编码器(Text Encoder)两个关键模块的处理,以实现对视频/图片和对应引导词的特征编码。这些编码后的特征数据作为模型训练的输入,参与后续训练流程。然而业界训练方案中,每个训练周期(Epoch)都需要对多模态训练数据集进行重复的特征编码计算,这不仅增加了额外的计算开销,还显著延长了整体训练时间。 具体而言,在传统的训练流程中,VAE和Text Encoder模型通常需要常驻于GPU显存中,以便在每个Epoch中实时执行特征编码任务。这种设计虽然确保了特征编码的实时性,但也导致了GPU显存占用率居高不下,成为制约训练效率的主要瓶颈之一。尤其是在处理大规模数据集或复杂模型时,显存资源的紧张会进一步加剧这一问题,限制了模型的参数量和训练速度。 为了解决上述问题,我们提出了一种特征值以查代算的优化方案。该方案的核心思想是将特征编码的计算过程与模型训练过程进行解耦。具体实现方式为:在训练前或首轮训练时计算耗时最高的引导词特征值,将其保存至外置高性能文件存储中。后续的训练过程中,模型可以直接从文件存储中读取这些预计算的特征数据,避免了重复的特征编码计算。这种设计不仅显著减少了计算资源的浪费,还大幅降低了GPU显存的占用率,使更多的显存资源可用于模型训练。 基于以下配置环境,统计使用特征数据存储前后的单个epoch及单个step的训练数据。实验表明,特征值存储方案**可缩短约30%多轮迭代训练时间,同时释放约20%显存资源。** | 配置环境 | 详细信息 | | :--------: | :-------------------------------: | | 模型 | Open-Sora Plan v1.5.0 with 2B量级 | | 数据集 | 100K图片及10K视频 | | GPU服务器 | 8张Nvidia A800 | | 特征值存储 | 华为OceanStor AI存储 | 测试数据: | 训练阶段 | 测试类型 | Batch Size | 单Step耗时 | 单Epoch耗时 | 显存占用 | | ------------ | ---------------- | ---------- | ---------- | ----------- | -------- | | 低分辨率图片 | 通用方案 | 64 | 6.53s | 21min12s | 56GB | | | 特征数据存储方案 | 64 | 4.10s | 13min19s | 40GB | | | 通用方案 | 128 | 12.78s | 20min39s | 74GB | | | 特征数据存储方案 | 128 | 7.81s | 12min38s | 50GB | | 低分辨率视频 | 通用方案 | 8 | 8.90s | 26min23s | 68GB | | | 特征数据存储方案 | 8 | 7.78s | 23min05s | 51GB | | 高分辨率视频 | 通用方案 | 4 | 17s | 101min | 71GB | | | 特征数据存储方案 | 4 | 16s | 97min | 57GB | 2、更好的基于稀疏化attention or 线性attention预训练的DiT。在V1.3.0中,我们推出了社区中第一个基于稀疏attention预训练的DiT,并在V1.5.0版本中将其扩展为SUV架构,使稀疏DiT获得了与Dense DiT相当的模型性能。稀疏attention和线性attention在LLM领域已经获得了很大的成功,但在视频生成领域中的应用仍不够明显。在后续版本中,我们将进一步探索稀疏attention和线性attention在video generation领域的应用。 3、基于MoE的DiT。自Mixtral 8x7B发布以来,LLM领域通常会采用MoE的方式将模型scale至更大的参数量。目前开源视频模型的最大大小仅限于14B,相比于LLM领域上百B的参数量来说仍属于小模型。在DiT架构中引入MoE,以及MoE与稀疏attention和线性attention的结合,是Open-Sora Plan团队未来考虑的方向。 4、生成和理解统一的视频生成模型。3月份gpt-4o的更新让大家认识到了生成理解统一架构的生成模型能够获得与纯生成模型完全不同的能力。在视频领域,我们同样应该期待一个统一的生成模型能够为我们带来哪些惊喜。 5、更好的Image-to-Video模型。目前Image-to-Video领域仍基本遵循SVD范式和Open-Sora Plan v1.2.0起采用的Inpainting范式。这两种范式都需要在Text-to-Video模型权重的基础上进行长时间的finetune。从应用意义上看,Text-to-Video更接近于学术上的探索,而Image-to-Video则更贴近现实的生产环境。因此,Image-to-Video的更新范式也会是Open-Sora Plan团队未来的重点探索方向。 ================================================ FILE: docs/VAE.md ================================================ ### Data prepare The organization of the training data is easy. We only need to put all the videos recursively in a directory. This makes the training more convenient when using multiple datasets. ``` shell Training Dataset |——sub_dataset1 |——sub_sub_dataset1 |——video1.mp4 |——video2.mp4 ...... |——sub_sub_dataset2 |——video3.mp4 |——video4.mp4 ...... |——sub_dataset2 |——video5.mp4 |——video6.mp4 ...... |——video7.mp4 |——video8.mp4 ``` ### Training ``` shell bash scripts/causalvae/train.sh ``` We introduce the important args for training. | Argparse | Usage | |:---|:---| |_Training size_|| |`--num_frames`|The number of using frames for training videos| |`--resolution`|The resolution of the input to the VAE| |`--batch_size`|The local batch size in each GPU| |`--sample_rate`|The frame interval of when loading training videos| |_Data processing_|| |`--video_path`|/path/to/dataset| |_Load weights_|| |`--model_name`| `CausalVAE` or `WFVAE`| |`--model_config`|/path/to/config.json The model config of VAE. If you want to train from scratch use this parameter.| |`--pretrained_model_name_or_path`|A directory containing a model checkpoint and its config. Using this parameter will only load its weight but not load the state of the optimizer| |`--resume_from_checkpoint`|/path/to/checkpoint It will resume the training process from the checkpoint including the weight and the optimizer.| ### Inference ``` shell bash scripts/causalvae/rec_video.sh ``` We introduce the important args for inference. | Argparse | Usage | |:---|:---| |_Ouoput video size_|| |`--num_frames`|The number of frames of generated videos| |`--height`|The resolution of generated videos| |`--width`|The resolution of generated videos| |_Data processing_|| |`--video_path`|The path to the original video| |`--rec_path`|The path to the generated video| |_Load weights_|| |`--ae_path`|/path/to/model_dir. A directory containing the checkpoint of VAE is used for inference and its model config.json| |_Other_|| |`--enable_tilintg`|Use tiling to deal with videos of high resolution and long duration| |`--save_memory`|Save memory to inference but lightly influence quality| ### Evaluation The evaluation process consists of two steps: Reconstruct videos in batches: `bash scripts/causalvae/prepare_eval.sh` Evaluate video metrics: `bash scripts/causalvae/eval.sh` To simplify the evaluation, environment variables are used for control. For step 1 (`bash scripts/causalvae/prepare_eval.sh`): ```bash # Experiment name EXP_NAME=wfvae # Video parameters SAMPLE_RATE=1 NUM_FRAMES=33 RESOLUTION=256 # Model weights CKPT=ckpt # Select subset size (0 for full set) SUBSET_SIZE=0 # Dataset directory DATASET_DIR=test_video ``` For step 2 (`scripts/causalvae/eval.sh`): ```bash # Experiment name EXP_NAME=wfvae-4dim # Video parameters SAMPLE_RATE=1 NUM_FRAMES=33 RESOLUTION=256 # Evaluation metric METRIC=lpips # Select subset size (0 for full set) SUBSET_SIZE=0 # Path to the ground truth videos, which can be saved during video reconstruction by setting `--output_origin` ORIGIN_DIR=video_gen/${EXP_NAME}_sr${SAMPLE_RATE}_nf${NUM_FRAMES}_res${RESOLUTION}_subset${SUBSET_SIZE}/origin # Path to the reconstructed videos RECON_DIR=video_gen/${EXP_NAME}_sr${SAMPLE_RATE}_nf${NUM_FRAMES}_res${RESOLUTION}_subset${SUBSET_SIZE} ``` ================================================ FILE: examples/cond_pix_path.txt ================================================ examples/test_img1.png examples/test_img2.png examples/test_img3.png ================================================ FILE: examples/cond_prompt.txt ================================================ A rocket ascends slowly into the sky. Along the coast, variously sized boats float on the lake. The landscape at sunset is profound and expansive. ================================================ FILE: examples/rec_image.py ================================================ import sys sys.path.append(".") from PIL import Image import torch from torchvision.transforms import ToTensor, Compose, Resize, Normalize, Lambda from torch.nn import functional as F import argparse import numpy as np from opensora.models.causalvideovae import ae_wrapper def preprocess(video_data: torch.Tensor, short_size: int = 128) -> torch.Tensor: transform = Compose( [ ToTensor(), Lambda(lambda x: 2. * x - 1.), Resize(size=short_size), ] ) outputs = transform(video_data) outputs = outputs.unsqueeze(0).unsqueeze(2) return outputs def main(args: argparse.Namespace): image_path = args.image_path short_size = args.short_size device = args.device kwarg = {} # vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir', **kwarg).to(device) vae = ae_wrapper[args.ae](args.ae_path, **kwarg).eval().to(device) if args.enable_tiling: vae.vae.enable_tiling() vae.vae.tile_overlap_factor = args.tile_overlap_factor vae.eval() vae = vae.to(device) vae = vae.half() with torch.no_grad(): x_vae = preprocess(Image.open(image_path), short_size) x_vae = x_vae.to(device, dtype=torch.float16) # b c t h w latents = vae.encode(x_vae) latents = latents.to(torch.float16) image_recon = vae.decode(latents) # b t c h w x = image_recon[0, 0, :, :, :] x = x.squeeze() x = x.detach().cpu().numpy() x = np.clip(x, -1, 1) x = (x + 1) / 2 x = (255*x).astype(np.uint8) x = x.transpose(1,2,0) image = Image.fromarray(x) image.save(args.rec_path) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--image_path', type=str, default='') parser.add_argument('--rec_path', type=str, default='') parser.add_argument('--ae', type=str, default='') parser.add_argument('--ae_path', type=str, default='') parser.add_argument('--model_path', type=str, default='results/pretrained') parser.add_argument('--short_size', type=int, default=336) parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--tile_overlap_factor', type=float, default=0.25) parser.add_argument('--enable_tiling', action='store_true') args = parser.parse_args() main(args) ================================================ FILE: examples/rec_video.py ================================================ import math import random import argparse from typing import Optional import cv2 import numpy as np import numpy.typing as npt import torch from PIL import Image from decord import VideoReader, cpu from torch.nn import functional as F from torchvision.transforms import Lambda, Compose import sys sys.path.append(".") from opensora.models.causalvideovae import ae_wrapper from opensora.dataset.transform import ToTensorVideo, CenterCropResizeVideo def array_to_video(image_array: npt.NDArray, fps: float = 30.0, output_file: str = 'output_video.mp4') -> None: height, width, channels = image_array[0].shape fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height)) for image in image_array: image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) video_writer.write(image_rgb) video_writer.release() def custom_to_video(x: torch.Tensor, fps: float = 2.0, output_file: str = 'output_video.mp4') -> None: x = x.detach().cpu() x = torch.clamp(x, -1, 1) x = (x + 1) / 2 x = x.permute(0, 2, 3, 1).float().numpy() x = (255 * x).astype(np.uint8) array_to_video(x, fps=fps, output_file=output_file) return def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor: decord_vr = VideoReader(video_path, ctx=cpu(0)) total_frames = len(decord_vr) sample_frames_len = sample_rate * num_frames # if total_frames > sample_frames_len: # s = random.randint(0, total_frames - sample_frames_len - 1) # s = 0 # e = s + sample_frames_len # num_frames = num_frames # else: # s = 0 # e = total_frames # num_frames = int(total_frames / sample_frames_len * num_frames) s = 0 e = sample_frames_len print(f'sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}', video_path, total_frames) frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int) video_data = decord_vr.get_batch(frame_id_list).asnumpy() video_data = torch.from_numpy(video_data) video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W) return video_data def preprocess(video_data: torch.Tensor, height: int = 128, width: int = 128) -> torch.Tensor: transform = Compose( [ ToTensorVideo(), CenterCropResizeVideo((height, width)), Lambda(lambda x: 2. * x - 1.) ] ) video_outputs = transform(video_data) video_outputs = torch.unsqueeze(video_outputs, 0) return video_outputs def main(args: argparse.Namespace): device = args.device kwarg = {} # vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir', **kwarg).to(device) # vae = CausalVAEModelWrapper(args.ae_path, **kwarg).to(device) vae = ae_wrapper[args.ae](args.ae_path, **kwarg).eval().to(device) if args.enable_tiling: vae.vae.enable_tiling() vae.vae.tile_overlap_factor = args.tile_overlap_factor # vae.vae.tile_sample_min_size = 512 # vae.vae.tile_latent_min_size = 64 # vae.vae.tile_sample_min_size_t = 29 # vae.vae.tile_latent_min_size_t = 8 # if args.save_memory: # vae.vae.tile_sample_min_size = 256 # vae.vae.tile_latent_min_size = 32 # vae.vae.tile_sample_min_size_t = 9 # vae.vae.tile_latent_min_size_t = 3 dtype = torch.bfloat16 vae.eval() vae = vae.to(device, dtype=dtype) with torch.no_grad(): x_vae = preprocess(read_video(args.video_path, args.num_frames, args.sample_rate), args.height, args.width) print("input shape", x_vae.shape) x_vae = x_vae.to(device, dtype=dtype) # b c t h w # for i in range(10000): latents = vae.encode(x_vae) latents = latents.to(dtype) video_recon = vae.decode(latents) # b t c h w print("recon shape", video_recon.shape) # vae = vae.half() # from tqdm import tqdm # with torch.no_grad(): # x_vae = torch.rand(1, 3, 93, 720, 1280) # print(x_vae.shape) # x_vae = x_vae.to(device, dtype=torch.float16) # b c t h w # # x_vae = x_vae.to(device) # b c t h w # for i in tqdm(range(100000)): # latents = vae.encode(x_vae) # print(latents.shape) # latents = latents.to(torch.float16) # video_recon = vae.decode(latents) # b t c h w # print(video_recon.shape) custom_to_video(video_recon[0], fps=args.fps, output_file=args.rec_path) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--video_path', type=str, default='') parser.add_argument('--rec_path', type=str, default='') parser.add_argument('--ae', type=str, default='') parser.add_argument('--ae_path', type=str, default='') parser.add_argument('--model_path', type=str, default='results/pretrained') parser.add_argument('--fps', type=int, default=30) parser.add_argument('--height', type=int, default=336) parser.add_argument('--width', type=int, default=336) parser.add_argument('--num_frames', type=int, default=100) parser.add_argument('--sample_rate', type=int, default=1) parser.add_argument('--device', type=str, default="cuda") parser.add_argument('--tile_overlap_factor', type=float, default=0.25) parser.add_argument('--tile_sample_min_size', type=int, default=512) parser.add_argument('--tile_sample_min_size_t', type=int, default=33) parser.add_argument('--tile_sample_min_size_dec', type=int, default=256) parser.add_argument('--tile_sample_min_size_dec_t', type=int, default=33) parser.add_argument('--enable_tiling', action='store_true') parser.add_argument('--save_memory', action='store_true') args = parser.parse_args() main(args) ================================================ FILE: examples/sora.txt ================================================ A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, along red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is dampand reflective, creating a mirror effect of thecolorful lights. Many pedestrians walk about. Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered tree sand dramatic snow capped mountains in the distance,mid afternoon lightwith wispy cloud sand a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field A movie trailer featuring the adventures ofthe 30 year old spacemanwearing a redwool knitted motorcycle helmet, bluesky, saltdesert, cinematic style, shoton 35mm film, vivid colors.  Drone view of waves crashing against the rugged cliffs along Big Sur's garay point beach.The crashing blue waters create white-tipped waves,while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliffs edge. The steep drop from the road down to the beach is adramatic feat, with the cliff's edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway. Animated scene features a close-up of a short fluffy monster kneeling beside a melting red candle.The art style is 3D and realistic,with a focus on lighting and texture.The mood of the painting is one of wonder and curiosity, as the monster gazes at the flame with wide eyes and open mouth. lts pose and expression convey a sense of innocence and playfulness, as if it is exploring the world around it for the first time.The use of warm colors and dramatic lighting further enhances the cozy atmosphere of the image. A gorgeously rendered papercraft world of a coral reef,rife with colorful fish and sea creatures. This close-up shot of a Victoria crowned pigeon showcases its striking blue plumage and red chest. Its crest is made of delicate, lacy feathers, while its eye is a striking red color. The bird's head is tilted slightly to the side,giving the impression of it looking regal and majestic. The background is blurred,drawing attention to the bird's striking appearance. Photorealistic closeup video of two pirate ships battling each other as they sail inside a cup of coffee. A young man at his 20s is sitting on a piece of cloud in the sky, reading a book. A petri dish with a bamboo forest growing within it that has tiny red pandas running around. The camera rotates around a large stack of vintage televisions all showing different programs-1950s sci-fi movies, horror movies, news, static, a 1970s sitcom, etc, set inside a large New York museum gallery. 3D animation of a small, round, fluffy creature with big, expressive eyes explores a vibrant, enchanted forest. The creature, a whimsical blend of a rabbit and a squirrel, has soft blue fur and a bushy, striped tail. It hops along a sparkling stream,its eyes wide with wonder. The forest is alive with magical elements: flowers that glow and change colors, trees with leaves in shades of purple and silver, and small floating lights that resemble fireflies. The creature stops to interact playfully with a group of tiny, fairy-like beings dancing around a mushroom ring. The creature looks up in awe at a large, glowing tree that seems to be the heart of the forest. Historical footage of California during the gold rush. A close up view of a glass sphere that has a zen garden within it. There is a small dwarf in the sphere who is raking the zen garden and creating patterns in the sand. Extreme close up of a 24 year old woman's eye blinking, standing in Marrakech during magic hour, cinematic film shot in 70mm, depth of field,vivid colors, cinematic. A cartoon kangaroo disco dances. A beautiful homemade video showing the people of Lagos, Nigeria in the year 2056. Shot with a mobile phone camera. A cat waking up its sleeping owner demanding breakfast.The owner tries to ignore the cat, but the cat tries new tactics and finally the owner pulls out a secret stash of treats from under the pillow to hold the cat off a little longer. Borneo wildlife on the Kinabatangan River A Chinese Lunar New Year celebration video with Chinese Dragon. The camera follows behind a white vintage SUv with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from it's tires, the sunlight shines on the Suv as it speeds along the dirt road,casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars orvehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains with a clear blue sky above with wispy clouds. Reflections in the window of a train traveling through the Tokyo suburbs. A drone camera circles around a beautiful historic church built on a rocky outcropping along the Amalfi Coast, the view showcases historic and magnificent architectural details and tiered pathways and patios, waves are seen crashing against the rocks below as the view overlooks the horizon of the coastal waters and hilly landscapes of the Amalfi Coast ltaly, several distant people are seen walking and enjoying vistas on patios of the dramatic ocean views, the warm glow of the afternoon sun creates a magical and romantic feeling to the scene, the view is stunning captured with beautiful photography A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. lts tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock,its claws raised and ready to attack. The crab is brown and spiny,with long legs and antennae. The scene is captured from a wide angle,showing the vastness and depth of the ocean. The wateris clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred,creating a depth of field effect. A flock of paper airplanes flutters through a dense jungle,weaving around trees as if they were migrating birds. A beautiful silhouette animation shows a wolf howling at the moon,feeling lonely, untilit finds its pack. New York City submerged like Atlantis.Fish,whales,sea turtles and sharks swim through the streets of New York. A litter of golden retriever puppies playing in the snow.Their heads pop out of the snow, covered in. Tour of an art gallery with many beautiful works of art in different styles. Beautiful, snowy Tokyo city is bustling. The camera moves through the bustling city street, following several people enjoying the beautiful snowy weather and shopping at nearby stalls. Gorgeous sakura petals are flying through the wind along with snowflakes. A stop motion animation of a flower growing out of the windowsill of a suburban house. The story of a robot's life in a cyberpunk setting. An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt, he wears a brown beret and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and the Parisian streets and city in the background, depth of field, cinematic 35mm film. Basketball through hoop then explodes Archeologists discovera generic plastic chairin the desert,excavating and dusting it with great care A grandmother with neatly combed grey hair stands behind a colorful birthday cake with numerous candles at a wood dining room table,expression is one of pure joy and happines with a happy glow in her eye. She leans forward and blows out the candles with a gentle puff, the cake has pink frosting and sprinkles and the candles cease to flicker,the grandmotherwears a light blue blouse adorned with floral patterns,several happy friends and family sitting at the table can be seen celebrating,out of focus.The scene is beautifully captured, cinematic, showing a 3/4 view of the grandmother and the dining room. Warm color tones and soft lighting enhance the mood Step-printing scene of a person running, cinematic film shot in 35mm Five gray wolf pups frolicking and chasing each other around a remote gravel road, surrounded by grass. The pups run and leap, chasing each other, and nipping at each other, playing. Tiltshift of a construction site filled with workers, equipment, and heavy machinery. A giant, towering cloud in the shape of a man looms overthe earth. The cloud man shoots lighting bolts down to the earth. A Samoyed and a Golden Retriever dog are playfully romping through a futuristic neon city at night. The neon lights emitted from the nearby buildings glistens off of their fur. The Glenfinnan Viaduct is a historic railway bridge in Scotland, UK, that crosses over the west highland line between the towns of Mallaig and Fort Wiliam. It is a stunning sight as a steam train leaves the bridge, traveling over the arch-covered viaduct. The landscape is dotted with lush greenery and rocky mountains, creating a picturesque backdrop forthe train journey. The sky is blue and the sun is shining,making for a beautiful day to explore this majestic spot. The camera directly faces colorful buildings in Burano ltaly. An adorable dalmation looks through a window on a building on the ground floor. Many people are walking and cycling along the canal streets in front of the buildings. An adorable happy otter confidently stands on a surfboard wearing a yellow lifejacket, riding along turquoise tropical waters near lush tropical islands,3D digital render art style. This close-up shot of a chameleon showcases its striking color changing capabilities.The background is blurred, drawing attention to the animals striking appearance. A corgi vlogging itself in tropical Maui. A white and orange tabby cat is seen happily darting through a dense garden, as if chasing something.Its eyes are wide and happy as it jogs forward, scanning the branches, flowers, and leaves as it walks. The path is narrow as it makes its way between all the plants. the scene is captured from a ground-level angle, following the cat closely, giving a low and intimate perspective. The image is cinematic with warm tones and a grainy texture. The scattered daylight between the leaves and plants above creates awarm contrast, accentuating the cat's orange fur. The shot is clear and sharp, with a shallow depth of field. Aerial view of Santorini during the blue hour, showcasing the stunning architecture of white Cycladic buildings with blue domes. The caldera views are breathtaking,and the lighting creates a beautiful, serene atmosphere. Tiltshift of a construction site filled with workers, equipment, and heavy machinery. ================================================ FILE: opensora/__init__.py ================================================ # ================================================ FILE: opensora/acceleration/__init__.py ================================================ ================================================ FILE: opensora/acceleration/communications.py ================================================ import torch import torch.distributed as dist from einops import rearrange from opensora.acceleration.parallel_states import hccl_info, lccl_info, enable_LCCL try: from lcalib.functional import lcal_all2allvc except: lcal_all2allvc = None def broadcast(input_: torch.Tensor): sp_size = hccl_info.world_size src = hccl_info.rank // sp_size * sp_size dist.broadcast(input_, src=src, group=hccl_info.group) _COUNT = 0 def _all_to_all( input_: torch.Tensor, scatter_dim: int, gather_dim: int, ): group = hccl_info.group sp_size = hccl_info.world_size input_list = [t.contiguous() for t in torch.tensor_split(input_, sp_size, scatter_dim)] output_list = [torch.empty_like(input_list[0]) for _ in range(sp_size)] dist.all_to_all(output_list, input_list, group=group) return torch.cat(output_list, dim=gather_dim).contiguous() def _single_all_to_all( input_: torch.Tensor, scatter_dim: int, gather_dim: int, enable_HCCL=False, ): if enable_LCCL: sp_size = lccl_info.world_size else: sp_size = hccl_info.world_size inp_shape = list(input_.shape) inp_shape[scatter_dim] = inp_shape[scatter_dim] // sp_size if scatter_dim < 1: input_t = input_.reshape( [sp_size, inp_shape[scatter_dim]] + \ inp_shape[scatter_dim + 1:] ) else: # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! input_t = input_.reshape( [-1, sp_size, inp_shape[scatter_dim]] + \ inp_shape[scatter_dim + 1:] ).transpose(0, 1).contiguous() output = torch.empty_like(input_t) if enable_LCCL and not enable_HCCL: matrix_count = torch.ones([sp_size, sp_size], dtype=torch.int64, device=input_t.device) * ( input_t.numel() // sp_size) lcal_all2allvc(input_t, output, matrix_count, lccl_info.group) else: dist.all_to_all_single(output, input_t, group=hccl_info.group) # if scattering the seq-dim, transpose the heads back to the original dimension if scatter_dim < 1: output = output.transpose(0, 1).contiguous() return output.reshape( inp_shape[: gather_dim] + [inp_shape[gather_dim] * sp_size, ] + inp_shape[gather_dim + 1:]) class _AllToAll(torch.autograd.Function): """All-to-all communication. Args: input_: input matrix process_group: communication group scatter_dim: scatter dimension gather_dim: gather dimension """ @staticmethod def forward(ctx, input_, scatter_dim, gather_dim, all_to_all_func): ctx.scatter_dim = scatter_dim ctx.gather_dim = gather_dim ctx.all_to_all = all_to_all_func output = ctx.all_to_all(input_, scatter_dim, gather_dim) return output @staticmethod def backward(ctx, grad_output): grad_output = ctx.all_to_all( grad_output, ctx.gather_dim, ctx.scatter_dim, ) return ( grad_output, None, None, None, ) def all_to_all_SBH( input_: torch.Tensor, scatter_dim: int = 1, gather_dim: int = 0, ): return _AllToAll.apply(input_, scatter_dim, gather_dim, _single_all_to_all) def all_to_all_BSND( input_: torch.Tensor, scatter_dim: int = 2, gather_dim: int = 1, ): return _AllToAll.apply(input_, scatter_dim, gather_dim, _all_to_all) def prepare_parallel_data( hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask, pooled_projections=None, ): def all_to_all( hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask, pooled_projections, ): # hidden_states (b c t h w) -gather0-> (sp*b c t h w) -scatter2-> (sp*b c t//sp h w) # encoder_hidden_states (b sp l/sp d) -gather0-> (sp*b sp l/sp d) -scatter1-> (sp*b 1 l/sp d) # attention_mask (b t*sp h w) -gather0-> (sp*b t*sp h w) -scatter1-> (sp*b t h w) # encoder_attention_mask (b sp l) -gather0-> (sp*b sp l) -scatter1-> (sp*b 1 l) # pooled_projections (b sp d) -gather0-> (sp*b sp d) -scatter1-> (sp*b 1 d) hidden_states = _single_all_to_all(hidden_states, scatter_dim=2, gather_dim=0, enable_HCCL=True) encoder_hidden_states = _single_all_to_all(encoder_hidden_states, scatter_dim=1, gather_dim=0, enable_HCCL=True) attention_mask = _single_all_to_all(attention_mask, scatter_dim=1, gather_dim=0, enable_HCCL=True) encoder_attention_mask = _single_all_to_all(encoder_attention_mask, scatter_dim=1, gather_dim=0, enable_HCCL=True) if pooled_projections is not None: pooled_projections = _single_all_to_all(pooled_projections, scatter_dim=1, gather_dim=0, enable_HCCL=True) return hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask, pooled_projections sp_size = hccl_info.world_size frame = hidden_states.shape[2] assert frame % sp_size == 0, "frame should be a multiple of sp_size" encoder_hidden_states = rearrange( encoder_hidden_states, 'b 1 (n x) h -> b n x h', n=sp_size, x=encoder_hidden_states.shape[2]//sp_size ).contiguous() hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask, pooled_projections = all_to_all( hidden_states, encoder_hidden_states, attention_mask.repeat(1, sp_size, 1, 1), encoder_attention_mask.repeat(1, sp_size, 1), pooled_projections.repeat(1, sp_size, 1) ) return hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask, pooled_projections ================================================ FILE: opensora/acceleration/parallel_states.py ================================================ import torch import torch_npu import torch.distributed as dist import os try: from lcalib.functional import lcal_initialize enable_LCCL = True except: lcal_initialize = None enable_LCCL = False class COMM_INFO: def __init__(self): self.group = None self.world_size = 0 self.rank = -1 lccl_info = COMM_INFO() hccl_info = COMM_INFO() _SEQUENCE_PARALLEL_STATE = False def initialize_sequence_parallel_state(sequence_parallel_size): global _SEQUENCE_PARALLEL_STATE if sequence_parallel_size > 1: _SEQUENCE_PARALLEL_STATE = True initialize_sequence_parallel_group(sequence_parallel_size) def set_sequence_parallel_state(state): global _SEQUENCE_PARALLEL_STATE _SEQUENCE_PARALLEL_STATE = state def get_sequence_parallel_state(): return _SEQUENCE_PARALLEL_STATE def initialize_sequence_parallel_group(sequence_parallel_size): """Initialize the sequence parallel group.""" rank = int(os.getenv('RANK', '0')) world_size = int(os.getenv("WORLD_SIZE", '1')) assert world_size % sequence_parallel_size == 0, "world_size must be divisible by sequence_parallel_size" # hccl hccl_info.world_size = sequence_parallel_size hccl_info.rank = rank num_sequence_parallel_groups: int = world_size // sequence_parallel_size for i in range(num_sequence_parallel_groups): ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) group = dist.new_group(ranks) if rank in ranks: hccl_info.group = group if enable_LCCL: assert sequence_parallel_size == 8, "sequence_parallel_size should be 8 when enable_LCCL is True" rank %= sequence_parallel_size lccl_info.world_size = sequence_parallel_size lccl_info.group = lcal_initialize(rank, sequence_parallel_size) lccl_info.rank = rank def destroy_sequence_parallel_group(): """Destroy the sequence parallel group.""" dist.destroy_process_group() ================================================ FILE: opensora/adaptor/__init__.py ================================================ ================================================ FILE: opensora/adaptor/bf16_optimizer.py ================================================ # Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team from collections import OrderedDict import torch import sys import os from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from deepspeed import comm as dist from deepspeed.runtime.constants import PIPE_REPLICATED from deepspeed.runtime import ZeROOptimizer from packaging import version as pkg_version from deepspeed.git_version_info import version from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim, align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank, is_model_parallel_parameter, see_memory_usage, graph_process) from deepspeed.utils import link_hp_params, fragment_address from deepspeed.checkpoint import enable_universal_checkpoint from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE, SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS, PARAM_SLICE_MAPPINGS) setattr(sys.modules[__name__], 'fragment_address', fragment_address) def contigous_flatten(tensors): return _flatten_dense_tensors([tensor.contiguous() for tensor in tensors]) class BF16_Optimizer(ZeROOptimizer): def __init__(self, init_optimizer, param_names, mpu=None, clip_grad=0.0, norm_type=2, allgather_bucket_size=5000000000, dp_process_group=None, timers=None, grad_acc_dtype=None, graph_harvesting=False): # super().__init__() # base_class = ZeROOptimizer.__bases__[0] # # 直接调用基类的 __init__ 方法 # base_class.__init__() see_memory_usage('begin bf16_optimizer', force=True) self.timers = timers self.optimizer = init_optimizer self.param_names = param_names self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim) assert grad_acc_dtype in [torch.float32, torch.bfloat16 ], f"BF16Optimizer: Unsupported gradient accumulation data type: {grad_acc_dtype}" self.grad_acc_dtype = grad_acc_dtype self.clip_grad = clip_grad self.norm_type = norm_type self.mpu = mpu self.allgather_bucket_size = int(allgather_bucket_size) self.dp_process_group = dp_process_group self.dp_rank = dist.get_rank(group=self.dp_process_group) self.real_dp_process_group = [dp_process_group for i in range(len(self.optimizer.param_groups))] # Use torch (un)flatten ops self.flatten = contigous_flatten self.unflatten = _unflatten_dense_tensors #align nccl all-gather send buffers to 4-bye boundary self.nccl_start_alignment_factor = 16 # Build BF16/FP32 groups self.bf16_groups = [] self.bf16_groups_flat = [] self.bf16_partitioned_groups = [] self.fp32_groups_flat_partition = [] # Maintain different fp32 gradients views for convenience self.fp32_groups_gradients = [] self.fp32_groups_gradient_dict = {} self.fp32_groups_gradients_flat = [] self.fp32_groups_actual_gradients_flat = [] self.fp32_groups_gradient_flat_partition = [] self.fp32_groups_has_gradients = [] self.group_paddings = [] self.graph_harvesting = graph_harvesting if self.using_real_optimizer: self._setup_for_real_optimizer() see_memory_usage('end bf16_optimizer', force=True) def _setup_for_real_optimizer(self): dp_world_size = dist.get_world_size(group=self.dp_process_group) self.partition_count = [dp_world_size for i in range(len(self.optimizer.param_groups))] for i, param_group in enumerate(self.optimizer.param_groups): see_memory_usage(f'before initializing group {i}', force=True) partition_id = dist.get_rank(group=self.real_dp_process_group[i]) # grab the original list trainable_parameters = [param for param in param_group['params'] if param.requires_grad] self.bf16_groups.append(trainable_parameters) # create flat bf16 params self.bf16_groups_flat.append( self._flatten_dense_tensors_aligned(self.bf16_groups[i], self.nccl_start_alignment_factor * dp_world_size)) # Make bf16 params point to flat tensor storage self._update_storage_to_flattened_tensor(tensor_list=self.bf16_groups[i], flat_tensor=self.bf16_groups_flat[i]) # divide flat weights into equal sized partitions partition_size = self.bf16_groups_flat[i].numel() // dp_world_size bf16_dp_partitions = [ self.bf16_groups_flat[i].narrow(0, dp_index * partition_size, partition_size) for dp_index in range(dp_world_size) ] self.bf16_partitioned_groups.append(bf16_dp_partitions) # create fp32 params partition self.fp32_groups_flat_partition.append(bf16_dp_partitions[partition_id].clone().float().detach()) self.fp32_groups_flat_partition[i].requires_grad = True num_elem_list = [t.numel() for t in self.bf16_groups[i]] # create fp32 gradients self.fp32_groups_gradients_flat.append( torch.zeros_like(self.bf16_groups_flat[i], dtype=self.grad_acc_dtype)) # track individual fp32 gradients for entire model fp32_gradients = self._split_flat_tensor(flat_tensor=self.fp32_groups_gradients_flat[i], num_elem_list=num_elem_list) self.fp32_groups_gradients.append(fp32_gradients) self.fp32_groups_gradient_dict[i] = fp32_gradients # flat tensor corresponding to actual fp32 gradients (i.e., minus alignment padding) length_without_padding = sum(num_elem_list) self.fp32_groups_actual_gradients_flat.append( torch.narrow(self.fp32_groups_gradients_flat[i], 0, 0, length_without_padding)) # flat tensor corresponding to gradient partition self.fp32_groups_gradient_flat_partition.append( torch.narrow(self.fp32_groups_gradients_flat[i], 0, partition_id * partition_size, partition_size)) # track fp32 gradient updates self.fp32_groups_has_gradients.append([False] * len(self.bf16_groups[i])) # Record padding required for alignment if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1: padding = self.bf16_groups_flat[i].numel() - length_without_padding else: padding = 0 self.group_paddings.append(padding) # update optimizer param groups to reference fp32 params partition param_group['params'] = [self.fp32_groups_flat_partition[i]] see_memory_usage(f'after initializing group {i}', force=True) see_memory_usage('before initialize_optimizer', force=True) self.initialize_optimizer_states() see_memory_usage('end initialize_optimizer', force=True) # Need optimizer states initialized before linking lp to optimizer state self._link_all_hp_params() self._enable_universal_checkpoint() self._param_slice_mappings = self._create_param_mapping() def _enable_universal_checkpoint(self): for lp_param_group in self.bf16_groups: enable_universal_checkpoint(param_list=lp_param_group) def _create_param_mapping(self): param_mapping = [] for i, _ in enumerate(self.optimizer.param_groups): param_mapping_per_group = OrderedDict() for lp in self.bf16_groups[i]: if lp._hp_mapping is not None: lp_name = self.param_names[lp] param_mapping_per_group[lp_name] = lp._hp_mapping.get_hp_fragment_address() param_mapping.append(param_mapping_per_group) return param_mapping def _link_all_hp_params(self): dp_world_size = dist.get_world_size(group=self.dp_process_group) for i, _ in enumerate(self.optimizer.param_groups): # Link bf16 and fp32 params in partition partition_id = dist.get_rank(group=self.real_dp_process_group[i]) partition_size = self.bf16_groups_flat[i].numel() // dp_world_size flat_hp_partition = self.fp32_groups_flat_partition[i] link_hp_params(lp_param_list=self.bf16_groups[i], flat_hp_partition=flat_hp_partition, gradient_dict=self.fp32_groups_gradient_dict, offload_gradient_dict=None, use_offload=False, param_group_index=i, partition_start=partition_id * partition_size, partition_size=partition_size, partition_optimizer_state=self.optimizer.state[flat_hp_partition], dp_group=self.real_dp_process_group[i]) def initialize_optimizer_states(self): """Take an optimizer step with zero-valued gradients to allocate internal optimizer state. This helps prevent memory fragmentation by allocating optimizer state at the beginning of training instead of after activations have been allocated. """ for param_partition, grad_partition in zip(self.fp32_groups_flat_partition, self.fp32_groups_gradient_flat_partition): # In case of grad acc dtype different than FP32, need to cast to high precision. param_partition.grad = grad_partition.to( param_partition.dtype) if grad_partition.dtype != param_partition.dtype else grad_partition self.optimizer.step() if self.grad_acc_dtype is not torch.float32: for param_partition in self.fp32_groups_flat_partition: param_partition.grad = None self.clear_hp_grads() def _split_flat_tensor(self, flat_tensor, num_elem_list): assert sum(num_elem_list) <= flat_tensor.numel() tensor_list = [] offset = 0 for num_elem in num_elem_list: dense_tensor = torch.narrow(flat_tensor, 0, offset, num_elem) tensor_list.append(dense_tensor) offset += num_elem return tensor_list def _update_storage_to_flattened_tensor(self, tensor_list, flat_tensor): updated_params = self.unflatten(flat_tensor, tensor_list) for p, q in zip(tensor_list, updated_params): p.data = q.data def _flatten_dense_tensors_aligned(self, tensor_list, alignment): return self.flatten(align_dense_tensors(tensor_list, alignment)) @torch.no_grad() def step(self, closure=None): if closure is not None: raise NotImplementedError(f'{self.__class__} does not support closure.') all_groups_norm = get_global_norm_of_tensors(input_tensors=self.get_grads_for_norm(), mpu=self.mpu, norm_type=self.norm_type, use_graph=self.graph_harvesting) self._global_grad_norm = all_groups_norm assert all_groups_norm > 0. if self.clip_grad > 0.: clip_tensors_by_global_norm(input_tensors=self.get_grads_for_norm(for_clipping=True), max_norm=self.clip_grad, global_norm=all_groups_norm, mpu=self.mpu, use_graph=self.graph_harvesting) self.optimizer.step() self.update_lp_params() self.clear_hp_grads() def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs): """Perform a backward pass and copy the low-precision gradients to the high-precision copy. We copy/accumulate to the high-precision grads now to prevent accumulating in the bf16 grads after successive backward() calls (i.e., grad accumulation steps > 1) The low-precision grads are deallocated during this procedure. """ self.clear_lp_grads() loss.backward(**bwd_kwargs) if update_hp_grads: self.update_hp_grads(clear_lp_grads=clear_lp_grads) @torch.no_grad() def update_hp_grads(self, clear_lp_grads=False): def _update_hp_grads_func(clear_lp_grads=False): for i, group in enumerate(self.bf16_groups): for j, lp in enumerate(group): if lp.grad is None: continue hp_grad = self.fp32_groups_gradients[i][j] assert hp_grad is not None, \ f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{i}][{j}]' hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape)) lp._hp_grad = hp_grad self.fp32_groups_has_gradients[i][j] = True # clear gradients if clear_lp_grads: lp.grad._zero() if self.graph_harvesting: graph_process(False, _update_hp_grads_func, clear_lp_grads) else: _update_hp_grads_func(clear_lp_grads) #cpu op for i, group in enumerate(self.bf16_groups): for j, lp in enumerate(group): if lp.grad is None: continue self.fp32_groups_has_gradients[i][j] = True @torch.no_grad() def get_grads_for_reduction(self): return self.fp32_groups_gradients_flat @torch.no_grad() def get_grads_for_norm(self, for_clipping=False): grads = [] tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) for i, group in enumerate(self.bf16_groups): for j, lp in enumerate(group): if not for_clipping: if hasattr(lp, PIPE_REPLICATED) and lp.ds_pipe_replicated: continue if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp)): continue if not self.fp32_groups_has_gradients[i][j]: continue grads.append(self.fp32_groups_gradients[i][j]) return grads @torch.no_grad() def update_lp_params(self): for i, (bf16_partitions, fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)): partition_id = dist.get_rank(group=self.real_dp_process_group[i]) bf16_partitions[partition_id].data.copy_(fp32_partition.data) # print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True) # if i == 0: # print_rank_0(f'{fp32_partition[:10]=}', force=True) all_gather_dp_groups(groups_flat=self.bf16_groups_flat, partitioned_param_groups=self.bf16_partitioned_groups, dp_process_group=self.real_dp_process_group, start_alignment_factor=self.nccl_start_alignment_factor, allgather_bucket_size=self.allgather_bucket_size) def clear_hp_grads(self): for flat_gradients in self.fp32_groups_gradients_flat: flat_gradients.zero_() for i, group in enumerate(self.fp32_groups_gradients): self.fp32_groups_has_gradients[i] = [False] * len(group) def clear_lp_grads(self): for group in self.bf16_groups: for param in group: if param.grad is not None: # Using zero_() fixed memory address for graph replay param.grad.zero_() def state_dict(self): state_dict = {} state_dict[CLIP_GRAD] = self.clip_grad state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict() state_dict[SINGLE_PARTITION_OF_FP32_GROUPS] = self.fp32_groups_flat_partition state_dict[GROUP_PADDINGS] = self.group_paddings state_dict[PARTITION_COUNT] = self.partition_count state_dict[DS_VERSION] = version state_dict[PARAM_SLICE_MAPPINGS] = self._param_slice_mappings return state_dict # Restore base optimizer fp32 weights bfloat16 weights def _restore_from_bit16_weights(self): for i, group in enumerate(self.bf16_groups): partition_id = dist.get_rank(group=self.real_dp_process_group[i]) for bf16_partitions, fp32_partition in zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition): fp32_partition.data.copy_(bf16_partitions[partition_id].data) def refresh_fp32_params(self): self._restore_from_bit16_weights() def load_state_dict(self, state_dict_list, checkpoint_folder, load_optimizer_states=True, load_from_fp32_weights=False, load_serial=None): if checkpoint_folder: self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights) else: self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights) def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False): dp_rank = dist.get_rank(group=self.dp_process_group) current_rank_sd = state_dict_list[dp_rank] ckpt_version = current_rank_sd.get(DS_VERSION, False) assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed" ckpt_version = pkg_version.parse(ckpt_version) self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad) if load_optimizer_states: self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE]) if load_from_fp32_weights: for current, saved in zip(self.fp32_groups_flat_partition, current_rank_sd[SINGLE_PARTITION_OF_FP32_GROUPS]): src_tensor = _get_padded_tensor(saved, current.numel()) current.data.copy_(src_tensor.data) if load_optimizer_states: self._link_all_hp_params() def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights): self._load_hp_checkpoint_state(checkpoint_folder) @property def param_groups(self): """Forward the wrapped optimizer's parameters.""" return self.optimizer.param_groups def _load_hp_checkpoint_state(self, checkpoint_dir): checkpoint_dir = os.path.join(checkpoint_dir, "zero") tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) tp_world_size = self.mpu.get_slice_parallel_world_size() for i, _ in enumerate(self.optimizer.param_groups): for lp in self.bf16_groups[i]: if lp._hp_mapping is not None: #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}") lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank, tp_world_size) def _get_padded_tensor(src_tensor, size): if src_tensor.numel() >= size: return src_tensor padded_tensor = torch.zeros(size, dtype=src_tensor.dtype, device=src_tensor.device) slice_tensor = torch.narrow(padded_tensor, 0, 0, src_tensor.numel()) slice_tensor.data.copy_(src_tensor.data) return padded_tensor ================================================ FILE: opensora/adaptor/engine.py ================================================ # Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team import os import re import stat import torch import hashlib from collections import defaultdict, OrderedDict, deque from shutil import copyfile import gc from torch.nn.modules import Module from torch.nn.parameter import Parameter from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from typing import Callable, Dict, Union, Iterable import deepspeed from deepspeed import comm as dist from deepspeed.runtime.utils import see_memory_usage, DummyOptim from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer from deepspeed.runtime.bf16_optimizer import BF16_Optimizer from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \ ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \ TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, \ MUSGD_OPTIMIZER, LION_OPTIMIZER from deepspeed.runtime.dataloader import DeepSpeedDataLoader from deepspeed.runtime.constants import \ ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \ PLD_THETA, PLD_GAMMA, BFLOAT16, FP16, AMP, GRADIENT_ACCUMULATION_STEPS, \ DATA_PARALLEL_GROUP, GLOBAL_RANK from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.compression import compression_scheduler from deepspeed.compression.constants import \ WEIGHT_QUANTIZE_IN_FORWARD_ENABLED, \ WEIGHT_QUANTIZATION, SHARED_PARAMETERS, \ WEIGHT_QUANTIZE_ENABLED, \ WEIGHT_QUANTIZE_GROUPS, \ WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE, \ WEIGHT_QUANTIZE_CHANGE_RATIO, \ WEIGHT_QUANTIZE_TYPE, \ WEIGHT_QUANTIZE_ROUNDING, \ WEIGHT_QUANTIZE_VERBOSE, \ WEIGHT_QUANTIZE_KERNEL from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FROZEN_PARAM_FRAGMENTS from deepspeed.runtime.sparse_tensor import SparseTensor from deepspeed.runtime import lr_schedules from deepspeed.utils import groups from deepspeed.utils import logger, log_dist, instrument_w_nvtx from deepspeed.utils.timer import NoopTimer, ThroughputTimer, SynchronizedWallClockTimer, \ FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER, \ STEP_MICRO_TIMER, \ FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_GLOBAL_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER, \ STEP_GLOBAL_TIMER from deepspeed.utils.debug import debug_extract_module_and_param_names from deepspeed.monitor.monitor import MonitorMaster from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop from deepspeed.runtime.utils import clip_grad_norm_ from deepspeed.runtime.eigenvalue import Eigenvalue from deepspeed.runtime.data_pipeline.constants import DATA_SAMPLING, \ DATA_ROUTING, DATA_SAMPLING_ENABLED, CURRICULUM_LEARNING, \ CURRICULUM_LEARNING_ENABLED, DATA_SAMPLING_NUM_WORKERS, RANDOM_LTD, \ RANDOM_LTD_ENABLED, RANDOM_LTD_LAYER_ID, RANDOM_LTD_LAYER_NUM, \ RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE, RANDOM_LTD_LAYER_TOKEN_LR_ENABLED, \ RANDOM_LTD_GLOBAL_BATCH_SIZE, RANDOM_LTD_MICRO_BATCH_SIZE, DATA_EFFICIENCY from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler from deepspeed.runtime.data_pipeline.data_routing.scheduler import RandomLTDScheduler from deepspeed.runtime.data_pipeline.data_routing.helper import remove_random_ltd_state_dict from deepspeed.runtime.data_pipeline.data_routing.basic_layer import RandomLayerTokenDrop from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint from deepspeed.runtime.pipe.module import PipelineModule from deepspeed.runtime.utils import get_ma_status from deepspeed.ops.adam import FusedAdam from deepspeed.moe.sharded_moe import TopKGate, MOELayer from deepspeed.moe.layer import MoE from deepspeed.moe.utils import is_moe_param from deepspeed.git_version_info import version from deepspeed.profiling.flops_profiler.profiler import FlopsProfiler from deepspeed.utils.logging import print_json_dist, print_configuration from deepspeed.accelerator import get_accelerator from deepspeed.runtime.config import DtypeEnum from opensora.adaptor.zp_manager import zp_manager MEMORY_OPT_ALLREDUCE_SIZE = 500000000 DeepSpeedOptimizerCallable = \ Callable[[Union[Iterable[Parameter], Dict[str, Iterable]]], Optimizer] DeepSpeedSchedulerCallable = Callable[[Optimizer], _LRScheduler] try: import apex from apex import amp APEX_INSTALLED = True except ImportError: # Fail silently so we don't spam logs unnecessarily if user isn't using amp APEX_INSTALLED = False def split_half_float_double_sparse(tensors): device_type = get_accelerator().device_name() supported_types = [ "torch.{}.HalfTensor".format(device_type), "torch.{}.FloatTensor".format(device_type), "torch.{}.DoubleTensor".format(device_type), "torch.{}.BFloat16Tensor".format(device_type), SparseTensor.type() ] for t in tensors: assert t.type() in supported_types, f"attempting to reduce an unsupported grad type: {t.type()}" buckets = [] for i, dtype in enumerate(supported_types): bucket = [t for t in tensors if t.type() == dtype] if bucket: buckets.append((dtype, bucket)) return buckets class EngineTimers(object): r"""Wallclock timers for DeepSpeedEngine""" def __init__(self, enable_micro_timers, enable_global_timers): self.forward_timers = [] self.backward_timers = [] self.backward_inner_timers = [] self.backward_reduce_timers = [] self.step_timers = [] self.global_timers = [] self.micro_timers = [] if enable_micro_timers: self.forward_timers += [FORWARD_MICRO_TIMER] self.backward_timers += [BACKWARD_MICRO_TIMER] self.backward_inner_timers += [BACKWARD_INNER_MICRO_TIMER] self.backward_reduce_timers += [BACKWARD_REDUCE_MICRO_TIMER] self.step_timers += [STEP_MICRO_TIMER] self.micro_timers += [ FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER, STEP_MICRO_TIMER ] if enable_global_timers: self.forward_timers += [FORWARD_GLOBAL_TIMER] self.backward_timers += [BACKWARD_GLOBAL_TIMER] self.backward_inner_timers += [BACKWARD_INNER_GLOBAL_TIMER] self.backward_reduce_timers += [BACKWARD_REDUCE_GLOBAL_TIMER] self.step_timers += [STEP_GLOBAL_TIMER] self.global_timers += [ FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_GLOBAL_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER, STEP_GLOBAL_TIMER ] class DeepSpeedEngine(Module): r"""DeepSpeed engine for training.""" def __init__( self, args, model, optimizer=None, model_parameters=None, training_data=None, lr_scheduler=None, mpu=None, dist_init_required=None, collate_fn=None, config=None, config_class=None, dont_change_device=False, ): super(DeepSpeedEngine, self).__init__() self.dont_change_device = dont_change_device self.client_optimizer = optimizer self.client_lr_scheduler = lr_scheduler self.training_data = training_data self.collate_fn = collate_fn self.mpu = mpu self.all_to_all_group = None self.data_parallel_group = None self.global_steps = 0 self.global_samples = 0 self.micro_steps = 0 self.skipped_steps = 0 self.gradient_average = True self.warn_unscaled_loss = True self.config = config self._config = config_class self.loaded_checkpoint_mp_world_size = None self.loaded_checkpoint_dp_world_size = None self.loaded_checkpoint_zp_world_size = None self.enable_backward_allreduce = True self.progressive_layer_drop = None self.eigenvalue = None self.block_eigenvalue = None self.gas_boundary_ctr = 0 self.dist_backend = get_accelerator().communication_backend_name() self.has_moe_layers = False self.num_experts = [] self.gate_modules = [] self.moe_layers = [] self._step_applied = False self._global_grad_norm = None self.use_ds_comm = False # False --> Use torch.dist, True --> Use ds.comm backend. self.checkpoint_engine = None self._is_gradient_accumulation_boundary = None self.scale_wrt_gas = None self.losses = 0.0 # for debug purposes - can then debug print: debug_get_module_name(module) debug_extract_module_and_param_names(model) self._do_args_sanity_check(args) self._configure_with_arguments(args, mpu) self._do_sanity_check() see_memory_usage(f"DeepSpeed Engine: After args sanity test", force=self.memory_breakdown()) if mpu is not None: if self.elasticity_enabled(): if not self.is_elastic_model_parallel_supported(): assert not self.elasticity_enabled(), ("Elasticity is not currently supported" " with model parallelism.") self._set_distributed_vars(args) dist.configure(self._config) self.monitor = MonitorMaster(self._config.monitor_config) see_memory_usage( f"DeepSpeed Engine: Before configure distributed model", force=self.memory_breakdown(), ) self.pipeline_parallelism = isinstance(model, PipelineModule) # Configure distributed model self._configure_distributed_model(model) # needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict self.param_names = {param: name for name, param in model.named_parameters()} self._get_model_parameters() see_memory_usage(f"DeepSpeed Engine: After configure distributed model") # Configure wall clock timers self.timers = SynchronizedWallClockTimer() # Throughput timer self.tput_timer = ThroughputTimer( batch_size=self.train_batch_size(), steps_per_output=self.steps_per_print(), monitor_memory=False, ) log_dist(f"DeepSpeed Flops Profiler Enabled: {self.flops_profiler_enabled()}", ranks=[0]) if self.flops_profiler_enabled(): self.flops_profiler = FlopsProfiler(self.module, self, self.flops_profiler_recompute_fwd_factor()) if training_data: self.training_dataloader = self.deepspeed_io(training_data) else: self.training_dataloader = None # Configure optimizer and scheduler self.optimizer = None self.basic_optimizer = None self.lr_scheduler = None has_optimizer = False if optimizer or self.optimizer_name(): has_optimizer = True # If no parameters given by init default to module parameters if model_parameters is None: model_parameters = self.module.parameters() # Convert model parameters from generator to list if not isinstance(model_parameters, list): model_parameters = list(model_parameters) if has_optimizer: self._configure_optimizer(optimizer, model_parameters) self._configure_lr_scheduler(lr_scheduler) self._report_progress(0) elif self.zero_optimization(): # no optim selected but zero is enabled self.optimizer = self._configure_zero_optimizer(optimizer=None) elif self.bfloat16_enabled(): self.optimizer = self._configure_bf16_optimizer(optimizer=None) # Hook optimizer for snip_momentum pruning if hasattr(model, 'pruners'): from deepspeed.compression.helper import rewrite_optimizer_step self.optimizer.pruners = model.pruners rewrite_optimizer_step(self.optimizer) # Bookkeeping for sparse support self.sparse_tensor_module_names = set() # if self.sparse_gradients_enabled(): for name, module in self.module.named_modules(): if isinstance(module, (torch.nn.Embedding, torch.nn.EmbeddingBag)) and self.sparse_gradients_enabled(): self.sparse_tensor_module_names.add(name + ".weight") logger.info("Will convert {} to sparse tensor during training".format(name)) self.save_non_zero_checkpoint = False self.save_zero_checkpoint = False if not isinstance(self.optimizer, DeepSpeedZeRoOffload): self._configure_checkpointing(dist_init_required) if self.eigenvalue_enabled(): self.eigenvalue = self._configure_eigenvalue() if self.pld_enabled(): self.progressive_layer_drop = self._configure_progressive_layer_drop() if self.curriculum_enabled_legacy(): self.curriculum_scheduler_legacy = self._configure_curriculum_scheduler_legacy() if self.random_ltd_enabled(): random_ltd_config = self.random_ltd_config() random_ltd_config[RANDOM_LTD_GLOBAL_BATCH_SIZE] = self.train_batch_size() random_ltd_config[RANDOM_LTD_MICRO_BATCH_SIZE] = self.train_micro_batch_size_per_gpu() self.random_ltd_scheduler = self._configure_random_ltd_scheduler(random_ltd_config) # Engine timers self.engine_timers = EngineTimers(enable_micro_timers=self.wall_clock_breakdown(), enable_global_timers=self.wall_clock_breakdown() or self.flops_profiler_enabled()) if self.global_rank == 0: self._config.print("DeepSpeedEngine configuration") if self.dump_state(): print_configuration(self, "DeepSpeedEngine") # Use torch (un)flatten ops self.flatten = _flatten_dense_tensors self.unflatten = _unflatten_dense_tensors def destroy(self): if self.optimizer is not None and hasattr(self.optimizer, 'destroy'): self.optimizer.destroy() def _get_model_parameters(self): if self.autotuning_profile_model_info(): self.autotuning_model_info = {} num_params = 0 trainable_num_params = 0 for p in self.module.parameters(): # since user code might call deepspeed.zero.Init() before deepspeed.initialize(), need to check the attribute to check if the parameter is partitioned in zero 3 already or not n = 0 if hasattr(p, "ds_tensor"): # if the parameter is partitioned in zero 3 n += p.ds_numel else: # if the parameter is not partitioned in zero 3 yet n += p.numel() num_params += n if p.requires_grad: trainable_num_params += n if self.global_rank == 0: self.autotuning_model_info["num_params"] = num_params * self.mp_world_size self.autotuning_model_info["trainable_num_params"] = trainable_num_params * self.mp_world_size logger.info(f"model parameter = {num_params}") def get_batch_info(self): """Get all training batch related settings. Returns: train_batch_size (int): The effective training batch size. This is the amount of data samples that leads to one step of model update. train_micro_batch_size_per_gpu (int): Batch size to be processed by one GPU in one step (without gradient accumulation). gradient_accumulation_steps (int): Number of training steps to accumulate gradients before averaging and applying them. """ return ( self.train_batch_size, self.train_micro_batch_size_per_gpu, self.gradient_accumulation_steps, ) def set_train_batch_size(self, train_batch_size): """Adjust the global batch size by increasing or decreasing the number of micro-batches (i.e., gradient accumulation steps). The size of each micro-batch (i.e., ``train_micro_batch_size_per_gpu``) is not changed. Args: train_batch_size (int): The new global batch size for training. Raises: ValueError: if ``train_batch_size`` is not divisible by the configured micro-batch size and data parallelism. """ if train_batch_size % (self.train_micro_batch_size_per_gpu() * self.dp_world_size) != 0: #print(f'{train_batch_size=} {self.train_micro_batch_size_per_gpu()=} {self.dp_world_size=}') raise ValueError(f'Train batch size must be divisible by micro-batch data parallelism') new_gas = train_batch_size // (self.train_micro_batch_size_per_gpu() * self.dp_world_size) # overwrite config self._config.train_batch_size = train_batch_size self._config.gradient_accumulation_steps = new_gas def set_train_micro_batch_size(self, micro_batch_size): """Adjust the micro batch size(i.e., the micro batch size in every data parallel group), while keep the gradient accumulation steps the same. Args: micro_batch_size (int): The new micro batch size for training. """ # overwrite config new_global_batch_size = micro_batch_size * self._config.gradient_accumulation_steps * self.dp_world_size self._config.train_batch_size = new_global_batch_size self._config.train_micro_batch_size_per_gpu = micro_batch_size def set_data_post_process_func(self, post_process_func): if self.training_dataloader is not None: self.training_dataloader.post_process_func = post_process_func def set_custom_curriculum_learning_schedule(self, schedule_func_dict): if self.training_dataloader is not None and self.curriculum_learning_enabled(): self.training_dataloader.data_sampler.set_custom_curriculum_learning_schedule(schedule_func_dict) def get_global_grad_norm(self) -> float: """Return the 2-norm of all gradients. If there is model parallelism, the norm will be global. The computed norm will be cached and reused until the next step() pass. .. note:: In the presence of model parallelism, this is a collective call and acts as a barrier among ``mpu.get_model_parallel_group()``. Returns: float: norm """ return self._global_grad_norm def __getattr__(self, name): """ Pass through attributes defined in the model if they are not overridden by ds-engine. """ _module = {} if "module" in self.__dict__: _module = self.__dict__['module'] if name in dir(self): return getattr(self, name) elif name in dir(_module): return getattr(_module, name) else: raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") def checkpoint_tag_validation_enabled(self): return self._config.checkpoint_tag_validation_enabled def checkpoint_tag_validation_fail(self): return self._config.checkpoint_tag_validation_fail def elasticity_enabled(self): return self._config.elasticity_enabled def is_elastic_model_parallel_supported(self): if self.elasticity_enabled(): # Add code for finding number of GPUs per node automatically if self._config.num_gpus_per_node % self._config.elastic_model_parallel_size == 0: return True else: return False def pld_enabled(self): return self._config.pld_enabled def pld_params(self): return self._config.pld_params def pld_theta(self): return self.pld_params()[PLD_THETA] def pld_gamma(self): return self.pld_params()[PLD_GAMMA] def eigenvalue_enabled(self): return self._config.eigenvalue_enabled def eigenvalue_verbose(self): return self._config.eigenvalue_verbose def eigenvalue_max_iter(self): return self._config.eigenvalue_max_iter def eigenvalue_tol(self): return self._config.eigenvalue_tol def eigenvalue_stability(self): return self._config.eigenvalue_stability def eigenvalue_gas_boundary_resolution(self): return self._config.eigenvalue_gas_boundary_resolution def eigenvalue_layer_name(self): return self._config.eigenvalue_layer_name def eigenvalue_layer_num(self): return self._config.eigenvalue_layer_num def curriculum_enabled_legacy(self): return self._config.curriculum_enabled_legacy def curriculum_params_legacy(self): return self._config.curriculum_params_legacy def data_efficiency_enabled(self): return self._config.data_efficiency_enabled def data_efficiency_config(self): return self._config.data_efficiency_config def data_sampling_enabled(self): return self._config.data_efficiency_config[DATA_SAMPLING][DATA_SAMPLING_ENABLED] def data_sampling_config(self): return self._config.data_efficiency_config[DATA_SAMPLING] def curriculum_learning_enabled(self): return self._config.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_ENABLED] def curriculum_learning_config(self): return self._config.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING] def random_ltd_enabled(self): return self._config.data_efficiency_config[DATA_ROUTING][RANDOM_LTD][RANDOM_LTD_ENABLED] def random_ltd_config(self): return self._config.data_efficiency_config[DATA_ROUTING][RANDOM_LTD] def random_ltd_initialize(self): assert self.random_ltd_enabled() random_ltd_config = self.random_ltd_config() random_ltd_queue = deque([x for x in sorted(random_ltd_config[RANDOM_LTD_LAYER_ID])]) count = 0 for name, layer in self.module.named_modules(): if isinstance(layer, RandomLayerTokenDrop): if len(random_ltd_queue) != 0 and str(random_ltd_queue[0]) in name: ###[1,2,3] layer.init_config(random_ltd_config, self.random_ltd_scheduler, count) random_ltd_queue.popleft() count += 1 if random_ltd_config[RANDOM_LTD_LAYER_NUM] != count: raise ValueError(f'random_ltd_layer_num {random_ltd_config[RANDOM_LTD_LAYER_NUM]} must be \ equivalent to the len of random_ltd_layer_id {count}') if random_ltd_config[RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE][RANDOM_LTD_LAYER_TOKEN_LR_ENABLED]: assert self.client_lr_scheduler is None raise ValueError(f'not yet support') #self.lr_scheduler = lr_schedules.WarmupLayerTokenDecayLR(self.optimizer, self.random_ltd_scheduler) def wall_clock_breakdown(self): return self._config.wall_clock_breakdown def flops_profiler_enabled(self): return self._config.flops_profiler_config.enabled or self.autotuning_enabled() def flops_profiler_recompute_fwd_factor(self): return self._config.flops_profiler_config.recompute_fwd_factor def flops_profiler_profile_step(self): step = self._config.flops_profiler_config.profile_step if self._config.autotuning_config.enabled: step = self.autotuning_start_profile_step() return step def flops_profiler_module_depth(self): return self._config.flops_profiler_config.module_depth def flops_profiler_top_modules(self): return self._config.flops_profiler_config.top_modules def flops_profiler_detailed(self): if self._config.autotuning_config.enabled: return False return self._config.flops_profiler_config.detailed def flops_profiler_output_file(self): return self._config.flops_profiler_config.output_file def memory_breakdown(self): return self._config.memory_breakdown def autotuning_enabled(self): return self._config.autotuning_config.enabled def autotuning_start_profile_step(self): return self._config.autotuning_config.start_profile_step def autotuning_end_profile_step(self): return self._config.autotuning_config.end_profile_step def autotuning_metric_path(self): path = self._config.autotuning_config.metric_path if not path: path = os.path.join(os.getcwd(), "autotuning_metric.json") return path def autotuning_model_info_path(self): path = self._config.autotuning_config.model_info_path if not path: path = os.path.join(os.getcwd(), "autotuning_model_info.json") return path def autotuning_metric(self): return self._config.autotuning_config.metric def autotuning_profile_model_info(self): return self.autotuning_enabled( ) and self._config.autotuning_config.model_info and self._config.autotuning_config.model_info.get( "profile", False) def sparse_gradients_enabled(self): return self._config.sparse_gradients_enabled def train_batch_size(self): return self._config.train_batch_size def train_micro_batch_size_per_gpu(self): return self._config.train_micro_batch_size_per_gpu def optimizer_name(self): return (self.client_optimizer.__class__.__name__ if self.client_optimizer else self._config.optimizer_name) def optimizer_params(self): return self._config.optimizer_params def optimizer_legacy_fusion(self): return self._config.optimizer_legacy_fusion def scheduler_name(self): return self._config.scheduler_name def scheduler_params(self): return self._config.scheduler_params def quantize_training(self): return ( self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] [WEIGHT_QUANTIZE_IN_FORWARD_ENABLED], self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_ENABLED], self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_GROUPS], self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] [WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE], self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_CHANGE_RATIO], self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_TYPE], self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_ROUNDING], self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_VERBOSE], self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_KERNEL], ) def zero_optimization(self): return self._config.zero_enabled def zero_allow_untested_optimizer(self): return self._config.zero_allow_untested_optimizer def zero_force_ds_cpu_optimizer(self): return self._config.zero_force_ds_cpu_optimizer def zero_reduce_scatter(self): return self._config.zero_config.reduce_scatter def zero_overlap_comm(self): return self._config.zero_config.overlap_comm def zero_offload_optimizer(self): return self._config.zero_config.offload_optimizer def zero_offload_param(self): return self._config.zero_config.offload_param def zero_use_cpu_optimizer(self): if self._config.zero_config.offload_optimizer is not None: return self._config.zero_config.offload_optimizer.device in [OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme] return False def zero_cpu_offload(self): if self._config.zero_config.offload_optimizer is not None: return self._config.zero_config.offload_optimizer.device == OffloadDeviceEnum.cpu return False def zero_partial_offload(self): return getattr(self._config.zero_config.offload_optimizer, "ratio", 1.0) def zero_sub_group_size(self): return self._config.zero_config.sub_group_size def zero_optimization_stage(self): return self._config.zero_optimization_stage def mics_shard_size(self): return self._config.mics_shard_size def zero_reduce_bucket_size(self): return self._config.zero_config.reduce_bucket_size def zero_multi_rank_bucket_allreduce(self): return self._config.zero_config.use_multi_rank_bucket_allreduce def zero_allgather_bucket_size(self): return self._config.zero_config.allgather_bucket_size def zero_optimization_partition_gradients(self): return self.zero_optimization_stage() >= ZeroStageEnum.gradients def zero_optimization_partition_weights(self): return self.zero_optimization_stage() >= ZeroStageEnum.weights def is_first_weights_partition_group(self): ret = True if self.mics_shard_size() < 0 \ and self.zero_optimization_partition_weights() else False if self.mics_shard_size() > 0 and self.global_rank < self.mics_shard_size(): ret = True return ret def zero_contiguous_gradients(self): return self._config.zero_config.contiguous_gradients def zero_load_from_fp32_weights(self): return self._config.zero_config.load_from_fp32_weights def zero_elastic_checkpoint(self): return self._config.zero_config.elastic_checkpoint def zero_max_live_parameters(self): return self._config.zero_config.max_live_parameters def zero_max_reuse_distance(self): return self._config.zero_config.max_reuse_distance def zero_prefetch_bucket_size(self): return self._config.zero_config.prefetch_bucket_size def zero_param_persistence_threshold(self): return self._config.zero_config.param_persistence_threshold def zero_model_persistence_threshold(self): return self._config.zero_config.model_persistence_threshold def zero_gather_16bit_weights_on_model_save(self): return self._config.zero_config.gather_16bit_weights_on_model_save def zero_grad_hooks(self): return self._config.zero_config.grad_hooks def zero_legacy_stage1(self): return self._config.zero_config.legacy_stage1 def zero_ignore_unused_parameters(self): return self._config.zero_config.ignore_unused_parameters def graph_harvesting(self): return self._config.graph_harvesting def fp16_enabled(self): return self._config.fp16_enabled def bfloat16_enabled(self): return self._config.bfloat16_enabled def fp16_master_weights_and_gradients(self): return self._config.fp16_master_weights_and_gradients def amp_enabled(self): return self._config.amp_enabled def amp_params(self): return self._config.amp_params def fp16_auto_cast(self): return self._config.fp16_auto_cast def loss_scale(self): return self._config.loss_scale def gradient_accumulation_steps(self): return self._config.gradient_accumulation_steps def use_node_local_storage(self): return self._config.use_node_local_storage def load_universal_checkpoint(self): return self._config.load_universal_checkpoint @property def communication_data_type(self): res = self._config.communication_data_type if res is not None: return res if self.fp16_enabled(): return torch.float16 if self.bfloat16_enabled(): return torch.bfloat16 return torch.float32 @communication_data_type.setter def communication_data_type(self, value): self._config.communication_data_type = value def postscale_gradients(self): return not self._config.prescale_gradients def gradient_predivide_factor(self): return self._config.gradient_predivide_factor def steps_per_print(self): return self._config.steps_per_print def zero_allgather_partitions(self): return self._config.zero_config.allgather_partitions def zero_round_robin_gradients(self): return self._config.zero_config.round_robin_gradients def zero_hpz_partition_size(self): return self._config.zero_config.zero_hpz_partition_size def zero_quantized_weights(self): return self._config.zero_config.zero_quantized_weights def zero_quantized_nontrainable_weights(self): return self._config.zero_config.zero_quantized_nontrainable_weights def zero_quantized_gradients(self): return self._config.zero_config.zero_quantized_gradients def dump_state(self): return self._config.dump_state def gradient_clipping(self): return self._config.gradient_clipping def dynamic_loss_scale(self): return self._config.loss_scale == 0 def initial_dynamic_scale(self): return self._config.initial_dynamic_scale def dynamic_loss_scale_args(self): return self._config.dynamic_loss_scale_args def swap_tensor_config(self): return self._config.swap_tensor_config def aio_config(self): return self._config.aio_config def get_data_types(self): model_dtype = torch.float32 if self.fp16_enabled(): model_dtype = torch.float16 elif self.bfloat16_enabled(): model_dtype = torch.bfloat16 if self._config.grad_accum_dtype is None: if model_dtype == torch.bfloat16 and not self.zero_optimization(): grad_accum_dtype = torch.float32 else: grad_accum_dtype = model_dtype else: grad_accum_dtype = DtypeEnum(self._config.grad_accum_dtype).value return (model_dtype, grad_accum_dtype) def _optimizer_has_ckpt_event_prologue(self): return self.optimizer is not None and hasattr(self.optimizer, 'checkpoint_event_prologue') def _optimizer_has_ckpt_event_epilogue(self): return self.optimizer is not None and hasattr(self.optimizer, 'checkpoint_event_epilogue') def _configure_lr_scheduler(self, client_lr_scheduler): # First check for scheduler in json configuration lr_scheduler = self._scheduler_from_config(self.optimizer) if lr_scheduler: log_dist(f"DeepSpeed using configured LR scheduler = {self.scheduler_name()}", ranks=[0]) self.lr_scheduler = lr_scheduler else: if isinstance(client_lr_scheduler, Callable): log_dist('DeepSpeed using client callable to create LR scheduler', ranks=[0]) self.lr_scheduler = client_lr_scheduler(self.basic_optimizer) else: log_dist('DeepSpeed using client LR scheduler', ranks=[0]) self.lr_scheduler = client_lr_scheduler log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0]) def _configure_checkpointing(self, dist_init_required): self.checkpoint_engine = TorchCheckpointEngine() if self._config is not None and self._config.nebula_config.enabled: try: from deepspeed.runtime.checkpoint_engine.nebula_checkpoint_engine import \ NebulaCheckpointEngine self.checkpoint_engine = NebulaCheckpointEngine(config_params=self._config.nebula_config) except ImportError as err: logger.error(f"No torch_nebula was found! Will fall back to torch.save. Details: {err}") self.checkpoint_engine = TorchCheckpointEngine() dp_rank = groups._get_sequence_data_parallel_rank() rank = self.local_rank if self.use_node_local_storage() else dp_rank # only the first data parallel process needs to store the model checkpoint # if you want to use node local storage this must be done by rank 0 on each # node self.save_non_zero_checkpoint = (rank == 0) or (self.zero_optimization_partition_weights() and self.is_first_weights_partition_group()) if self.zero_optimization() or self.bfloat16_enabled(): param_rank = dist.get_rank(group=self.optimizer.zp_process_group) # Only the first parameter parallel process needs to store the # optimizer state checkpoints for zero self.save_zero_checkpoint = param_rank == dp_rank def _scheduler_from_config(self, optimizer): scheduler_name = self.scheduler_name() if scheduler_name is not None: if hasattr(lr_schedules, scheduler_name): scheduler = getattr(lr_schedules, scheduler_name) else: assert hasattr(torch.optim.lr_scheduler, scheduler_name), f"DeepSpeed does not recognize LR scheduler {scheduler_name}" scheduler = getattr(torch.optim.lr_scheduler, scheduler_name) scheduler_params = self.scheduler_params() instantiated_scheduler = scheduler(optimizer, **scheduler_params) return instantiated_scheduler else: return None def _set_distributed_vars(self, args): device_rank = args.device_rank if args is not None and hasattr(args, 'device_rank') else self.local_rank if device_rank >= 0: get_accelerator().set_device(device_rank) self.device = torch.device(get_accelerator().device_name(), device_rank) self.world_size = dist.get_world_size() self.global_rank = dist.get_rank() else: self.world_size = 1 self.global_rank = 0 self.device = torch.device(get_accelerator().device_name()) # Configure based on command line arguments def _configure_with_arguments(self, args, mpu): # After the distributed backend is initialized we are guaranteed the LOCAL_RANK # environment variable is set. We must align args.local_rank to this value for # backwards compatibility with scripts relying on [args|self].local_rank containing # the correct local rank info. _do_args_sanity_check will ensure this is the case. if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: ompi_local_rank = os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK") local_rank = os.environ.get('LOCAL_RANK', ompi_local_rank) assert ompi_local_rank == local_rank, f"LOCAL_RANK ({local_rank}) != OMPI_COMM_WORLD_LOCAL_RANK ({ompi_local_rank}), " \ "not sure how to proceed as we're seeing conflicting local rank info." os.environ['LOCAL_RANK'] = local_rank self.local_rank = int(os.environ['LOCAL_RANK']) if hasattr(args, 'local_rank'): args.local_rank = self.local_rank # Validate command line arguments def _do_args_sanity_check(self, args): assert "LOCAL_RANK" in os.environ or "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ, "DeepSpeed requires the LOCAL_RANK environment " \ "variable, it is set by the deepspeed launcher, deepspeed.init_distributed, or the torch's launcher. If using a " \ "different launcher please ensure LOCAL_RANK is set prior to initializing deepspeed." if hasattr(args, 'local_rank') and args.local_rank is not None: assert isinstance(args.local_rank, int), f"args.local_rank of {args.local_rank} is an unknown type {type(args.local_rank)}" if args.local_rank >= 0: env_local_rank = int(os.environ.get("LOCAL_RANK")) assert ( env_local_rank == args.local_rank ), f"Mismatch in local rank setting, args.local_rank={args.local_rank} but env['LOCAL_RANK']={env_local_rank}." def _is_supported_optimizer(self, optimizer_name): return (optimizer_name in DEEPSPEED_OPTIMIZERS or getattr(torch.optim, optimizer_name, None) is not None) def _supported_optims(self): FairseqOptimizer = None try: from fairseq.optim.fairseq_optimizer import FairseqOptimizer except ImportError: pass expected_optim_types = [Optimizer] if FairseqOptimizer: # fairseq optims are not torch.optim objects expected_optim_types.append(FairseqOptimizer) return expected_optim_types # Validate configuration based on command line arguments def _do_sanity_check(self): expected_optim_types = self._supported_optims() expected_optim_types += [type(None), Callable] assert isinstance(self.client_optimizer, tuple(expected_optim_types)), \ f'Client Optimizer is of unexpected type {type(self.client_optimizer)}' if not self.client_optimizer: if self.optimizer_name() is not None: assert self._is_supported_optimizer( self.optimizer_name()), "{} is not a supported DeepSpeed Optimizer".format(self.optimizer_name()) if (self.optimizer_name() == LAMB_OPTIMIZER or self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER): assert (self.dynamic_loss_scale()), "DeepSpeed {} optimizer requires dynamic loss scaling".format( self.optimizer_name()) # Detect invalid combinations of client optimizer and client scheduler if isinstance(self.client_lr_scheduler, _LRScheduler): assert isinstance(self.client_optimizer, Optimizer), \ f'Client Optimizer (type = {type(self.client_optimizer)} is not instantiated but Client LR Scheduler is instantiated' def _broadcast_model(self): def is_replicated(p): if hasattr(p, "ds_status") and p.ds_status is not ZeroParamStatus.AVAILABLE: return False return True for p in self.module.parameters(): # Broadcast the model for different parameters if is_moe_param(p): if torch.is_tensor(p) and is_replicated(p): dist.broadcast(p, groups._get_expert_broadcast_src_rank(p.group_name), group=self.expert_data_parallel_group[p.group_name]) else: if torch.is_tensor(p) and is_replicated(p): dist.broadcast(p, groups._get_broadcast_src_rank(), group=self.seq_data_parallel_group) @staticmethod def __check_params(model: Module, dtype: torch.dtype) -> None: return if not all(param.dtype == dtype for param in model.parameters()) and dist.get_rank() == 0: raise ValueError(f"{dtype} is enabled but the following parameters have dtype that is " f"not {dtype}: " f"{[(n, p.dtype) for n, p in model.named_parameters() if p.dtype != dtype]}") def _set_client_model(self, model): # register client model in _modules so that nn.module methods work correctly modules = self.__dict__.get('_modules') modules['module'] = model # register module attribute in engine but avoid getattr self.__dict__['module'] = model def _configure_distributed_model(self, model): self._set_client_model(model) is_zero_init_model = self.zero_optimization_partition_weights() and any( [hasattr(param, "ds_id") for param in self.module.parameters()]) if self.fp16_enabled(): if is_zero_init_model: self.__check_params(self.module, torch.half) self.module.half() elif self.bfloat16_enabled(): if is_zero_init_model: self.__check_params(self.module, torch.bfloat16) self.module.bfloat16() else: self.__check_params(self.module, torch.float) # zero.Init() handles device placement of model if not (self.dont_change_device or is_zero_init_model): self.module.to(self.device) # MoE related initialization for _, module in self.module.named_modules(): if isinstance(module, MoE): self.has_moe_layers = True self.num_experts.append(module.num_experts) if self.has_moe_layers: for _, module in self.module.named_modules(): if isinstance(module, TopKGate): self.gate_modules.append(module) if self.wall_clock_breakdown(): module.wall_clock_breakdown = True if isinstance(module, MOELayer): self.moe_layers.append(module) if self.wall_clock_breakdown(): module.wall_clock_breakdown = True # Pass the mpu from here to groups. For subsequent use, just query groups if self.mpu is not None: groups.mpu = self.mpu # Set deepspeed parallelism spec. for the model including expert parallelism for _, module in self.module.named_modules(): if hasattr(module, 'set_deepspeed_parallelism'): module.set_deepspeed_parallelism(self._config.use_data_before_expert_parallel_) # Query the groups module to get information about various parallel groups self.local_all_to_all_group = None if self.zero_quantized_gradients(): log_dist("Using quantized gradients", ranks=[0]) self.local_all_to_all_group = groups._get_local_all_to_all_group() self.data_parallel_group = groups._get_data_parallel_group() self.dp_world_size = groups._get_data_parallel_world_size() self.zp_world_size = zp_manager.zp_size self.seq_data_parallel_group = groups._get_sequence_data_parallel_group() self.seq_dp_world_size = groups._get_sequence_data_parallel_world_size() self.mp_world_size = groups._get_model_parallel_world_size() self.expert_parallel_group = groups._get_expert_parallel_group_dict() self.expert_data_parallel_group = groups._get_expert_data_parallel_group_dict() self.sequence_parallel_size = groups._get_sequence_parallel_world_size() if self.sequence_parallel_size > 1: self.communication_data_type = self._config.seq_parallel_communication_data_type if not (self.amp_enabled() or is_zero_init_model): self._broadcast_model() # check if parameters are duplicated in optimizer param_groups def _check_for_duplicates(self, optimizer): for name, param in self.module.named_parameters(): param_id = id(param) def ids_list(group): return [id(param) for param in group] occurrence = sum([ ids_list(group['params']).count(param_id) if param_id in ids_list(group['params']) else 0 for group in optimizer.param_groups ]) assert occurrence <= 1, f"Parameter with name: {name} occurs multiple times in optimizer.param_groups. Make sure it only appears once to prevent undefined behavior." def _do_optimizer_sanity_check(self, basic_optimizer): model_dtype, grad_accum_dtype = self.get_data_types() zero_enabled = self.zero_optimization() amp_enabled = self.amp_enabled() # config based assertions assert ( not (amp_enabled and zero_enabled) ), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2" if zero_enabled: if not is_zero_supported_optimizer(basic_optimizer): assert ( self.zero_allow_untested_optimizer() ), 'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.' if self.global_rank == 0: logger.warning("**** You are using ZeRO with an untested optimizer, proceed with caution *****") if model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32 and self.zero_optimization_stage( ) == 1 and not self.zero_cpu_offload(): return BFLOAT16 return ZERO_OPTIMIZATION elif amp_enabled: if model_dtype != grad_accum_dtype: raise NotImplementedError( "Model data type and gradient accumulation data type must be equal to use Amp") if model_dtype == torch.bfloat16 or model_dtype == torch.float16: raise NotImplementedError("Cannot enable both amp with (legacy) fp16 or bfloat16 mode") try: logger.info("Initializing Apex amp from: {}".format(amp.__path__)) except NameError: # If apex/amp is available it will be imported above raise RuntimeError("Unable to import apex/amp, please make sure it is installed") return AMP # data type checks elif model_dtype == grad_accum_dtype: if model_dtype == torch.bfloat16: if self.pipeline_parallelism: logger.warning( "**** BF16 gradient accumulation is not safe numerically with large number of accumulation steps, proceed with caution *****" ) return BFLOAT16 else: raise NotImplementedError( "Bfloat16 wrapper must use a gradient accumulation type of fp32, enable ZeRO to use Bfloat16 gradient accumulation" ) if model_dtype == torch.float16: return FP16 # else optimizer_wrapper = None elif model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32: return BFLOAT16 else: raise NotImplementedError("unsupported mix of model dtype and gradient accumulation type") return None # Configure optimizer def _configure_optimizer(self, client_optimizer, model_parameters): if client_optimizer is None: basic_optimizer = self._configure_basic_optimizer(model_parameters) log_dist(f"Using DeepSpeed Optimizer param name {self.optimizer_name()} as basic optimizer", ranks=[0]) else: if isinstance(client_optimizer, tuple(self._supported_optims())): basic_optimizer = client_optimizer log_dist('Using client Optimizer as basic optimizer', ranks=[0]) else: basic_optimizer = client_optimizer(model_parameters) log_dist('Using client callable to create basic optimizer', ranks=[0]) if self.zero_use_cpu_optimizer() and not isinstance(basic_optimizer, deepspeed.ops.adam.DeepSpeedCPUAdam): if self.zero_force_ds_cpu_optimizer(): msg = f'You are using ZeRO-Offload with a client provided optimizer ({type(basic_optimizer)}) which in most cases will yield poor performance. Please either use deepspeed.ops.adam.DeepSpeedCPUAdam or set an optimizer in your ds-config (https://www.deepspeed.ai/docs/config-json/#optimizer-parameters). If you really want to use a custom optimizer w. ZeRO-Offload and understand the performance impacts you can also set <"zero_force_ds_cpu_optimizer": false> in your configuration file.' raise ZeRORuntimeException(msg) basic_optimizer.param_groups[:] = [pg for pg in basic_optimizer.param_groups if len(pg["params"]) != 0] log_dist("Removing param_group that has no 'params' in the basic Optimizer", ranks=[0]) self._check_for_duplicates(basic_optimizer) self.basic_optimizer = basic_optimizer log_dist("DeepSpeed Basic Optimizer = {}".format(basic_optimizer.__class__.__name__), ranks=[0]) optimizer_wrapper = self._do_optimizer_sanity_check(basic_optimizer) if optimizer_wrapper == ZERO_OPTIMIZATION: self.optimizer = self._configure_zero_optimizer(basic_optimizer) elif optimizer_wrapper == AMP: amp_params = self.amp_params() log_dist(f"Initializing AMP with these params: {amp_params}", ranks=[0]) model, self.optimizer = amp.initialize(self.module, basic_optimizer, **amp_params) self._set_client_model(model) self._broadcast_model() # TODO: maybe need to broadcast experts differently? elif optimizer_wrapper == FP16: self.optimizer = self._configure_fp16_optimizer(basic_optimizer) elif optimizer_wrapper == BFLOAT16: self.optimizer = self._configure_bf16_optimizer(basic_optimizer) else: self.optimizer = basic_optimizer log_dist("DeepSpeed Final Optimizer = {}".format(self.optimizer_name()), ranks=[0]) self.compression_scheduler = self._configure_compression_scheduler() self.quantizer = self._configure_quantization() def _configure_basic_optimizer(self, model_parameters): optimizer_parameters = self.optimizer_params() if optimizer_parameters is None: optimizer_parameters = {} # print(optimizer_parameters.keys()) if "max_grad_norm" in optimizer_parameters.keys(): raise ValueError( "'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details" ) if self.optimizer_name() in [ADAM_OPTIMIZER, ADAMW_OPTIMIZER]: torch_adam = optimizer_parameters.pop(TORCH_ADAM_PARAM, False) adam_w_mode = optimizer_parameters.pop(ADAM_W_MODE, ADAM_W_MODE_DEFAULT) # Optimizer name of Adam forces AdamW logic unless adam_w_mode is explicitly set effective_adam_w_mode = self.optimizer_name() == ADAMW_OPTIMIZER or adam_w_mode if torch_adam: if not effective_adam_w_mode: optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters) else: optimizer = torch.optim.AdamW(model_parameters, **optimizer_parameters) else: if self.zero_use_cpu_optimizer(): from deepspeed.ops.adam import DeepSpeedCPUAdam optimizer = DeepSpeedCPUAdam(model_parameters, **optimizer_parameters, adamw_mode=effective_adam_w_mode) else: from deepspeed.ops.adam import FusedAdam optimizer = FusedAdam( model_parameters, **optimizer_parameters, adam_w_mode=effective_adam_w_mode, ) elif self.optimizer_name() == ADAGRAD_OPTIMIZER: if self.zero_use_cpu_optimizer(): from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad optimizer = DeepSpeedCPUAdagrad(model_parameters, **optimizer_parameters) else: optimizer = torch.optim.Adagrad(model_parameters, **optimizer_parameters) elif self.optimizer_name() == LAMB_OPTIMIZER: from deepspeed.ops.lamb import FusedLamb optimizer = FusedLamb(model_parameters, **optimizer_parameters) elif self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER: assert not self.zero_optimization(), "1bit-Adam is not compatible with ZeRO" from deepspeed.runtime.fp16.onebit.adam import OnebitAdam optimizer = OnebitAdam(model_parameters, self, **optimizer_parameters) if not self.fp16_enabled(): logger.warning(f"Currently the convergence of 1-bit Adam is only verified under FP16") elif self.optimizer_name() == ZERO_ONE_ADAM_OPTIMIZER: assert not self.zero_optimization(), "0/1 Adam is not compatible with ZeRO" from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam optimizer = ZeroOneAdam(model_parameters, self, **optimizer_parameters) if not self.fp16_enabled(): logger.warning(f'Currently the convergence of 0/1 Adam is only verified under FP16') elif self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER: assert not self.zero_optimization(), "1bit-Lamb is not compatible with ZeRO" from deepspeed.runtime.fp16.onebit.lamb import OnebitLamb optimizer = OnebitLamb(model_parameters, self, **optimizer_parameters) if not self.fp16_enabled(): logger.warning(f"Currently the convergence of 1-bit Lamb is only verified under FP16") elif self.optimizer_name() == LION_OPTIMIZER: if self.zero_use_cpu_optimizer(): from deepspeed.ops.lion import DeepSpeedCPULion optimizer = DeepSpeedCPULion(model_parameters, **optimizer_parameters) else: from deepspeed.ops.lion import FusedLion optimizer = FusedLion(model_parameters, **optimizer_parameters) elif self.optimizer_name() == MUADAM_OPTIMIZER: try: from mup import MuAdam except ImportError: logger.error(f"Install mup to use MuAdam optimizer") optimizer = MuAdam(model_parameters, **optimizer_parameters) elif self.optimizer_name() == MUADAMW_OPTIMIZER: try: from mup import MuAdamW except ImportError: logger.error(f"Install mup to use MuAdamW optimizer") optimizer = MuAdamW(model_parameters, **optimizer_parameters) elif self.optimizer_name() == MUSGD_OPTIMIZER: try: from mup import MuSGD except ImportError: logger.error(f"Install mup to use MuSGD optimizer") optimizer = MuSGD(model_parameters, **optimizer_parameters) else: torch_optimizer = getattr(torch.optim, self.optimizer_name()) optimizer = torch_optimizer(model_parameters, **optimizer_parameters) return optimizer def _configure_compression_scheduler(self): return compression_scheduler(self.module, self._config.compression_config) def _configure_random_ltd_scheduler(self, configs): return RandomLTDScheduler(configs) def _configure_quantization(self): ( quantize_weight_in_forward, quantize_enabled, q_groups, q_mixed_fp16, q_change_ratio, q_type, q_rounding, q_verbose, use_quantizer_kernel, ) = self.quantize_training() if quantize_enabled and not quantize_weight_in_forward: assert self.fp16_enabled( ), "MoQ (quantize in optimization step) weight quantization is only supported for FP16" quantizer = None if quantize_enabled and not quantize_weight_in_forward: from deepspeed.runtime.quantize import Quantizer quantizer = Quantizer( q_groups, q_mixed_fp16, q_change_ratio, q_type, q_rounding, q_verbose, self.eigenvalue_enabled(), use_quantizer_kernel, self.eigenvalue_layer_num() if self.eigenvalue_enabled() else 0, ) return quantizer def _configure_fp16_optimizer(self, optimizer): initial_dynamic_scale = self.initial_dynamic_scale() dynamic_loss_args = self.dynamic_loss_scale_args() clip_grad = self.gradient_clipping() if APEX_INSTALLED: fused_opts = (apex.optimizers.FusedAdam, FusedAdam) else: fused_opts = FusedAdam if isinstance(optimizer, fused_opts) \ or self.optimizer_name() in [ONEBIT_ADAM_OPTIMIZER, ZERO_ONE_ADAM_OPTIMIZER]: if self.dynamic_loss_scale(): log_dist(f'Creating fp16 optimizer with dynamic loss scale', ranks=[0]) timers = self.timers if self.wall_clock_breakdown() else NoopTimer() optimizer = FP16_Optimizer( optimizer, deepspeed=self, dynamic_loss_scale=True, initial_dynamic_scale=initial_dynamic_scale, dynamic_loss_args=dynamic_loss_args, mpu=self.mpu, clip_grad=clip_grad, fused_adam_legacy=self.optimizer_legacy_fusion(), timers=timers, has_moe_layers=self.has_moe_layers, ) else: log_dist(f'Creating fp16 optimizer with static loss scale: {self.loss_scale()}', ranks=[0]) optimizer = FP16_Optimizer( optimizer, deepspeed=self, static_loss_scale=self.loss_scale(), mpu=self.mpu, clip_grad=clip_grad, fused_adam_legacy=self.optimizer_legacy_fusion(), has_moe_layers=self.has_moe_layers, ) else: log_dist(f'Creating fp16 unfused optimizer with dynamic loss scale', ranks=[0]) optimizer = FP16_UnfusedOptimizer( optimizer, deepspeed=self, static_loss_scale=self.loss_scale(), dynamic_loss_scale=self.dynamic_loss_scale(), dynamic_loss_args=dynamic_loss_args, mpu=self.mpu, clip_grad=clip_grad, fused_lamb_legacy=self.optimizer_name() == LAMB_OPTIMIZER, ) return optimizer def _configure_bf16_optimizer(self, optimizer): clip_grad = self.gradient_clipping() if optimizer is None: optimizer = DummyOptim(list(self.module.parameters())) log_dist('Creating BF16 optimizer', ranks=[0]) timers = self.timers if self.wall_clock_breakdown() else NoopTimer() optimizer = BF16_Optimizer(optimizer, self.param_names, mpu=self.mpu, clip_grad=clip_grad, allgather_bucket_size=self.zero_allgather_bucket_size(), dp_process_group=self.seq_data_parallel_group, timers=timers, grad_acc_dtype=self.get_data_types()[1], graph_harvesting=self.graph_harvesting()) return optimizer def _configure_zero_optimizer(self, optimizer): zero_stage = self.zero_optimization_stage() mics_shard_size = self.mics_shard_size() model_dtype, gradient_accumulation_dtype = self.get_data_types() timers = self.timers if self.wall_clock_breakdown() else NoopTimer() if optimizer is None: optimizer = DummyOptim(list(self.module.parameters())) if self.zero_legacy_stage1(): raise Exception( "The deprecated version of ZeRO Stage 1 is not supported in deepspeed >= 0.5.9. Please downgrade to a version less than 0.5.9 if you need to use this deprecated version of ZeRO." ) if zero_stage <= ZeroStageEnum.gradients: overlap_comm = self.zero_overlap_comm() contiguous_gradients = self.zero_contiguous_gradients() round_robin_gradients = self.zero_round_robin_gradients() assert not isinstance(optimizer, DummyOptim), "zero stage {} requires an optimizer".format(zero_stage) log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0]) # Overlap and contiguous grads are meaningless in stage 1 and are ignored if zero_stage == ZeroStageEnum.optimizer_states: overlap_comm = False round_robin_gradients = False # Non-MoE requires contiguous grads to be disabled w. stage 1 if not self.has_moe_layers: contiguous_gradients = False if isinstance(self.module, PipelineModule): if overlap_comm: logger.warning("Pipeline parallelism does not support overlapped communication, will be disabled.") overlap_comm = False optimizer = DeepSpeedZeroOptimizer( optimizer, self.param_names, timers=timers, static_loss_scale=self.loss_scale(), dynamic_loss_scale=self.dynamic_loss_scale(), dynamic_loss_args=self.dynamic_loss_scale_args(), clip_grad=self.gradient_clipping(), contiguous_gradients=contiguous_gradients, reduce_bucket_size=self.zero_reduce_bucket_size(), use_multi_rank_bucket_allreduce=self.zero_multi_rank_bucket_allreduce(), allgather_bucket_size=self.zero_allgather_bucket_size(), dp_process_group=self.seq_data_parallel_group, expert_parallel_group=self.expert_parallel_group if self.has_moe_layers else None, expert_data_parallel_group=self.expert_data_parallel_group if self.has_moe_layers else None, reduce_scatter=self.zero_reduce_scatter(), overlap_comm=overlap_comm, offload_optimizer_config=self.zero_offload_optimizer(), mpu=self.mpu, postscale_gradients=self.postscale_gradients(), gradient_predivide_factor=self.gradient_predivide_factor(), gradient_accumulation_steps=self.gradient_accumulation_steps(), ignore_unused_parameters=self.zero_ignore_unused_parameters(), partition_grads=zero_stage == ZeroStageEnum.gradients, round_robin_gradients=round_robin_gradients, has_moe_layers=self.has_moe_layers, fp16_master_weights_and_gradients=self.fp16_master_weights_and_gradients(), gradient_accumulation_dtype=gradient_accumulation_dtype, communication_data_type=self.communication_data_type, elastic_checkpoint=self.zero_elastic_checkpoint()) elif zero_stage == ZeroStageEnum.weights: assert not self.has_moe_layers, "MoE not supported with Stage 3" if isinstance(optimizer, DummyOptim): log_dist("Creating ZeRO Offload", ranks=[0]) zero_param_parallel_group = groups._get_zero_param_intra_parallel_group() if self.zero_hpz_partition_size() > 1 and zero_param_parallel_group is None: self._set_zero_group_parallelism() zero_param_parallel_group = groups._get_zero_param_intra_parallel_group() optimizer = DeepSpeedZeRoOffload( self.module, timers=timers, ds_config=self.config, overlap_comm=self.zero_overlap_comm(), prefetch_bucket_size=self.zero_prefetch_bucket_size(), max_reuse_distance=self.zero_max_reuse_distance(), max_live_parameters=self.zero_max_live_parameters(), param_persistence_threshold=self.zero_param_persistence_threshold(), model_persistence_threshold=self.zero_model_persistence_threshold(), offload_param_config=self.zero_offload_param(), mpu=self.mpu, zero_param_parallel_group=zero_param_parallel_group, zero_quantized_weights=self.zero_quantized_weights(), zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(), ) else: log_dist( f'Creating fp16 ZeRO stage {zero_stage} optimizer,' f' MiCS is enabled {mics_shard_size>0},' f' Hierarchical params gather {self._config.mics_hierarchial_params_gather}', ranks=[0]) if mics_shard_size > 0: return self._return_mics_optimizer(optimizer, timers) log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0]) from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 optimizer = DeepSpeedZeroOptimizer_Stage3( self.module, optimizer, timers=timers, ds_config=self.config, static_loss_scale=self.loss_scale(), dynamic_loss_scale=self.dynamic_loss_scale(), dynamic_loss_args=self.dynamic_loss_scale_args(), clip_grad=self.gradient_clipping(), contiguous_gradients=self.zero_contiguous_gradients(), reduce_bucket_size=self.zero_reduce_bucket_size(), prefetch_bucket_size=self.zero_prefetch_bucket_size(), max_reuse_distance=self.zero_max_reuse_distance(), max_live_parameters=self.zero_max_live_parameters(), param_persistence_threshold=self.zero_param_persistence_threshold(), model_persistence_threshold=self.zero_model_persistence_threshold(), dp_process_group=self.seq_data_parallel_group, all2all_process_group=self.local_all_to_all_group, reduce_scatter=self.zero_reduce_scatter(), overlap_comm=self.zero_overlap_comm(), offload_optimizer_config=self.zero_offload_optimizer(), offload_param_config=self.zero_offload_param(), sub_group_size=self.zero_sub_group_size(), offload_ratio=self.zero_partial_offload(), mpu=self.mpu, postscale_gradients=self.postscale_gradients(), gradient_predivide_factor=self.gradient_predivide_factor(), gradient_accumulation_steps=self.gradient_accumulation_steps(), aio_config=self.aio_config(), gradient_accumulation_dtype=gradient_accumulation_dtype, communication_data_type=self.communication_data_type, zero_hpz_partition_size=self.zero_hpz_partition_size(), zero_quantized_weights=self.zero_quantized_weights(), zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(), ) else: raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage)) return optimizer def _return_mics_optimizer(self, basic_optimizer, timers): from deepspeed.runtime.zero.mics import MiCS_Optimizer model_dtype, gradient_accumulation_dtype = self.get_data_types() optimizer = MiCS_Optimizer(self.module, basic_optimizer, timers=timers, ds_config=self.config, static_loss_scale=self.loss_scale(), dynamic_loss_scale=self.dynamic_loss_scale(), dynamic_loss_args=self.dynamic_loss_scale_args(), clip_grad=self.gradient_clipping(), contiguous_gradients=self.zero_contiguous_gradients(), reduce_bucket_size=self.zero_reduce_bucket_size(), prefetch_bucket_size=self.zero_prefetch_bucket_size(), max_reuse_distance=self.zero_max_reuse_distance(), max_live_parameters=self.zero_max_live_parameters(), param_persistence_threshold=self.zero_param_persistence_threshold(), model_persistence_threshold=self.zero_model_persistence_threshold(), dp_process_group=self.seq_data_parallel_group, reduce_scatter=self.zero_reduce_scatter(), overlap_comm=self.zero_overlap_comm(), offload_optimizer_config=self.zero_offload_optimizer(), offload_param_config=self.zero_offload_param(), sub_group_size=self.zero_sub_group_size(), mpu=self.mpu, postscale_gradients=self.postscale_gradients(), gradient_predivide_factor=self.gradient_predivide_factor(), gradient_accumulation_steps=self.gradient_accumulation_steps(), aio_config=self.aio_config(), gradient_accumulation_dtype=gradient_accumulation_dtype, communication_data_type=self.communication_data_type) return optimizer def _configure_eigenvalue(self): eigenvalue = Eigenvalue( verbose=self.eigenvalue_verbose(), max_iter=self.eigenvalue_max_iter(), tol=self.eigenvalue_tol(), stability=self.eigenvalue_stability(), gas_boundary_resolution=self.eigenvalue_gas_boundary_resolution(), layer_name=self.eigenvalue_layer_name(), layer_num=self.eigenvalue_layer_num(), ) return eigenvalue def _configure_progressive_layer_drop(self): pld = ProgressiveLayerDrop(theta=self.pld_theta(), gamma=self.pld_gamma()) return pld def _configure_curriculum_scheduler_legacy(self): scheduler = CurriculumScheduler(self.curriculum_params_legacy()) return scheduler @staticmethod def is_map_style_dataset(obj): return hasattr(obj, "__getitem__") and hasattr(obj, "__len__") @staticmethod def is_iterable_style_dataset(obj): return isinstance(obj, torch.utils.data.IterableDataset) # hasattr(obj, "__iter__") should work as well def dataloader_drop_last(self): return self._config.dataloader_drop_last def was_step_applied(self) -> bool: """Returns True if the latest ``step()`` produced in parameter updates. Note that a ``False`` return is not an error condition. Steps are frequently no-ops, such as between gradient accumulation boundaries or when overflows occur. Returns: bool: Whether the latest ``step()`` modified model parameters. """ return self._step_applied def deepspeed_io(self, dataset, batch_size=None, route=ROUTE_TRAIN, pin_memory=True, data_sampler=None, collate_fn=None, num_local_io_workers=None): if not (self.is_map_style_dataset(dataset) or self.is_iterable_style_dataset(dataset)): raise ValueError("Training data must be a torch Dataset") if batch_size is None: batch_size = self.train_micro_batch_size_per_gpu() if collate_fn is None: collate_fn = self.collate_fn # Currently we only use timer in train route deepspeed_io_timer = None if route == ROUTE_TRAIN: deepspeed_io_timer = self.tput_timer # If mpu is provided, forward world size and parallel rank to sampler. data_parallel_world_size = self.dp_world_size data_parallel_rank = self.global_rank if self.mpu is not None: data_parallel_world_size = self.mpu.get_data_parallel_world_size() data_parallel_rank = self.mpu.get_data_parallel_rank() if data_sampler is None and (route == ROUTE_PREDICT or route == ROUTE_EVAL): data_sampler = torch.utils.data.DistributedSampler( dataset, num_replicas=data_parallel_world_size, rank=data_parallel_rank, shuffle=False, ) deepspeed_dataloader_config = {} if self.curriculum_learning_enabled(): deepspeed_dataloader_config = { CURRICULUM_LEARNING: self.curriculum_learning_enabled(), DATA_EFFICIENCY: self.data_efficiency_config(), DATA_PARALLEL_GROUP: self.data_parallel_group, GRADIENT_ACCUMULATION_STEPS: self.gradient_accumulation_steps(), GLOBAL_RANK: self.global_rank, DATA_SAMPLING_NUM_WORKERS: self.data_sampling_config()[DATA_SAMPLING_NUM_WORKERS] } return DeepSpeedDataLoader(dataset=dataset, batch_size=batch_size, pin_memory=pin_memory, collate_fn=collate_fn, local_rank=self.local_rank, tput_timer=deepspeed_io_timer, num_local_io_workers=num_local_io_workers, data_sampler=data_sampler, data_parallel_world_size=data_parallel_world_size, data_parallel_rank=data_parallel_rank, dataloader_drop_last=self.dataloader_drop_last(), deepspeed_dataloader_config=deepspeed_dataloader_config) def train(self, mode=True): r"""""" self.warn_unscaled_loss = True self.module.train(mode) def eval(self): r"""""" self.warn_unscaled_loss = True self.module.train(False) def _scale_loss_by_gas(self, prescaled_loss): if isinstance(prescaled_loss, torch.Tensor): scaled_loss = prescaled_loss / self.gradient_accumulation_steps() elif isinstance(prescaled_loss, tuple) or isinstance(prescaled_loss, list): scaled_loss = [] for l in prescaled_loss: if isinstance(l, torch.Tensor): scaled_loss.append(l / self.gradient_accumulation_steps()) else: scaled_loss.append(l) else: scaled_loss = prescaled_loss if self.warn_unscaled_loss: logger.warning(f"DeepSpeed unable to scale loss because of type: {type(prescaled_loss)}") self.warn_unscaled_loss = False return scaled_loss @instrument_w_nvtx def forward(self, *inputs, **kwargs): r"""Execute forward propagation Arguments: *inputs: Variable length input list **kwargs: variable length keyword arguments """ if self.autotuning_profile_model_info(): ma = get_ma_status() else: see_memory_usage("Engine before forward", force=self.memory_breakdown()) flops_profiler_active = (self.flops_profiler_enabled() and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0) # used to check quantization happens at step 0! if self.global_steps == 0 and hasattr(self, "compression_scheduler"): self.compression_scheduler.step(step_zero_check=True) if self.quantizer: tensor_to_quantize = self.optimizer.bit16_groups if self.zero_optimization_stage( ) == 2 else self.optimizer.fp16_groups if self.compression_scheduler.weight_quantization_enabled: self.quantizer.quantize( tensor_to_quantize, (self.optimizer.overflow if self.fp16_enabled() else False), self.eigenvalue_enabled(), None, ) if flops_profiler_active: self.flops_profiler.start_profile(ignore_list=None) if self.module.training: if self.progressive_layer_drop: kwargs.update(self.progressive_layer_drop.get_state()) if self.__class__.__name__ != "PipelineEngine": # TODO: The above if condition is a HACK since for PipelineEngine # it's difficult to inject argument in forward pass. if self.module.training and self.curriculum_enabled_legacy(): self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1) if self.curriculum_params_legacy()["curriculum_type"] == "seqlen": kwargs.update({"curriculum_seqlen": self.curriculum_scheduler_legacy.get_current_difficulty()}) if self.module.training and self.random_ltd_enabled(): self.random_ltd_scheduler.update_seq(self.global_steps) if self.zero_optimization_partition_weights(): # Enable automated discovery of external parameters by indicating that # we are in a forward pass. for module in self.module.modules(): module._parameters._in_forward = True self._start_timers(self.engine_timers.forward_timers) if self.training_dataloader is None: self.tput_timer.start() if self.fp16_auto_cast(): inputs = self._cast_inputs_half(inputs) # print(f"RANK[{self.global_rank}] self.fp16_auto_cast() is {self.fp16_auto_cast()}") loss = self.module(*inputs, **kwargs) # print(f"RANK[{self.global_rank}]'s loss is {loss}") if self.zero_optimization_partition_weights(): # Disable automated discovery of external parameters for module in self.module.modules(): module._parameters._in_forward = False self._stop_timers(self.engine_timers.forward_timers) if flops_profiler_active: self.flops_profiler.stop_profile() if self.autotuning_profile_model_info(): activation_mem = get_ma_status() - ma self.autotuning_model_info["activation_mem_per_gpu"] = activation_mem print_json_dist(self.autotuning_model_info, [0], path=self.autotuning_model_info_path()) exit() else: see_memory_usage("Engine after forward", force=self.memory_breakdown()) return loss def _cast_inputs_half(self, inputs): if isinstance(inputs, (list, tuple)): new_inputs = [] for v in inputs: new_inputs.append(self._cast_inputs_half(v)) return inputs.__class__(new_inputs) elif isinstance(inputs, dict): new_inputs = {} for k, v in inputs.items(): new_inputs[k] = self._cast_inputs_half(v) return new_inputs elif hasattr(inputs, 'half'): return inputs.half() else: return inputs def print_forward_breakdown(self, fwd_time): gate_time = 0.0 moe_time = 0.0 falltoall = 0.0 salltoall = 0.0 for gate in self.gate_modules: #logger.info(f"Individual TopK gate time: {gate.gate_time:.2f} ms") gate_time += gate.gate_time for l in self.moe_layers: #logger.info(f"MoE layer; total: {l.time_moe:.2f} ms, first alltoall: {l.time_falltoall:.2f}, second alltoall: {l.time_salltoall:.2f}") moe_time += l.time_moe falltoall += l.time_falltoall salltoall += l.time_salltoall # TODO: Allreduce/average them across ranks for more accurate timing. # if deepspeed.comm.get_rank() == 0: log_dist( f"time (ms) | fwd: {fwd_time:.2f} (fwd_moe: {moe_time:.2f}, 1st_a2a: {falltoall:.2f}, 2nd_a2a: {salltoall:.2f}, top_k: {gate_time:.2f})", ranks=[0]) @instrument_w_nvtx def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): assert not (self.bfloat16_enabled() and self.pipeline_parallelism), \ f'allreduce_gradients() is not valid when bfloat+pipeline_parallelism is enabled' # Pass (PP) gas boundary flag to optimizer (required for zero) self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary() # ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well if self.zero_optimization_partition_gradients(): self.optimizer.overlapping_partition_gradients_reduce_epilogue() # Communicate only at gradient accumulation boundaries elif self.is_gradient_accumulation_boundary(): if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states and hasattr( self.optimizer, 'reduce_gradients'): self.optimizer.reduce_gradients(pipeline_parallel=self.pipeline_parallelism) else: self.buffered_allreduce_fallback(elements_per_buffer=bucket_size) @instrument_w_nvtx def backward(self, loss, allreduce_gradients=True, release_loss=False, retain_graph=False, scale_wrt_gas=True): r"""Execute backward pass on the loss Arguments: loss: Torch tensor on which to execute backward propagation allreduce_gradients: is deprecated, ignored, and will soon be removed' retain_graph: bool, default: false forward on user defined choice of retain_graph """ see_memory_usage("Engine before backward", force=self.memory_breakdown()) if self.scale_wrt_gas is not None: scale_wrt_gas = self.scale_wrt_gas if not allreduce_gradients: logger.warning(f"Argument `allreduce_gradients` is deprecated, ignored, and will soon be removed") # scale loss w.r.t. gradient accumulation if needed if self.gradient_accumulation_steps() > 1 and scale_wrt_gas: loss = self._scale_loss_by_gas(loss.float()) # Log training loss self.losses += loss.mean().item() if self.monitor.enabled: if self.is_gradient_accumulation_boundary(): if self.global_rank == 0: self.summary_events = [( f"Train/Samples/train_loss", self.losses, self.global_samples, )] self.monitor.write_events(self.summary_events) self._start_timers(self.engine_timers.backward_timers) assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \ "must provide optimizer during init in order to use backward" self._start_timers(self.engine_timers.backward_inner_timers) if self.zero_optimization(): self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary() self.optimizer.backward(loss, retain_graph=retain_graph) elif self.amp_enabled(): # AMP requires delaying unscale when inside gradient accumulation boundaries # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations delay_unscale = not self.is_gradient_accumulation_boundary() with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss: scaled_loss.backward(retain_graph=retain_graph) elif self.fp16_enabled(): if self.eigenvalue_enabled(): self.optimizer.backward(loss, create_graph=True, retain_graph=True) else: self.optimizer.backward(loss, retain_graph=retain_graph) elif self.bfloat16_enabled(): self.optimizer.backward(loss) else: if self.eigenvalue_enabled(): loss.backward(create_graph=True, retain_graph=True) else: loss.backward(retain_graph=retain_graph) self._stop_timers(self.engine_timers.backward_inner_timers) self._start_timers(self.engine_timers.backward_reduce_timers) if allreduce_gradients and self.enable_backward_allreduce: # Traditional code path that allreduces the module parameter grads self.allreduce_gradients() self._stop_timers(self.engine_timers.backward_reduce_timers) self._stop_timers(self.engine_timers.backward_timers) if release_loss: # loss.data = None pass see_memory_usage("Engine after backward", force=self.memory_breakdown()) return loss def is_gradient_accumulation_boundary(self): """ Query whether the current micro-batch is at the boundary of gradient accumulation, and thus will trigger gradient reductions and an optimizer step. Returns: bool: if the current step is a gradient accumulation boundary. """ if self._is_gradient_accumulation_boundary is None: return (self.micro_steps + 1) % \ self.gradient_accumulation_steps() == 0 else: return self._is_gradient_accumulation_boundary def set_gradient_accumulation_boundary(self, is_boundary): """ Manually overrides the DeepSpeed engine's gradient accumulation boundary state, this is an optional feature and should be used with care. The state should be set before to the intended value before each forward/backward. The final forward/backward should have the boundary state set to True. This style allows client code to only call engine.step() once after all the gradient accumulation passes are complete. See example below: .. code-block:: python engine.set_gradient_accumulation_boundary(False) for _ in range(gradient_accumulation_steps - 1): micro_batch = next(data_loader) loss = engine(micro_batch) engine.backward(loss) engine.set_gradient_accumulation_boundary(True) micro_batch = next(data_loader) loss = engine(micro_batch) engine.backward(loss) engine.step() Arguments: is_boundary (bool): are we at a gradient accumulation boundary or not? """ self._is_gradient_accumulation_boundary = is_boundary self.optimizer.is_gradient_accumulation_boundary = is_boundary def zero_grad(self): """ Zero parameter grads. """ for param_name, param in self.module.named_parameters(): param.grad = None def clip_fp32_gradients(self): clip_grad_norm_(parameters=self.module.parameters(), max_norm=self.gradient_clipping(), mpu=self.mpu) def _take_model_step(self, lr_kwargs, block_eigenvalue={}): if self.gradient_clipping() > 0.0: if not (self.fp16_enabled() or self.bfloat16_enabled() or self.amp_enabled() or self.zero_optimization()): self.clip_fp32_gradients() elif self.amp_enabled(): # AMP's recommended way of doing clipping # https://nvidia.github.io/apex/advanced.html#gradient-clipping master_params = amp.master_params(self.optimizer) clip_grad_norm_(parameters=master_params, max_norm=self.gradient_clipping(), mpu=self.mpu) self.optimizer.step() if hasattr(self.optimizer, '_global_grad_norm'): self._global_grad_norm = self.optimizer._global_grad_norm # Quantize the updated parameter if there is no overflow if self.quantizer: tensor_to_quantize = self.optimizer.bit16_groups if self.zero_optimization_stage( ) == 2 else self.optimizer.fp16_groups if self.compression_scheduler.weight_quantization_enabled: self.quantizer.quantize( tensor_to_quantize, (self.optimizer.overflow if self.fp16_enabled() else False), self.eigenvalue_enabled(), block_eigenvalue, ) # zero grad in basic optimizer could be unreliable and may not exhibit # the behavior that we want if self.bfloat16_enabled(): # TODO: Temporary until bf16_optimizer and zero_optimizer are integrated if self.zero_optimization() and hasattr(self.optimizer, "zero_grad"): self.optimizer.zero_grad() else: pass elif self.zero_optimization() or self.fp16_enabled() or self.amp_enabled(): self.optimizer.zero_grad() else: self.zero_grad() report_progress = self.global_rank == 0 if self.global_rank else True # Check overflow here since in DS fp16 optimizer, the overflow is updated in above step() function. overflow = False if hasattr(self.optimizer, "overflow"): overflow = self.optimizer.overflow self._step_applied = not overflow if overflow: self.skipped_steps += 1 else: self.compression_scheduler.step() if self.lr_scheduler is not None: try: self.lr_scheduler.step(**(lr_kwargs or {})) except TypeError: # XXX Hack to work with Megatron 2.0 and DeepSpeed pipelines. # We don't currently have a way to specify lr_kwargs from # pipe_engine.train_batch() self.lr_scheduler.step(self.train_batch_size()) if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0: self._report_progress(self.global_steps + 1) self.losses = 0.0 self.global_steps += 1 self.global_samples += self.train_batch_size() def step(self, lr_kwargs=None): r"""Execute the weight update step after forward and backward propagation on effective_train_batch. """ see_memory_usage("Engine before step", force=self.memory_breakdown()) # Check early because self.global_steps is incremented at some point here. # TODO: Delay self.global_steps increment until very end of this function. flops_profiler_active = self.flops_profiler_enabled( ) and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0 self._start_timers(self.engine_timers.step_timers) assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \ "must provide optimizer during init in order to use step" report_progress = False self._step_applied = False # assume False, will flip to True # Update the model when we reach gradient accumulation boundaries if self.is_gradient_accumulation_boundary(): self.gas_boundary_ctr += 1 if (self.eigenvalue_enabled() and (self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution() == 0) and self.quantizer.any_precision_switch()): log_dist(f"computing eigenvalue...", ranks=[0]) self.block_eigenvalue = self.eigenvalue.compute_eigenvalue(self.module, self.device, self.optimizer.cur_scale) if self.progressive_layer_drop: self.progressive_layer_drop.update_state(self.global_steps) if (self.eigenvalue_enabled() and not self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution() and self.quantizer.any_precision_switch()): self._take_model_step(lr_kwargs, self.block_eigenvalue) else: self._take_model_step(lr_kwargs) report_progress = self.global_rank == 0 if self.global_rank else True self.tput_timer.stop(global_step=self.is_gradient_accumulation_boundary(), report_speed=report_progress) self._stop_timers(self.engine_timers.step_timers) # Log learning rate if self.monitor.enabled: if self.is_gradient_accumulation_boundary(): if self.global_rank == 0: self.summary_events = [(f"Train/Samples/lr", self.get_lr()[0], self.global_samples)] if self.fp16_enabled() and hasattr(self.optimizer, "cur_scale"): self.summary_events.append(( f"Train/Samples/loss_scale", self.optimizer.cur_scale, self.global_samples, )) if (self.eigenvalue_enabled() and not self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution()): ev_values = self.block_eigenvalue.values() for i in range(len(ev_values)): self.summary_events.append(( f"Train/Eigenvalues/ModelBlockParam_{i}", self.ev_values[i][0], self.global_samples, )) self.monitor.write_events(self.summary_events) # Check flops profiling if flops_profiler_active: if self.autotuning_enabled(): self.flops = self.flops_profiler.get_total_flops() * 3 self.fwd_duration = self.flops_profiler.get_total_duration() else: self.flops_profiler.print_model_profile( profile_step=self.global_steps, module_depth=self.flops_profiler_module_depth(), top_modules=self.flops_profiler_top_modules(), detailed=self.flops_profiler_detailed(), output_file=self.flops_profiler_output_file(), ) self.flops_profiler.end_profile() if self.autotuning_enabled() and self.global_steps == (self.autotuning_end_profile_step() + 1): self._autotuning_exit() if self.wall_clock_breakdown(): # Log micro timing and reset self.timers.log(names=self.engine_timers.micro_timers, memory_breakdown=self.memory_breakdown()) if self.wall_clock_breakdown() or self.flops_profiler_enabled(): # Log global timing and reset if self.is_gradient_accumulation_boundary(): if self.monitor.enabled: self._write_monitor() if self.has_moe_layers: fwd_time = self.timers(FORWARD_GLOBAL_TIMER).elapsed(reset=False) self.print_forward_breakdown(fwd_time=fwd_time) self.timers.log(self.engine_timers.global_timers) self.micro_steps += 1 see_memory_usage("Engine after step", force=self.memory_breakdown()) def _start_timers(self, timer_names): for name in timer_names: self.timers(name).start() def _stop_timers(self, timer_names): record = self.is_gradient_accumulation_boundary() and \ self.flops_profiler_enabled() and \ (self.global_steps >= self.flops_profiler_profile_step()) for name in timer_names: self.timers(name).stop(record=record) def _autotuning_exit(self): if self.global_rank == 0: msg = self.timers.get_mean([ FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, STEP_GLOBAL_TIMER, ], reset=False) titer = 0.0 titer += msg[FORWARD_GLOBAL_TIMER] if FORWARD_GLOBAL_TIMER in msg else 0 titer += msg[BACKWARD_GLOBAL_TIMER] if BACKWARD_GLOBAL_TIMER in msg else 0 titer += msg[STEP_GLOBAL_TIMER] if STEP_GLOBAL_TIMER in msg else 0 titer *= self.gradient_accumulation_steps() msg["latency"] = titer msg["FLOPS_per_gpu"] = self.flops * 1_000_000 * self.gradient_accumulation_steps() / titer msg["throughput"] = self.train_batch_size() * 1_000_000 / \ msg["latency"] print_json_dist(msg, [0], path=self.autotuning_metric_path()) log_dist( f"Wrote metrics to {self.autotuning_metric_path()}, {os.path.abspath(self.autotuning_metric_path())}", ranks=[0]) import atexit atexit.register(print, "Autotuning: done with running current ds config.") exit() def _write_monitor(self): if self.global_rank == 0: self.summary_events = [ ( f"Train/Samples/elapsed_time_ms_forward", self.timers(FORWARD_GLOBAL_TIMER).elapsed(reset=False), self.global_samples, ), ( f"Train/Samples/elapsed_time_ms_backward", self.timers(BACKWARD_GLOBAL_TIMER).elapsed(reset=False), self.global_samples, ), ( f"Train/Samples/elapsed_time_ms_backward_inner", self.timers(BACKWARD_INNER_GLOBAL_TIMER).elapsed(reset=False), self.global_samples, ), ( f"Train/Samples/elapsed_time_ms_backward_allreduce", self.timers(BACKWARD_REDUCE_GLOBAL_TIMER).elapsed(reset=False), self.global_samples, ), ( f"Train/Samples/elapsed_time_ms_step", self.timers(STEP_GLOBAL_TIMER).elapsed(reset=False), self.global_samples, ), ] self.monitor.write_events(self.summary_events) def _get_optimizer_param(self, param_name): result = [] if not self.optimizer: return result for group in self.optimizer.param_groups: if param_name in group: result.append(group[param_name]) else: result.append(0.0) return result def get_lr(self): return self._get_optimizer_param("lr") def get_type(self): return self._get_optimizer_param("type") def get_mom(self): if self.optimizer_name() in ["SGD", "RMSprop"]: return self._get_optimizer_param("momentum") else: return self._get_optimizer_param("betas") def get_pld_theta(self): if self.progressive_layer_drop: return self.progressive_layer_drop.get_theta() else: return None def _report_progress(self, step): lr = self.get_lr() mom = self.get_mom() log_dist(f"step={step}, skipped={self.skipped_steps}, lr={lr}, mom={mom}", ranks=[0]) def allreduce_bucket(self, bucket, dp_group): tensor = self.flatten(bucket) tensor_to_allreduce = tensor if self.communication_data_type != tensor.dtype: tensor_to_allreduce = tensor.to(self.communication_data_type) if self.postscale_gradients(): if self.gradient_predivide_factor() != 1.0: tensor_to_allreduce.mul_(1.0 / self.gradient_predivide_factor()) dist.all_reduce(tensor_to_allreduce, group=dp_group) if self.gradient_average: if self.gradient_predivide_factor() != dist.get_world_size(group=dp_group): tensor_to_allreduce.mul_(self.gradient_predivide_factor() / dist.get_world_size(group=dp_group)) else: tensor_to_allreduce.mul_(1. / dist.get_world_size(group=dp_group)) dist.all_reduce(tensor_to_allreduce, group=dp_group) if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: tensor.copy_(tensor_to_allreduce) return tensor def allreduce_and_copy(self, small_bucket, dp_group): allreduced = self.allreduce_bucket(small_bucket, dp_group) for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): buf.copy_(synced) def allreduce_no_retain(self, bucket, dp_group, numel_per_bucket=500000000): small_bucket = [] numel = 0 for tensor in bucket: small_bucket.append(tensor) numel = numel + tensor.numel() if numel > numel_per_bucket: self.allreduce_and_copy(small_bucket, dp_group) small_bucket = [] numel = 0 if len(small_bucket) > 0: self.allreduce_and_copy(small_bucket, dp_group) def _get_gradients_for_reduction(self): non_expert_grads = [] expert_grads = {} if self.has_moe_layers: for key in self.expert_data_parallel_group.keys(): expert_grads[key] = [] for param_name, param in self.module.named_parameters(): if not param.requires_grad: continue if param.grad is None: # In cases where there is an imbalance of empty grads across # ranks we must create empty grads, this will ensure that every # rank is reducing the same size. In some cases it may make # sense in the future to support the ability to average not # w.r.t. world size but with a different value. param.grad = torch.zeros(param.size(), dtype=param.dtype, device=param.device) grad_data = param.grad.data if param_name in self.sparse_tensor_module_names or grad_data.is_sparse: # Call param.grad without data to avoid problem with setting of updated grads grad_data = SparseTensor(param.grad) if is_moe_param(param): expert_grads[param.group_name].append(grad_data) else: non_expert_grads.append(grad_data) return non_expert_grads, expert_grads def _reduce_non_expert_gradients(self, grads, elements_per_buffer): split_buckets = split_half_float_double_sparse(grads) for _, bucket_tuple in enumerate(split_buckets): bucket_type, bucket = bucket_tuple if self.pipeline_parallelism: dp_group = self.mpu.get_data_parallel_group() else: dp_group = groups._get_sequence_data_parallel_group() if bucket_type == SparseTensor.type(): self.sparse_allreduce_no_retain(bucket, dp_group=dp_group) else: self.allreduce_no_retain(bucket, dp_group=dp_group, numel_per_bucket=elements_per_buffer) def _reduce_expert_gradients(self, expert_grads, elements_per_buffer): for ep_name, expert_grads_group in expert_grads.items(): expert_split_buckets = split_half_float_double_sparse(expert_grads_group) for i, bucket_tuple in enumerate(expert_split_buckets): bucket_type, bucket = bucket_tuple if bucket_type == SparseTensor.type(): self.sparse_allreduce_no_retain(bucket, groups._get_expert_data_parallel_group(ep_name)) else: # Separate between diff groups self.allreduce_no_retain(bucket, dp_group=groups._get_expert_data_parallel_group(ep_name), numel_per_bucket=elements_per_buffer) def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000): if grads is None: non_expert_grads, expert_grads = self._get_gradients_for_reduction() else: assert not self.has_moe_layers, "attempting to reduce grads in unsupported way w.r.t. MoE" non_expert_grads = grads self._reduce_non_expert_gradients(non_expert_grads, elements_per_buffer) if self.has_moe_layers: self._reduce_expert_gradients(expert_grads, elements_per_buffer) def sparse_allreduce_no_retain(self, bucket, dp_group): allreduced_sparses = self.sparse_allreduce_bucket(bucket, dp_group) # Densify sparse tensor and copy back to original location for tensor in allreduced_sparses: if tensor.is_sparse: tensor.orig_dense_tensor.data = tensor.to_coo_tensor() else: tensor.orig_dense_tensor.copy_(tensor.to_dense()) def sparse_allreduce_bucket(self, bucket, dp_group): sparse_list = [] for sparse in bucket: sparse_list.append(self.sparse_allreduce(sparse, dp_group)) return sparse_list def sparse_allreduce(self, sparse, dp_group): original_data_type = sparse.values.dtype if self.communication_data_type != sparse.values.dtype: if self.communication_data_type in (torch.float16, torch.bfloat16): indices = sparse.indices.to(torch.int32) else: indices = sparse.indices values = sparse.values.to(self.communication_data_type) else: indices = sparse.indices values = sparse.values if self.postscale_gradients(): if self.gradient_average: values.mul_(self.gradient_predivide_factor() / (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size))) else: values.mul_(1. / (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size))) indices_device_list = self.sparse_all_gather(indices, dp_group) values_device_list = self.sparse_all_gather(values, dp_group) sparse.indices = torch.cat(indices_device_list).to(torch.long) sparse.values = torch.cat(values_device_list).to(original_data_type) return sparse def sparse_all_gather(self, value, dp_group): my_size = torch.LongTensor([value.size()[0]]).to(self.device) all_sizes = self.all_gather_scalar(my_size, dp_group) max_size = torch.cat(all_sizes).max() fill_size = max_size - my_size assert value.dim() in [1, 2] if value.dim() == 1: if fill_size > 0: value = torch.cat([value, value.new_empty(fill_size)]) tensor_list = [value.new_empty(max_size) for _ in range(dist.get_world_size(group=dp_group))] else: if fill_size > 0: value = torch.cat([value, value.new_empty(fill_size, value.size()[1])]) tensor_list = [ value.new_empty(max_size, value.size()[1]) for _ in range(dist.get_world_size(group=dp_group)) ] dist.all_gather(tensor_list, value, group=dp_group) tensors = [] for dev_idx, t in enumerate(tensor_list): size = all_sizes[dev_idx][0] tensors.append(t.index_select(0, torch.arange(size, dtype=torch.long, device=self.device))) return tensors def all_gather_scalar(self, value, dp_group): tensor_list = [value.new_zeros(value.size()) for _ in range(dist.get_world_size(group=dp_group))] dist.all_gather(tensor_list, value, group=dp_group) return tensor_list def module_state_dict(self, destination=None, prefix="", keep_vars=False, exclude_frozen_parameters=False): sd = self.module.state_dict(destination, prefix, keep_vars) # Remove frozen parameter weights from state_dict if specified if exclude_frozen_parameters: for n, p in self.module.named_parameters(): if not p.requires_grad and n in sd: del sd[n] if self.random_ltd_enabled(): sd = remove_random_ltd_state_dict(sd) return sd @staticmethod def load_moe_state_dict(checkpoint_path, tag, state_dict, old_moe_load, model=None, mpu=None, num_experts=1, checkpoint_engine=TorchCheckpointEngine()): if old_moe_load: expp_rank = groups._get_expert_data_parallel_rank(groups._get_max_expert_size_name()) num_local_experts = max(num_experts) // groups._get_expert_parallel_world_size( groups._get_max_expert_size_name()) for local_expert_id in range(num_local_experts): global_expert_id = expp_rank * num_local_experts + local_expert_id expert_state_dict = checkpoint_engine.load( DeepSpeedEngine._get_expert_ckpt_name( checkpoint_path, -1, # -1 means ignore layer_id global_expert_id, tag, mpu), map_location=torch.device('cpu')) # Updating global -> local expert ids moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' for key in list(expert_state_dict.keys()): local_key = key.replace(f'{moe_str_prefix}{global_expert_id}', f'{moe_str_prefix}{local_expert_id}') expert_state_dict[local_key] = expert_state_dict.pop(key) state_dict.update(expert_state_dict) else: moe_layer_id = 0 for n_module, module in model.named_modules(): if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0: group_name = module.expert_group_name num_local_experts = module.num_local_experts expp_rank = groups._get_expert_parallel_rank(group_name) # loop all local_experts for local_expert_id in range(num_local_experts): global_expert_id = expp_rank * num_local_experts + local_expert_id expert_state_dict = checkpoint_engine.load(DeepSpeedEngine._get_expert_ckpt_name( checkpoint_path, moe_layer_id, global_expert_id, tag, mpu), map_location=torch.device('cpu')) # print(expert_state_dict.keys()) # Updating global -> local expert ids moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' for key in list(expert_state_dict.keys()): local_key = key.replace(f'{moe_str_prefix}{global_expert_id}', f'{moe_str_prefix}{local_expert_id}') expert_state_dict[local_key] = expert_state_dict.pop(key) state_dict.update(expert_state_dict) moe_layer_id += 1 def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, fetch_z3_params=False): if fetch_z3_params: params_to_fetch = [ p for p in self.module.parameters() if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE ] else: params_to_fetch = [] with deepspeed.zero.GatheredParameters(params_to_fetch, modifier_rank=0): module_state_dict = checkpoint['module'] if custom_load_fn: custom_load_fn(src=module_state_dict, dst=self.module) else: self.module.load_state_dict( module_state_dict, # TODO strict=strict) if checkpoint.get(FROZEN_PARAM_FRAGMENTS, None) is not None: saved_frozen_params = checkpoint[FROZEN_PARAM_FRAGMENTS] for param in self.module.parameters(): if param.requires_grad: continue if param not in self.param_names: raise ValueError(f"failed to find frozen {param} in named params") name = self.param_names[param] if hasattr(param, 'ds_id'): param.ds_tensor.data.copy_(saved_frozen_params[name].data) else: param.data.copy_(saved_frozen_params[name].data) def _get_zero_ckpt_prefix(self, dp_rank, bf16_mode): return f'{"bf16_" if bf16_mode else ""}zero_pp_rank_{dp_rank}' def _get_rank_zero_ckpt_name(self, checkpoints_path, tag, mp_rank, dp_rank, bf16_mode): file_prefix = self._get_zero_ckpt_prefix(dp_rank, bf16_mode=bf16_mode) zero_ckpt_name = os.path.join( checkpoints_path, str(tag), f"{file_prefix}_mp_rank_{mp_rank:02d}_optim_states.pt", ) return zero_ckpt_name def _get_zero_ckpt_name(self, checkpoints_path, tag): mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() pp_rank = dist.get_rank(group=self.optimizer.zp_process_group) bf16_mode = self.bfloat16_enabled() return self._get_rank_zero_ckpt_name(checkpoints_path, tag, mp_rank, pp_rank, bf16_mode) def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None): if mp_placeholder is not None: mp_rank_str = mp_placeholder else: mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() mp_rank_str = f"{mp_rank:02d}" if self.zero_optimization_partition_weights(): filename = "zero_pp_rank_{}".format(dist.get_rank(group=self.optimizer.zp_process_group)) ckpt_name = os.path.join( checkpoints_path, str(tag), f"{filename}_mp_rank_{mp_rank_str}_model_states.pt", ) else: ckpt_name = os.path.join( checkpoints_path, str(tag), "mp_rank_" + mp_rank_str + "_model_states.pt", ) return ckpt_name def _get_optimizer_ckpt_name(self, checkpoints_path, tag, expp_rank): mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() ckpt_name = os.path.join(checkpoints_path, str(tag), f'expp_rank_{expp_rank}_mp_rank_{mp_rank:02d}_optim_states.pt') return ckpt_name @staticmethod def _get_expert_ckpt_name(checkpoints_path, layer_id, expert_id, tag, mpu=None): mp_rank = 0 if mpu is None else mpu.get_model_parallel_rank() if layer_id <= -1: # Used to support old checkpoint loading ckpt_name = os.path.join(checkpoints_path, '' if tag is None else str(tag), f'expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt') else: # Used to support new checkpoint loading ckpt_name = os.path.join(checkpoints_path, '' if tag is None else str(tag), f'layer_{layer_id}_expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt') return ckpt_name def _get_all_ckpt_names(self, checkpoints_path, tag): # It is required that (checkpoints_path, tag) are consistent among all ranks. ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, tag, mp_placeholder="*") import glob ckpt_files = glob.glob(ckpt_file_pattern) ckpt_files.sort() return ckpt_files def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False, custom_load_fn=None): """ Load training checkpoint Arguments: load_dir: Required. Directory to load the checkpoint from tag: Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match. load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint. load_module_only: Optional. Boolean to load only the model weights from the checkpoint. Ex. warmstarting. custom_load_fn: Optional. Custom model load function. Returns: A tuple of ``load_path`` and ``client_state``. *``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed. *``client_state``: State dictionary used for loading required training states in the client code. Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right after ``engine.save_checkpoint()``. It is because ``engine.module`` is partitioned, and ``load_checkpoint()`` wants a pristine model. If insisting to do so, please reinitialize engine before ``load_checkpoint()``. """ if tag is None: latest_tag = "latest_universal" if self.load_universal_checkpoint() else "latest" latest_path = os.path.join(load_dir, latest_tag) if os.path.isfile(latest_path): with open(latest_path, "r") as fd: tag = fd.read().strip() else: if self.load_universal_checkpoint(): raise ValueError(f'Invalid for universal checkpoint: {latest_path} does not exist') else: logger.warning( f"Unable to find latest file at {latest_path}, if trying to load latest " "checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint." ) return None, None if self._optimizer_has_ckpt_event_prologue(): # Prepare for checkpoint load by ensuring all parameters are partitioned self.optimizer.checkpoint_event_prologue() load_path, client_states = self._load_checkpoint(load_dir, tag, load_module_strict=load_module_strict, load_optimizer_states=load_optimizer_states, load_lr_scheduler_states=load_lr_scheduler_states, load_module_only=load_module_only, custom_load_fn=custom_load_fn) load_zero_checkpoint = load_optimizer_states and load_path is not None and (self.zero_optimization() or self.bfloat16_enabled()) if load_zero_checkpoint: success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states) if not success: self.optimizer._restore_from_bit16_weights() if self._optimizer_has_ckpt_event_epilogue(): self.optimizer.checkpoint_event_epilogue() if self.load_universal_checkpoint(): self.optimizer.update_lp_params() if load_zero_checkpoint: self.update_optimizer_step(step=client_states['iteration'] + 1) return load_path, client_states def _load_checkpoint(self, load_dir, tag, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False, custom_load_fn=None): from deepspeed.runtime.state_dict_factory import SDLoaderFactory ckpt_list = self._get_all_ckpt_names(load_dir, tag) sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine=self.checkpoint_engine) is_pipe_parallel = isinstance(self.module, PipelineModule) mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() load_path, checkpoint, _ = sd_loader.load(self.mp_world_size, mp_rank, is_pipe_parallel=is_pipe_parallel) if checkpoint is None: return None, None fetch_z3_params = False if self.zero_optimization_partition_weights() and not load_optimizer_states: checkpoint['module'] = get_fp32_state_dict_from_zero_checkpoint(load_dir) fetch_z3_params = True if is_pipe_parallel: # Pipeline parallelism uses this to load its own checkpoint files. self._curr_ckpt_path = os.path.join(load_dir, tag) if self.has_moe_layers: # print(checkpoint.keys()) old_moe_load = False if not isinstance(checkpoint['num_experts'], list): old_moe_load = True DeepSpeedEngine.load_moe_state_dict(load_dir, tag, state_dict=checkpoint['module'], old_moe_load=old_moe_load, model=self.module, mpu=self.mpu, num_experts=self.num_experts, checkpoint_engine=self.checkpoint_engine) if not self.load_universal_checkpoint(): self.load_module_state_dict(checkpoint=checkpoint, strict=load_module_strict, custom_load_fn=custom_load_fn, fetch_z3_params=fetch_z3_params) self.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size'] if 'zp_world_size' not in checkpoint: checkpoint['zp_world_size'] = self.zp_world_size self.loaded_checkpoint_zp_world_size = checkpoint['zp_world_size'] optim_checkpoint = None if load_module_only: deepspeed_states = ['module'] if self.optimizer is not None and self.fp16_enabled(): self.optimizer.refresh_fp32_params() else: has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled() if load_optimizer_states and self.optimizer is not None and not has_zero_optimizer_state: if self.has_moe_layers: largest_group_name = groups._get_max_expert_size_name() expp_rank = groups._get_expert_parallel_rank(largest_group_name) optim_load_path = self._get_optimizer_ckpt_name(load_dir, tag, expp_rank) optim_checkpoint = self.checkpoint_engine.load(optim_load_path, map_location=torch.device('cpu')) else: optim_checkpoint = checkpoint if self.fp16_enabled() or self.bfloat16_enabled(): self.optimizer.load_state_dict(optim_checkpoint['optimizer'], load_optimizer_states=load_optimizer_states) else: optim_checkpoint = checkpoint self.optimizer.load_state_dict(optim_checkpoint['optimizer']) if load_lr_scheduler_states and self.lr_scheduler is not None: self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) if self.random_ltd_enabled() and self.random_ltd_scheduler is not None and 'random_ltd' in checkpoint: self.random_ltd_scheduler.load_state_dict(checkpoint['random_ltd']) if self.training_dataloader is not None and self.curriculum_learning_enabled( ) and 'data_sampler' in checkpoint: self.training_dataloader.data_sampler.load_state_dict(checkpoint['data_sampler']) def get_sparse_tensor_module_names(original_set, loaded_set, original_parameters, loaded_parameters): result = set() for name in original_set: if name in loaded_parameters and name not in loaded_set: continue # parameter existed in previous model and was not sparse result.add(name) for name in loaded_set: if name in original_parameters: result.add(name) # parameter exists in both configs and it was sparse return result if 'sparse_tensor_module_names' in checkpoint: sparse_tensor_module_names = checkpoint['sparse_tensor_module_names'] elif 'csr_tensor_module_names' in checkpoint: sparse_tensor_module_names = checkpoint['csr_tensor_module_names'] else: sparse_tensor_module_names = None if sparse_tensor_module_names is not None: if load_module_strict: self.sparse_tensor_module_names = sparse_tensor_module_names else: self.sparse_tensor_module_names = get_sparse_tensor_module_names( self.sparse_tensor_module_names, sparse_tensor_module_names, dict(self.module.named_parameters()), checkpoint["module"]) self.global_steps = checkpoint['global_steps'] self.global_samples = checkpoint.get('global_samples', self.global_steps * self.train_batch_size()) self.skipped_steps = checkpoint['skipped_steps'] self.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size'] deepspeed_states = [ 'module', 'sparse_tensor_module_names', 'skipped_steps', 'global_steps', 'zp_world_size', 'mp_world_size', 'data_sampler', 'random_ltd', 'dp_world_size', ] client_state = {} if load_lr_scheduler_states: deepspeed_states.append('lr_scheduler') if load_optimizer_states: deepspeed_states.append('optimizer') client_state = {key: value for key, value in checkpoint.items() if not key in deepspeed_states} if optim_checkpoint is not None: client_state['optimizer'] = optim_checkpoint['optimizer'] return load_path, client_state def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): load_serial = None # When use loading checkpoint serial, checkpoint loading start from local rank 0, # all other local rank would be paused, waiting for its rank-1 peer ready and its notification. if self._config.zero_config.pipeline_loading_checkpoint: assert self.zero_optimization_stage( ) == ZeroStageEnum.weights, "Only stage3 support for pipeline checkpoint loading" load_serial = torch.zeros(1).to(self.device) if dist.get_local_rank() != 0: dist.recv(tensor=load_serial, src=dist.get_rank() - 1) if self.load_universal_checkpoint(): zero_sd_list = None checkpoint_folder = f'{os.path.join(load_dir, tag)}' else: if load_optimizer_states and self.zp_world_size != self.loaded_checkpoint_zp_world_size: raise ZeRORuntimeException("The checkpoint being loaded used a DP " \ f"world size of {self.loaded_checkpoint_zp_world_size} but the " \ f"current world size is {self.zp_world_size}. Automatic adjustment " \ "of ZeRO's optimizer state partitioning with a new world size is not " \ "currently supported.") checkpoint_folder = None zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag) if zero_sd_list is None: return False self.optimizer.load_state_dict(state_dict_list=zero_sd_list, load_optimizer_states=load_optimizer_states, load_from_fp32_weights=self.zero_load_from_fp32_weights(), checkpoint_folder=checkpoint_folder, load_serial=load_serial) if self.load_universal_checkpoint(): logger.info(f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}') else: logger.info(f"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}") return True def update_optimizer_step(self, step): def set_step(d): if isinstance(d['step'], torch.Tensor): d['step'] = torch.tensor(step, dtype=d['step'].dtype, device=d['step'].device) else: d['step'] = step optimizer = self.optimizer base_optimizer = optimizer.optimizer state = base_optimizer.state for group in optimizer.param_groups: if 'step' in group: set_step(group) for p in group['params']: if p in state and len(state[p]) > 0 and 'step' in state[p]: set_step(state[p]) def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size, bf16_mode): zero_ckpt_names = [] for dp_rank in range(dp_world_size): ckpt_name = self._get_rank_zero_ckpt_name(checkpoints_path=load_dir, tag=tag, mp_rank=mp_rank, dp_rank=dp_rank, bf16_mode=bf16_mode) zero_ckpt_names.append(ckpt_name) return zero_ckpt_names def _get_all_zero_checkpoint_names(self, load_dir, tag, bf16_mode): mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() zero_ckpt_names = self._get_mp_rank_zero_checkpoint_names(load_dir=load_dir, tag=tag, mp_rank=mp_rank, dp_world_size=self.loaded_checkpoint_dp_world_size, bf16_mode=bf16_mode) for i, ckpt_name in enumerate(zero_ckpt_names): if not os.path.exists(ckpt_name): # transparently handle the old file pattern for optim_states if "optim_states.pt" in ckpt_name: ckpt_name_try = ckpt_name.replace("_optim_states.pt", "optim_states.pt") if os.path.exists(ckpt_name_try): zero_ckpt_names[i] = ckpt_name_try continue return zero_ckpt_names def _get_all_zero_checkpoint_state_dicts(self, zero_ckpt_names): zero_sd_list = [] for i, ckpt_name in enumerate(zero_ckpt_names): _state = None if ckpt_name is None: _state = {OPTIMIZER_STATE_DICT: None} # Fully load state for current rank elif self.zero_elastic_checkpoint() or dist.get_rank(group=self.optimizer.zp_process_group) == i: _state = self.checkpoint_engine.load( ckpt_name, map_location='cpu', ) else: _state = {OPTIMIZER_STATE_DICT: None} zero_sd_list.append(_state) zero_optimizer_sd = [sd[OPTIMIZER_STATE_DICT] for sd in zero_sd_list] logger.info(f"successfully read {len(zero_optimizer_sd)} ZeRO state_dicts for rank {self.global_rank}") return zero_optimizer_sd def _get_all_zero_checkpoints(self, load_dir, tag): for bf16_mode in [self.bfloat16_enabled(), not self.bfloat16_enabled()]: zero_ckpt_names = self._get_all_zero_checkpoint_names(load_dir, tag, bf16_mode) if zero_ckpt_names is not None: # Warn if loading checkpoint of different bit16 type if bf16_mode is not self.bfloat16_enabled(): checkpoint_bit16 = BFLOAT16 if bf16_mode else FP16 engine_bit16 = BFLOAT16 if self.bfloat16_enabled() else FP16 logger.warn(f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine') return self._get_all_zero_checkpoint_state_dicts(zero_ckpt_names) return None def _checkpoint_tag_validation(self, tag): if self.checkpoint_tag_validation_enabled(): s_hash = hashlib.sha1(tag.encode()) bhash = torch.ByteTensor([s_hash.digest()]).flatten().to(self.device) max_bhash = bhash.clone() min_bhash = bhash.clone() dist.all_reduce(max_bhash, op=dist.ReduceOp.MAX) dist.all_reduce(min_bhash, op=dist.ReduceOp.MIN) valid = all(min_bhash == bhash) and all(max_bhash == bhash) msg = (f"[rank={dist.get_rank()}] The checkpoint tag name '{tag}' is not consistent across " "all ranks. Including rank unique information in checkpoint tag could cause issues when " "restoring with different world sizes.") if self.checkpoint_tag_validation_fail(): assert valid, msg elif not valid: logger.warning(msg) def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True, exclude_frozen_parameters=False): """Save training checkpoint Arguments: save_dir: Required. Directory for saving the checkpoint tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is used if not provided. Tag name must be the same across all ranks. client_state: Optional. State dictionary used for saving required training states in the client code. save_latest: Optional. Save a file 'latest' pointing to the latest saved checkpoint. exclude_frozen_parameters: Optional. Exclude frozen parameters from checkpointed state. Important: all processes must call this method and not just the process with rank 0. It is because each process needs to save its master weights and scheduler+optimizer states. This method will hang waiting to synchronize with other processes if it's called just for the process with rank 0. """ if self._optimizer_has_ckpt_event_prologue(): # Custom preparation for checkpoint save, if applicable self.optimizer.checkpoint_event_prologue() rank = self.local_rank if self.use_node_local_storage() else self.global_rank # This is to make sure the checkpoint names are created without collision # There seems to be issue creating them in parallel # Ensure save_dir directory exists if rank == 0: self.checkpoint_engine.makedirs(save_dir, exist_ok=True) dist.barrier() if tag is None: tag = f"global_step{self.global_steps}" # Ensure tag is a string tag = str(tag) self.checkpoint_engine.create(tag) # Ensure checkpoint tag is consistent across ranks self._checkpoint_tag_validation(tag) if self.has_moe_layers: self.save_non_zero_checkpoint = False self._create_checkpoint_file(save_dir, tag, False) self._save_moe_checkpoint(save_dir, tag, client_state=client_state, exclude_frozen_parameters=exclude_frozen_parameters) # We distribute the task of saving layer checkpoint files among # data parallel instances, so all procs should call _save_checkpoint. # All procs then call module_state_dict(), but only procs of data # parallel rank 0 save the general model params. if not self.has_moe_layers: self._create_checkpoint_file(save_dir, tag, False) self._save_checkpoint(save_dir, tag, client_state=client_state, exclude_frozen_parameters=exclude_frozen_parameters) if self.save_zero_checkpoint: self._create_zero_checkpoint_files(save_dir, tag) self._save_zero_checkpoint(save_dir, tag) if self._optimizer_has_ckpt_event_epilogue(): self.optimizer.checkpoint_event_epilogue() # Save latest checkpoint tag self.checkpoint_engine.commit(tag) if save_latest and rank == 0: with open(os.path.join(save_dir, 'latest'), 'w') as fd: fd.write(tag) dist.barrier() return True def _get_non_moe_state_dict(self, full_state_dict): """ Get the state dict of the non-moe layers """ for key in list(full_state_dict.keys()): if 'expert' in key and 'moe.gate.wg.weight' not in key: full_state_dict.pop(key) return full_state_dict def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False): save_path = self._get_ckpt_name(save_dir, tag) # A hack to save the checkpointing directory. Pipeline parallelism overrides # module_state_dict() and uses this path to save the model. module_state_dict() # then instead just returns None. # Using layer_#_export_# to save the model's expert state_dict moe_layer_id = 0 for n_module, module in self.module.named_modules(): if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0: group_name = module.expert_group_name num_local_experts = module.num_local_experts expp_rank = groups._get_expert_parallel_rank(group_name) exp_dp_rank = groups._get_expert_data_parallel_rank(group_name) # print(expp_rank, exp_dp_rank) if exp_dp_rank != 0: moe_layer_id += 1 continue # get all moe parameters moe_state_dict = {} for n, p in module.state_dict().items(): if 'expert' in n and 'moe.gate.wg.weight' not in n: moe_state_dict[n_module + '.' + n] = p moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' # print(moe_state_dict.keys()) # until now, everything is fine. So the bug happens at next few lines # Reorder the moe name rank, so that each checkpoint only has one expert experts_state_dict = defaultdict(dict) for key in list(moe_state_dict.keys()): m = re.match(f".*{moe_str_prefix}([0-9]+).*", key) local_expert_id = None if not m: logger.warn(f'No expert found in key {key}.') else: local_expert_id = m.group(1) global_expert_id = expp_rank * \ num_local_experts + int(local_expert_id) expert_key = key.replace(f'{moe_str_prefix}{local_expert_id}', f'{moe_str_prefix}{global_expert_id}') # truncating extra tensor (shared) storage truncated = moe_state_dict.pop(key).clone().detach() experts_state_dict[str(global_expert_id)][expert_key] = truncated # let save the moe parameters for global_expert_id, expert_state_dict in experts_state_dict.items(): # save the moe parameters moe_save_path = self._get_expert_ckpt_name(save_dir, moe_layer_id, global_expert_id, tag, self.mpu) if self.random_ltd_enabled(): expert_state_dict = remove_random_ltd_state_dict(expert_state_dict) self.checkpoint_engine.save(expert_state_dict, moe_save_path) moe_layer_id += 1 self._curr_ckpt_path = os.path.join(save_dir, tag) largest_group_name = groups._get_max_expert_size_name() expp_rank = groups._get_expert_parallel_rank(largest_group_name) exp_dp_rank = groups._get_expert_data_parallel_rank(largest_group_name) # In the case of E + D parallelism, only the # first expert parallel group should save the expert weights # since each expert parallel group is a copy of the model's experts if exp_dp_rank != 0: return # Save optimizer states. They are different across each exp parallel rank. optimizer_state = { 'optimizer': self.optimizer.state_dict() if self.optimizer and not self.zero_optimization() else None } # TODO: why use BufferedWriter not the path file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank) self.checkpoint_engine.save(optimizer_state, file_path) # get non-moe parameters model_state_dict = self._get_non_moe_state_dict( self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters)) if expp_rank == 0: # TODO: update num experts info,.. in checkpoint state = { 'module': model_state_dict, 'lr_scheduler': self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, 'data_sampler': self.training_dataloader.data_sampler.state_dict() if (self.training_dataloader is not None and self.curriculum_learning_enabled()) else None, 'random_ltd': self.random_ltd_scheduler.state_dict() if self.random_ltd_enabled() else None, 'sparse_tensor_module_names': self.sparse_tensor_module_names, 'skipped_steps': self.skipped_steps, 'global_steps': self.global_steps, 'global_samples': self.global_samples, 'zp_world_size': self.zp_world_size, 'dp_world_size': self.dp_world_size, 'mp_world_size': self.mp_world_size, 'num_experts': self.num_experts } state.update(client_state) logger.info(f'Saving model checkpoint: {save_path}') self.checkpoint_engine.save(state, save_path) def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint): name_function = (self._get_zero_ckpt_name if zero_checkpoint else self._get_ckpt_name) try: checkpoint_name = name_function(save_dir, tag) path = os.path.dirname(checkpoint_name) self.checkpoint_engine.makedirs(path, exist_ok=True) except: logger.error(f"Failed saving model checkpoint to {save_dir} with tag {tag}") return False return True def _create_zero_checkpoint_files(self, save_dir, tag): success = True # zero checkpoint files are created sequentially for rank in range(dist.get_world_size(self.optimizer.zp_process_group)): if rank == self.global_rank: success = self._create_checkpoint_file(save_dir, tag, True) dist.barrier(group=self.optimizer.zp_process_group) return success def _save_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False): save_path = self._get_ckpt_name(save_dir, tag) zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled() save_frozen_param = self.zero_optimization_partition_gradients() and not exclude_frozen_parameters # A hack to save the checkpointing directory. Pipeline parallelism overrides # module_state_dict() and uses this path to save the model. module_state_dict() # then instead just returns None. The module_state_dict() implementation in # PipelineEngine expects the save path to be set in self._curr_ckpt_path. self._curr_ckpt_path = os.path.join(save_dir, tag) module = self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters) self._curr_ckpt_path = None state = dict(module=module, buffer_names=self._get_buffer_names(), optimizer=self.optimizer.state_dict() if self.optimizer and not zero_optimizer_state else None, param_shapes=self._get_zero_param_shapes() if self.optimizer and zero_optimizer_state else None, frozen_param_shapes=self._get_zero_frozen_param_attributes(self._get_param_shape_func) if save_frozen_param else None, shared_params=self._get_shared_params() if self.optimizer and zero_optimizer_state else None, frozen_param_fragments=self._get_zero_frozen_param_attributes(self._get_param_fragment_func) if save_frozen_param else None, lr_scheduler=self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, data_sampler=self.training_dataloader.data_sampler.state_dict() if (self.training_dataloader is not None and self.curriculum_learning_enabled()) else None, random_ltd=self.random_ltd_scheduler.state_dict() if self.random_ltd_enabled() else None, sparse_tensor_module_names=self.sparse_tensor_module_names, skipped_steps=self.skipped_steps, global_steps=self.global_steps, global_samples=self.global_samples, dp_world_size=self.seq_dp_world_size, mp_world_size=self.mp_world_size, ds_config=self.config, ds_version=version) state.update(client_state) if self.save_non_zero_checkpoint: log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1]) self.checkpoint_engine.save(state, save_path) def _get_buffer_names(self): buffer_names = [] # we save buffer names so that we could extract later the real buffers from the saved # state_dict["module"] in the non-zero checkpoint - the buffers are already there but they # are intermixed with param placeholders # have to traverse the tree to be able to skip non-persistent buffers def get_layer_named_buffers(module, prefix=""): for name, buf in module.named_buffers(recurse=False): if buf is not None and name not in module._non_persistent_buffers_set: buffer_names.append(prefix + name) for name, child in module.named_children(): if child is not None: get_layer_named_buffers(child, prefix + name + ".") get_layer_named_buffers(self.module, prefix="") return buffer_names def _get_param_shape_func(self, param): return param.ds_shape if hasattr(param, 'ds_id') else param.shape def _get_param_fragment_func(self, param): return param.ds_tensor.detach().cpu() if hasattr(param, 'ds_id') else param.detach().cpu() def _get_zero_frozen_param_attributes(self, attr_func): frozen_param_fragments = OrderedDict() for param in self.module.parameters(): if param.requires_grad: continue if param not in self.param_names: raise ValueError(f"failed to find frozen {param} in named params") name = self.param_names[param] frozen_param_fragments[name] = attr_func(param) return frozen_param_fragments def _get_zero_param_shapes(self): """Returns a dict of name to shape mapping, only for the flattened fp32 weights saved by the optimizer. the names are exactly as in state_dict. The order is absolutely important, since the saved data is just flattened data with no identifiers and requires reconstruction in the same order it was saved. We can't rely on self.module.named_parameters() to get the saved tensors, as some params will be missing and others unsaved and then it'd be impossible to reconstruct state_dict from the flattened weights. optimizer.bit16_groups seems to be the easiest to use as it's in all zeroX versions. """ param_group_shapes = [] cnt = 0 numel = 0 # zero2 started using a round_robin_bit16_groups which is a shuffled version of bit16_groups - # if we don't use it, we get parameters ordered incorrectly if hasattr(self.optimizer, "round_robin_bit16_groups"): bit16_groups = self.optimizer.round_robin_bit16_groups elif self.bfloat16_enabled() and hasattr(self.optimizer, "bf16_groups"): bit16_groups = self.optimizer.bf16_groups else: bit16_groups = self.optimizer.bit16_groups if self.zero_optimization_stage( ) == 2 else self.optimizer.fp16_groups for bit16_group in bit16_groups: param_shapes = OrderedDict() for param in bit16_group: cnt += 1 numel += param.ds_numel if hasattr(param, "ds_numel") else param.numel() shape = param.ds_shape if hasattr(param, "ds_shape") else param.shape if param not in self.param_names: raise ValueError(f"failed to find optimizer param in named params") name = self.param_names[param] param_shapes[name] = shape # uncomment to debug zero_to_fp32.py problems # if self.global_rank == 0: print(f"saving param {name} {shape} (numel={shape.numel()})") param_group_shapes.append(param_shapes) # if self.global_rank == 0: print(f"Total saved {numel} numels in {cnt} params") return param_group_shapes def _get_shared_params(self): """ Returns a dict of shared params, which can later be used to reconstruct the original state dict, e.g. in `zero_to_fp32`. Each dict entry is a pair of param names, where the key is the name of the variable that isn't stored and the value is the actual param holding data. """ shared_index = {} shared_params_by_full_name = {} is_zero3_model = (self.zero_optimization_partition_weights() and any(hasattr(param, "ds_id") for param in self.module.parameters())) def get_layer_state_dict(module, prefix=""): # handle params for name, param in module.named_parameters(recurse=False): if param is None or (is_zero3_model and not hasattr(param, "ds_id")): continue key = prefix + name # When weights are manged by stage 3, we can't rely on param.data_ptr() as it will be reused # as weights get gathered and reduced, but param.ds_id is unique across all zero weights # (and shared params will have the same param.ds_id) param_id = param.ds_id if is_zero3_model else param.data_ptr() if param_id in shared_index: # shared weights #print(f"`{key}` is shared with `{shared_index[param_id]}`") shared_params_by_full_name[key] = shared_index[param_id] else: shared_index[param_id] = key for name, child in module.named_children(): if child is not None: get_layer_state_dict(child, prefix + name + ".") if dist.get_rank() == 0: get_layer_state_dict(self.module, prefix="") return shared_params_by_full_name def _copy_recovery_script(self, save_path): base_dir = os.path.dirname(os.path.dirname(__file__)) script = "zero_to_fp32.py" src = os.path.join(base_dir, "utils", script) dst = os.path.join(save_path, script) #logger.info(f"creating recovery script {dst}") copyfile(src, dst) self._change_recovery_script_permissions(dst) def _change_recovery_script_permissions(self, dst): # make executable (safeguard for file shares - Azure as example) try: os.chmod(dst, os.stat(dst).st_mode | stat.S_IEXEC) except (FileNotFoundError, PermissionError) as e: #this message is used in unit test TestZeRONonDistributed logger.info( f'Warning: Could not change permissions for {dst} due to error: {e}. Continuing without changing permissions.' ) def _save_zero_checkpoint(self, save_path, tag): zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag) zero_sd = dict(optimizer_state_dict=self.optimizer.state_dict(), ds_config=self.config, ds_version=version) self.checkpoint_engine.save(zero_sd, zero_checkpoint_name) if self.global_rank == 0: self._copy_recovery_script(save_path) ckpt_type = 'zero' if self.zero_optimization() else 'bf16_zero' logger.info(f'{ckpt_type} checkpoint saved {zero_checkpoint_name}') def _zero3_consolidated_16bit_state_dict(self): """ Get a full non-partitioned state_dict with fp16 weights on cpu. Important: this function must be called on all ranks and not just rank 0. This is similar to nn.Module.state_dict (modelled after _save_to_state_dict), but: 1. consolidates the weights from different partitions on gpu0 2. works on one layer at a time to require as little gpu0 memory as possible, by moving the already consolidated weights to cpu 3. takes care to keep the shared params shared when gradually copying the params to cpu Returns: a consolidated fp16 ``state_dict`` on cpu on rank 0, ``None`` on other ranks """ if not self.zero_optimization_partition_weights(): raise ValueError("this function requires ZeRO-3 mode") state_dict = OrderedDict() if dist.get_rank() == 0 else None shared_params = {} def get_layer_state_dict(module, prefix=""): # gather one layer at a time to be memory-efficient # must use modifier_rank=0 to release GPU memory after each layer gathered #see_memory_usage("before GatheredParameters", force=True) with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): if dist.get_rank() == 0: # handle params for name, param in module.named_parameters(recurse=False): if param is None: continue key = prefix + name # can't rely on param.data_ptr() as it will be reused as weights gets # gathered and reduced, but param.ds_id is unique across all zero weights # (and shared params will have the same param.ds_id) if param.ds_id in shared_params: # shared weights #print(f"`{key}` is shared with `{shared_params[param.ds_id]}`") state_dict[key] = state_dict[shared_params[param.ds_id]] else: state_dict[key] = param.detach().cpu() shared_params[param.ds_id] = key #print(f"param {param.ds_id} {param.shape} {key} ") # now buffers - not sure if need to take care of potentially shared weights here for name, buf in module.named_buffers(recurse=False): if (buf is not None and name not in module._non_persistent_buffers_set): state_dict[prefix + name] = buf.detach().cpu() #see_memory_usage("after GatheredParameters", force=True) for name, child in module.named_children(): if child is not None: get_layer_state_dict(child, prefix + name + ".") # Prepare for checkpoint save by ensuring all parameters are partitioned if self._optimizer_has_ckpt_event_prologue(): self.optimizer.checkpoint_event_prologue() see_memory_usage("before get_layer_state_dict", force=False) get_layer_state_dict(self.module, prefix="") see_memory_usage("after get_layer_state_dict", force=False) if self._optimizer_has_ckpt_event_epilogue(): self.optimizer.checkpoint_event_epilogue() return state_dict def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"): """has been renamed to save_16bit_model, keeping this around for backwards compatibility""" return self.save_16bit_model(save_dir, save_filename) def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin"): """ Save 16bit model weights This method saves the 16bit model weights at the desired destination. Arguments: save_dir: Required. Directory for saving the model save_filename: Optional. Filename to save to. Defaults to ``pytorch_model.bin`` Returns: ``True`` when a model has been saved, ``False`` otherwise. It will not be saved if stage3_gather_16bit_weights_on_model_save is ``False``. Important: all processes must call this method and not just the process with rank 0. It is because the processes need to work in sync to gather the weights. This method will hang waiting to synchronize with other processes if it's called just for the process with rank 0. """ path = os.path.join(save_dir, save_filename) if self.zero_optimization_partition_weights(): if self.zero_gather_16bit_weights_on_model_save(): # consolidation is expensive in time and memory and therefore isn't a default state_dict = self._zero3_consolidated_16bit_state_dict() else: # the model will be bogus if not consolidated so don't confuse the user by saving it logger.info( f"Did not save the model {path} because `stage3_gather_16bit_weights_on_model_save` is False") return False else: state_dict = self.module.state_dict() tag = f"global_step{self.global_steps}" tag = str(tag) self.checkpoint_engine.create(tag) if dist.get_rank() == 0: self.checkpoint_engine.makedirs(save_dir, exist_ok=True) logger.info(f"Saving model weights to {path}, tag: {tag}") self.checkpoint_engine.save(state_dict, path) self.checkpoint_engine.commit(tag) return True def empty_partition_cache(self): """ Release GPU memory consumed by offloaded model parameters. """ if hasattr(self.optimizer, 'empty_partition_cache'): self.optimizer.empty_partition_cache() gc.collect() get_accelerator().empty_cache() ================================================ FILE: opensora/adaptor/modules.py ================================================ import torch from torch import nn from torch.nn import functional as F def fp32_layer_norm_forward(self, inputs: torch.Tensor) -> torch.Tensor: origin_dtype = inputs.dtype return F.layer_norm(inputs.float(), self.normalized_shape, self.weight.float() if self.weight is not None else None, self.bias.float() if self.bias is not None else None, self.eps).to(origin_dtype) def fp32_silu_forward(self, inputs: torch.Tensor) -> torch.Tensor: return torch.nn.functional.silu(inputs.float(), inplace=self.inplace).to(inputs.dtype) def fp32_gelu_forward(self, inputs: torch.Tensor) -> torch.Tensor: return torch.nn.functional.gelu(inputs.float(), approximate=self.approximate).to(inputs.dtype) def replace_with_fp32_forwards(): nn.GELU.forward = fp32_gelu_forward nn.SiLU.forward = fp32_silu_forward nn.LayerNorm.forward = fp32_layer_norm_forward ================================================ FILE: opensora/adaptor/stage_1_and_2.py ================================================ # Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team import torch import os import pdb from deepspeed import comm as dist from packaging import version as pkg_version from collections import OrderedDict from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from deepspeed.runtime import ZeROOptimizer from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.utils import (bwc_tensor_model_parallel_rank, get_global_norm, empty_cache, see_memory_usage, inf, is_model_parallel_parameter, align_dense_tensors, all_gather_dp_groups) from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.utils import logger from deepspeed.moe.utils import is_moe_param from deepspeed.git_version_info import version from deepspeed.runtime.constants import PIPE_REPLICATED from deepspeed.accelerator import get_accelerator from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER, SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE, BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS) from deepspeed.utils import link_hp_params from deepspeed.checkpoint import enable_universal_checkpoint from deepspeed.utils import groups from opensora.adaptor.zp_manager import zp_manager # Toggle this to true to enable correctness test # with gradient partitioning and without pg_correctness_test = False OPTIMIZER_ALLGATHER_TIMER = 'optimizer_allgather' OPTIMIZER_GRADIENTS_TIMER = 'optimizer_gradients' OPTIMIZER_STEP_TIMER = 'optimizer_step' OPTIMIZER_TIMERS = [OPTIMIZER_ALLGATHER_TIMER, OPTIMIZER_GRADIENTS_TIMER, OPTIMIZER_STEP_TIMER] def input(msg): return def split_half_float_double(tensors): device_type = get_accelerator().device_name() dtypes = [ "torch.{}.HalfTensor".format(device_type), "torch.{}.FloatTensor".format(device_type), "torch.{}.DoubleTensor".format(device_type), "torch.{}.BFloat16Tensor".format(device_type) ] buckets = [] for i, dtype in enumerate(dtypes): bucket = [t for t in tensors if t.type() == dtype] if bucket: buckets.append(bucket) return buckets def isclose(a, b, rtol=1e-09, atol=0.0): return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol) def lcm(x, y): from fractions import gcd # or can import gcd from `math` in Python 3 return x * y // gcd(x, y) def get_alignment_padding(tensor_list, alignment): num_elements = sum([tensor.numel() for tensor in tensor_list]) remainder = num_elements % alignment return (alignment - remainder) if remainder else remainder def move_to_cpu(tensor_list): for tensor in tensor_list: tensor.data = tensor.data.cpu() def print_rank_msg(msg): print(f"rank {dist.get_rank()} - {msg}") def _get_padded_tensor(src_tensor, size): if src_tensor.numel() >= size: return src_tensor padded_tensor = torch.zeros(size, dtype=src_tensor.dtype, device=src_tensor.device) slice_tensor = torch.narrow(padded_tensor, 0, 0, src_tensor.numel()) slice_tensor.data.copy_(src_tensor.data) return padded_tensor def contigous_flatten(tensors): return _flatten_dense_tensors([tensor.contiguous() for tensor in tensors]) def all_gather_into_tensor_dp_groups(groups_flat, partitioned_param_groups, zp_process_group): for group_id, (group_flat, partitioned_params) in enumerate(zip(groups_flat, partitioned_param_groups)): partition_id = dist.get_rank(group=zp_process_group[group_id]) dp_world_size = dist.get_world_size(group=zp_process_group[group_id]) if dp_world_size == 1: # no groups share optimizer states # pipeline parallel with bf16 will default call this even if dp size = 1. continue input_tensor = partitioned_params[partition_id].contiguous() # print(f"call all_gather_into_tensor_dp_groups, input size is {input_tensor.size()}, " # f"output size is {group_flat.size()}") # # print(f"groups_flat.size = {groups_flat.numel()}") # print(f"partitioned_param_groups = {sum([v.numel() for v in partitioned_param_groups])}") dist.all_gather_into_tensor(group_flat, input_tensor, zp_process_group[group_id]) class DeepSpeedZeroOptimizer(ZeROOptimizer): """ DeepSpeedZeroOptimizer designed to reduce the memory footprint required for training large deep learning models. For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models https://arxiv.org/abs/1910.02054 For usage examples, refer to TODO: DeepSpeed Tutorial """ def __init__(self, init_optimizer, param_names, timers, static_loss_scale=1.0, dynamic_loss_scale=False, dynamic_loss_args=None, verbose=True, contiguous_gradients=True, reduce_bucket_size=500000000, use_multi_rank_bucket_allreduce=True, allgather_bucket_size=5000000000, dp_process_group=None, expert_parallel_group=None, expert_data_parallel_group=None, reduce_scatter=True, overlap_comm=False, offload_optimizer_config=None, mpu=None, clip_grad=0.0, gradient_accumulation_dtype=torch.float32, communication_data_type=torch.float16, postscale_gradients=True, gradient_predivide_factor=1.0, gradient_accumulation_steps=1, ignore_unused_parameters=True, partition_grads=True, round_robin_gradients=False, has_moe_layers=False, fp16_master_weights_and_gradients=False, elastic_checkpoint=False): if offload_optimizer_config is not None and offload_optimizer_config.device != OffloadDeviceEnum.none: self.cpu_offload = True self.cpu_offload_pin_memory = offload_optimizer_config.pin_memory else: self.cpu_offload = False self.cpu_offload_pin_memory = False if dist.get_rank() == 0: logger.info(f"Reduce bucket size {reduce_bucket_size}") logger.info(f"Allgather bucket size {allgather_bucket_size}") logger.info(f"CPU Offload: {self.cpu_offload}") logger.info(f'Round robin gradient partitioning: {round_robin_gradients}') # The fused optimizer does all the work. We need this layer for two reason: # 1. maintain same user API from apex.fp16_utils # 2. keep common stuff here in case we need to add ne552w fused optimizer later self.elastic_checkpoint = elastic_checkpoint self.param_names = param_names self.mpu = mpu # differences from apex.fp16_utils: # - assume all model params in fp16 # - assume all params requires grad # - flat by groups, not keeping state. TODO: remove state explicitly? # - master grad and unflat master weight never exist. TODO: a way to save out unflat master? if not get_accelerator().is_available(): raise SystemError("Accelerator is not detected, cannot perform low precision training (e.g., fp16, bf16).") self.optimizer = init_optimizer # Use torch (un)flatten ops self.flatten = contigous_flatten self.unflatten = _unflatten_dense_tensors # ZeRO stage 1 (False) or 2 (True) self.partition_gradients = partition_grads self.zero_stage_string = "ZeRO-2" if partition_grads else "ZeRO-1" self.timers = timers self.reduce_scatter = reduce_scatter self.overlap_comm = overlap_comm self.deepspeed_adam_offload = self.cpu_offload self.device = get_accelerator().current_device_name() if not self.cpu_offload else 'cpu' zp_manager.init_group() self.zp_process_group = zp_manager.zp_group zp_rank = dist.get_rank(group=self.zp_process_group) zp_size = dist.get_world_size(group=self.zp_process_group) print(f"zp rank is {zp_rank}, zp_size={zp_size}") self.dp_process_group = dp_process_group self.sequence_parallel_size = groups._get_sequence_parallel_world_size() # expert parallel group self.ep_process_group = expert_parallel_group # data parallel group for experts self.expert_dp_process_group = expert_data_parallel_group # data parallel size for non-experts dp_size = dist.get_world_size(group=self.dp_process_group) # For MoE models this maybe different for different param group # It will be modified during MoE setup later in the init self.real_zp_process_group = [self.zp_process_group for i in range(len(self.optimizer.param_groups))] self.real_dp_process_group = [self.dp_process_group for i in range(len(self.optimizer.param_groups))] self.partition_count = [zp_manager.zp_size for i in range(len(self.optimizer.param_groups))] self.is_gradient_accumulation_boundary = True # CPU-Offload requires contiguous gradients self.contiguous_gradients = contiguous_gradients or self.cpu_offload self.has_moe_layers = has_moe_layers if self.has_moe_layers: self._configure_moe_settings() self._global_grad_norm = 0. if mpu is None: self.model_parallel_group = None self.model_parallel_world_size = 1 self.model_parallel_rank = 0 else: self.model_parallel_group = mpu.get_model_parallel_group() self.model_parallel_world_size = mpu.get_model_parallel_world_size() self.model_parallel_rank = bwc_tensor_model_parallel_rank(mpu) self.overflow = False self.clip_grad = clip_grad self.communication_data_type = communication_data_type self.gradient_predivide_factor = gradient_predivide_factor self.postscale_gradients = postscale_gradients self.gradient_accumulation_steps = gradient_accumulation_steps self.micro_step_id = 0 self.ignore_unused_parameters = ignore_unused_parameters self.round_robin_gradients = round_robin_gradients self.extra_large_param_to_reduce = None self.fp16_master_weights_and_gradients = fp16_master_weights_and_gradients if self.fp16_master_weights_and_gradients: assert self.cpu_offload and type(self.optimizer) in [DeepSpeedCPUAdam], \ f"fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32." \ f"Currently only supported using ZeRO-Offload with DeepSpeedCPUAdam. But current setting is ZeRO-Offload:{self.cpu_offload} and optimizer type {type(self.optimizer)}." \ f"Either disable fp16_master_weights_and_gradients or enable {self.zero_stage_string} Offload with DeepSpeedCPUAdam." if self.reduce_scatter: valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32) assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled" assert self.postscale_gradients, "pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled" # param flattened by groups self.bit16_groups = [] self.bit16_groups_flat = [] # param partitioned by data parallel degree # this will contain a list of equal sized tensors # each of which will be updated by a different process self.parallel_partitioned_bit16_groups = [] # a single 32-bit partition of the parallel partitioned parameters # that this process will update self.single_partition_of_fp32_groups = [] # param partition info # These are the parameters in each group that will not be updated by this process directly self.params_not_in_partition = [] # These are the parameters that will be updated by this process directly self.params_in_partition = [] # Offset from the first parameter in the self.params_in_partition # the parameter boundaries may not align with partition boundaries # so we need to keep track of the offset self.first_offset = [] # number of elements per partition in each group self.partition_size = [] # align nccl all-gather send buffers to 4-byte boundary self.nccl_start_alignment_factor = 16 # 4-byte alignment/sizeof(fp16) = 2 assert ( allgather_bucket_size % self.nccl_start_alignment_factor == 0 ), f"allgather_bucket_size must be a multiple of nccl_start_alignment_factor, {self.nccl_start_alignment_factor} " self.all_reduce_print = False self.dtype = self.optimizer.param_groups[0]['params'][0].dtype self.gradient_accumulation_dtype = gradient_accumulation_dtype if self.dtype != self.gradient_accumulation_dtype: self.use_separate_grad_accum = True else: self.use_separate_grad_accum = False if self.use_separate_grad_accum and not self.partition_gradients: self.use_grad_accum_attribute = True else: self.use_grad_accum_attribute = False self.round_robin_bit16_groups = [] self.round_robin_bit16_indices = [] # Use different parallel to do all_to_all_reduce related things # padding on each partition for alignment purposes self.groups_padding = [] # loop to deal with groups for i, param_group in enumerate(self.optimizer.param_groups): partition_id = dist.get_rank(group=self.real_zp_process_group[i]) # push this group to list before modify # TODO: Explore simplification that avoids the extra book-keeping by pushing the reordered group trainable_parameters = [] for param in param_group['params']: if param.requires_grad: param.grad_accum = None trainable_parameters.append(param) self.bit16_groups.append(trainable_parameters) # not sure why apex was cloning the weights before flattening # removing cloning here see_memory_usage(f"Before moving param group {i} to CPU") # move all the parameters to cpu to free up GPU space for creating flat buffer move_to_cpu(self.bit16_groups[i]) empty_cache() see_memory_usage(f"After moving param group {i} to CPU", force=False) # Reorder group parameters for load balancing of gradient partitioning during backward among ranks. # This ensures that gradients are reduced in a fashion such that ownership round robins among the ranks. # For example, rather than 3 gradients (g_n+2, g_n+1, g_n) that are reduced consecutively belonging # to the same rank, instead they will belong to 3 ranks (r_m+2, r_m+1, r_m). if self.round_robin_gradients: round_robin_tensors, round_robin_indices = self._round_robin_reorder( self.bit16_groups[i], dist.get_world_size(group=self.real_zp_process_group[i])) else: round_robin_tensors = self.bit16_groups[i] round_robin_indices = list(range(len(self.bit16_groups[i]))) self.round_robin_bit16_groups.append(round_robin_tensors) self.round_robin_bit16_indices.append(round_robin_indices) # create flat buffer in CPU and move to GPU self.bit16_groups_flat.append( self.flatten_dense_tensors_aligned( self.round_robin_bit16_groups[i], self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_zp_process_group[i])).to( get_accelerator().current_device_name())) see_memory_usage(f"After flattening and moving param group {i} to GPU", force=False) # Record padding required for alignment if partition_id == dist.get_world_size(group=self.real_zp_process_group[i]) - 1: padding = self.bit16_groups_flat[i].numel() - sum( [t.numel() for t in self.round_robin_bit16_groups[i]]) else: padding = 0 self.groups_padding.append(padding) if dist.get_rank(group=self.real_zp_process_group[i]) == 0: see_memory_usage(f"After Flattening and after emptying param group {i} cache", force=False) # set model bit16 weight to slices of flattened buffer self._update_model_bit16_weights(i) # divide the flat weights into near equal partition equal to the data parallel degree # each process will compute on a different part of the partition data_parallel_partitions = self.get_data_parallel_partitions(self.bit16_groups_flat[i], i) self.parallel_partitioned_bit16_groups.append(data_parallel_partitions) # print(f"self.bit16_groups_flat[i].size = {self.bit16_groups_flat[i].numel()}") # print(f"data_parallel_partitions = {sum([v.numel() for v in data_parallel_partitions])}") # verify that data partition start locations are 4-byte aligned for partitioned_data in data_parallel_partitions: assert (partitioned_data.data_ptr() % (2 * self.nccl_start_alignment_factor) == 0) # A partition of the fp32 master weights that will be updated by this process. # Note that the params in single_partition_of_fp32_groups is cloned and detached # from the origin params of the model. if not fp16_master_weights_and_gradients: self.single_partition_of_fp32_groups.append(self.parallel_partitioned_bit16_groups[i][partition_id].to( self.device).clone().float().detach()) else: self.single_partition_of_fp32_groups.append(self.parallel_partitioned_bit16_groups[i][partition_id].to( self.device).clone().half().detach()) # Set local optimizer to have flat params of its own partition. # After this, the local optimizer will only contain its own partition of params. # In that case, the local optimizer only saves the states(momentum, variance, etc.) related to its partition's params(zero stage1). self.single_partition_of_fp32_groups[ i].requires_grad = True # keep this in case internal optimizer uses it param_group['params'] = [self.single_partition_of_fp32_groups[i]] partition_size = len(self.bit16_groups_flat[i]) / dist.get_world_size(group=self.real_zp_process_group[i]) params_in_partition, params_not_in_partition, first_offset = self.get_partition_info( self.round_robin_bit16_groups[i], partition_size, partition_id) self.partition_size.append(partition_size) self.params_in_partition.append(params_in_partition) self.params_not_in_partition.append(params_not_in_partition) self.first_offset.append(first_offset) self.reduce_bucket_size = int(reduce_bucket_size) self.use_multi_rank_bucket_allreduce = use_multi_rank_bucket_allreduce self.allgather_bucket_size = int(allgather_bucket_size) self.reduction_stream = None if get_accelerator().is_synchronized_device() else get_accelerator().Stream() # self.copy_grad_stream = get_accelerator().Stream() self.callback_queued = False self.param_dict = {} # map between param_id and bool to specify if a param is in this partition self.is_param_in_current_partition = {} self.grads_in_ipg_bucket = [] self.params_in_ipg_bucket = [] self.elements_in_ipg_bucket = 0 self.params_already_reduced = [] self._release_ipg_buffers() self.previous_reduced_grads = None self.ipg_bucket_has_moe_params = False # simplified param id self.param_id = {} # interesting code: unique ids being assigned to individual parameters largest_param_numel = 0 count = 0 for i, params_group in enumerate(self.bit16_groups): for param in params_group: unique_id = id(param) self.param_id[unique_id] = count self.param_dict[count] = param self.params_already_reduced.append(False) if param.numel() > largest_param_numel: largest_param_numel = param.numel() count = count + 1 for param_group in self.params_in_partition: for param in param_group: self.is_param_in_current_partition[self.get_param_id(param)] = True for param_group in self.params_not_in_partition: for param in param_group: self.is_param_in_current_partition[self.get_param_id(param)] = False if self.cpu_offload: self.accumulated_grads_in_cpu = {} self.norm_for_param_grads = {} self.local_overflow = False self.grad_position = {} self.temp_grad_buffer_for_cpu_offload = torch.zeros(largest_param_numel, device=self.device, dtype=self.dtype) if self.cpu_offload_pin_memory: self.temp_grad_buffer_for_cpu_offload = get_accelerator().pin_memory( self.temp_grad_buffer_for_cpu_offload) self.temp_grad_buffer_for_gpu_offload = torch.zeros(largest_param_numel, device=get_accelerator().current_device_name(), dtype=self.dtype) for i, params_group in enumerate(self.bit16_groups): self.get_grad_position(i, self.params_in_partition[i], self.first_offset[i], self.partition_size[i]) # mapping from parameter to partition that it belongs to self.param_to_partition_ids = {} # stores if a partition has been reduced in this step self.is_partition_reduced = {} # number of grads in partition that still need to be computed self.remaining_grads_in_partition = {} # total number of grads in partition self.total_grads_in_partition = {} # stores if a grad in a partition has been computed or not self.is_grad_computed = {} # stores the offset at which a parameter gradient needs to be inserted in a partition self.grad_partition_insertion_offset = {} # the offset in the gradient at which it must be inserted at the beginning of the partition self.grad_start_offset = {} # will store the averaged gradients required by this partition self.averaged_gradients = {} # For cpu_offload, will store the averaged gradients required by this partition self.offload_gradient_dict = {} # store index of first parameter in each partition self.first_param_index_in_partition = {} # initializes all data structures for implementing gradient partitioning self.initialize_gradient_partitioning_data_structures() # resets the data structure value for the next backward propagation self.reset_partition_gradient_structures() # creates backward hooks for gradient partitioning if self.partition_gradients or self.overlap_comm: self.create_reduce_and_remove_grad_hooks() self.custom_loss_scaler = False self.external_loss_scale = None # we may have a way of fusing dynamic scale. Do not support for now self.loss_scaler = CreateLossScaler(dtype=self.dtype, static_loss_scale=static_loss_scale, dynamic_scaling=dynamic_loss_scale, dynamic_loss_args=dynamic_loss_args) self.dynamic_loss_scale = self.loss_scaler.dynamic if self.dtype != torch.float16: # Only fp16 should use dynamic loss scaling assert self.loss_scaler.cur_scale == 1.0 assert not self.dynamic_loss_scale see_memory_usage("Before initializing optimizer states", force=True) self.initialize_optimizer_states() see_memory_usage("After initializing optimizer states", force=True) if dist.get_rank() == 0: logger.info(f"optimizer state initialized") if dist.get_rank(group=self.zp_process_group) == 0: see_memory_usage(f"After initializing ZeRO optimizer", force=True) self._link_all_hp_params() self._enable_universal_checkpoint() self._param_slice_mappings = self._create_param_mapping() def _enable_universal_checkpoint(self): for lp_param_group in self.bit16_groups: enable_universal_checkpoint(param_list=lp_param_group) def _create_param_mapping(self): param_mapping = [] for i, _ in enumerate(self.optimizer.param_groups): param_mapping_per_group = OrderedDict() for lp in self.bit16_groups[i]: if lp._hp_mapping is not None: lp_name = self.param_names[lp] param_mapping_per_group[lp_name] = lp._hp_mapping.get_hp_fragment_address() param_mapping.append(param_mapping_per_group) return param_mapping def _link_all_hp_params(self): dp_world_size = dist.get_world_size(group=self.zp_process_group) if self.cpu_offload: self._get_offload_gradient_dict() for i, _ in enumerate(self.optimizer.param_groups): # Link bit16 and fp32 params in partition partition_id = dist.get_rank(group=self.real_zp_process_group[i]) partition_size = self.bit16_groups_flat[i].numel() // dp_world_size flat_hp_partition = self.single_partition_of_fp32_groups[i] link_hp_params(lp_param_list=self.bit16_groups[i], flat_hp_partition=flat_hp_partition, gradient_dict=self.averaged_gradients, offload_gradient_dict=self.offload_gradient_dict, use_offload=self.cpu_offload, param_group_index=i, partition_start=partition_id * partition_size, partition_size=partition_size, partition_optimizer_state=self.optimizer.state[flat_hp_partition], dp_group=self.real_zp_process_group[i]) def is_moe_group(self, group): return 'moe' in group and group['moe'] def _configure_moe_settings(self): # if we're using ZeRO stage 2, ensure contiguous gradients are used if self.partition_gradients: assert self.contiguous_gradients, "Contiguous Gradients in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE" # NOTE: To run ZeRO stage 1 with MoE, we need to set self.contiguous_gradients to True or ignore the assertion if not self.partition_gradients and not self.contiguous_gradients: logger.warn( "ZeRO Stage 1 has not been thoroughly tested with MoE. This configuration is still experimental.") assert self.reduce_scatter, "Reduce Scatter in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE" assert any( [self.is_moe_group(group) for group in self.optimizer.param_groups] ), "The model has moe layers, but None of the param groups are marked as MoE. Create a param group with 'moe' key set to True before creating optimizer" self.is_moe_param_group = [] for i, group in enumerate(self.optimizer.param_groups): if self.is_moe_group(group): assert all([is_moe_param(param) for param in group['params']]), "All params in MoE group must be MoE params" self.real_zp_process_group[i] = self.expert_dp_process_group[group['name']] self.partition_count[i] = dist.get_world_size(group=self.expert_dp_process_group[group['name']]) self.is_moe_param_group.append(True) else: self.is_moe_param_group.append(False) assert self.expert_dp_process_group is not None, "Expert data parallel group should be configured with MoE" assert self.ep_process_group is not None, "Expert parallel group should be configured with MoE" def _update_model_bit16_weights(self, group_index): updated_params = self.unflatten(self.bit16_groups_flat[group_index], self.round_robin_bit16_groups[group_index]) for p, q in zip(self.round_robin_bit16_groups[group_index], updated_params): p.data = q.data # set model fp16 weight to slices of reordered flattened buffer for param_index, param in enumerate(self.bit16_groups[group_index]): new_index = self.round_robin_bit16_indices[group_index][param_index] param.data = self.round_robin_bit16_groups[group_index][new_index].data def _round_robin_reorder(self, tensor_list, num_partitions): # disable round robin if need to debug something # return tensor_list, list(range(len(tensor_list))) partition_tensors = {} for i, tensor in enumerate(tensor_list): j = i % num_partitions if not j in partition_tensors: partition_tensors[j] = [] partition_tensors[j].append((i, tensor)) reordered_tensors = [] reordered_indices = {} for partition_index in partition_tensors.keys(): for i, (original_index, tensor) in enumerate(partition_tensors[partition_index]): reordered_indices[original_index] = len(reordered_tensors) reordered_tensors.append(tensor) return reordered_tensors, reordered_indices def _release_ipg_buffers(self): if self.contiguous_gradients: self.ipg_buffer = None self.grads_in_partition = None self.grads_in_partition_offset = 0 def initialize_optimizer_states(self): for i, group in enumerate(self.bit16_groups): single_grad_partition = torch.zeros(int(self.partition_size[i]), dtype=self.single_partition_of_fp32_groups[i].dtype, device=self.device) self.single_partition_of_fp32_groups[i].grad = get_accelerator().pin_memory( single_grad_partition) if self.cpu_offload_pin_memory else single_grad_partition # Initialize the optimizer states with the flattened fp32 partition. # State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers # which do lazy initialization of the state at the first call to step. if isinstance(self.optimizer, torch.optim.Adagrad): self.optimizer = torch.optim.Adagrad(self.single_partition_of_fp32_groups, **self.optimizer.defaults) else: self.optimizer.step() if not self.cpu_offload: for group in self.single_partition_of_fp32_groups: group.grad = None # class init return ######################################################################### #################### ZeRO Stage 1 - reduce gradients #################### ######################################################################### def reduce_gradients(self, pipeline_parallel=False): world_size = dist.get_world_size(self.zp_process_group) my_rank = dist.get_rank(self.zp_process_group) # with PP we must create ipg buffer, since backward is handled outside zero if pipeline_parallel and self.contiguous_gradients: self.ipg_buffer = [] buf_0 = torch.empty(int(self.reduce_bucket_size), dtype=self.dtype, device=get_accelerator().current_device_name()) self.ipg_buffer.append(buf_0) self.ipg_index = 0 if not self.overlap_comm: for i, group in enumerate(self.bit16_groups): for param in group: grad_reduc = self.get_gradient_for_reduction(param) if grad_reduc is not None: self.reduce_ready_partitions_and_remove_grads(param, i) # reduce any pending grads in either hook/non-hook case self.overlapping_partition_gradients_reduce_epilogue() ######################################################################### #########################ZeRO Partition Gradients######################## ######################################################################### def get_first_param_index(self, group_id, param_group, partition_id): for index, param in enumerate(param_group): param_id = self.get_param_id(param) if partition_id in self.param_to_partition_ids[group_id][param_id]: return index return None def initialize_gradient_partitioning_data_structures(self): for i, param_group in enumerate(self.round_robin_bit16_groups): total_partitions = dist.get_world_size(group=self.real_zp_process_group[i]) self.param_to_partition_ids[i] = {} self.is_partition_reduced[i] = {} self.total_grads_in_partition[i] = {} self.remaining_grads_in_partition[i] = {} self.is_grad_computed[i] = {} self.grad_partition_insertion_offset[i] = {} self.grad_start_offset[i] = {} self.first_param_index_in_partition[i] = {} for partition_id in range(total_partitions): self.is_grad_computed[i][partition_id] = {} self.grad_partition_insertion_offset[i][partition_id] = {} self.grad_start_offset[i][partition_id] = {} self.total_grads_in_partition[i][partition_id] = 0 self.initialize_gradient_partition(i, param_group, partition_id) self.is_partition_reduced[i][partition_id] = False self.first_param_index_in_partition[i][partition_id] = self.get_first_param_index( i, param_group, partition_id) def independent_gradient_partition_epilogue(self): self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0) self.reduce_ipg_grads() self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0) # if dist.get_rank() == 0: # logger.info("Params already reduced %s", self.params_already_reduced) for i in range(len(self.params_already_reduced)): self.params_already_reduced[i] = False if self.overlap_comm: get_accelerator().synchronize() # It is safe to clear previously reduced grads of other partitions self._clear_previous_reduced_grads() if self.cpu_offload is False: for i, _ in enumerate(self.bit16_groups): if not i in self.averaged_gradients or self.averaged_gradients[i] is None: self.averaged_gradients[i] = self.get_flat_partition( self.params_in_partition[i], self.first_offset[i], self.partition_size[i], dtype=self.gradient_accumulation_dtype, device=get_accelerator().current_device_name(), return_tensor_list=True) else: avg_new = self.get_flat_partition(self.params_in_partition[i], self.first_offset[i], self.partition_size[i], dtype=self.gradient_accumulation_dtype, device=get_accelerator().current_device_name(), return_tensor_list=True) for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i], avg_new): accumulated_grad.add_(new_avg_grad) self._release_ipg_buffers() # No need to keep the gradients anymore. # All gradients required by the step # are in self.averaged_gradients self.zero_grad(set_to_none=True) see_memory_usage(f"End ipg_epilogue") # resets all partition to no reduced # sets remaining grads to the total number of grads in each partition # set is grad computed to false for all grads in partition def reset_partition_gradient_structures(self): for i, _ in enumerate(self.bit16_groups): total_partitions = dist.get_world_size(group=self.real_zp_process_group[i]) for partition_id in range(total_partitions): self.is_partition_reduced[i][partition_id] = False self.remaining_grads_in_partition[i][partition_id] = self.total_grads_in_partition[i][partition_id] for param_id in self.is_grad_computed[i][partition_id]: self.is_grad_computed[i][partition_id][param_id] = False def initialize_gradient_partition(self, i, param_group, partition_id): def set_key_value_list(dictionary, key, value): if key in dictionary: dictionary[key].append(value) else: dictionary[key] = [value] def increment_value(dictionary, key): if key in dictionary: dictionary[key] += 1 else: dictionary[key] = 1 partition_size = self.partition_size[i] start_index = partition_size * partition_id end_index = partition_size * (partition_id + 1) current_index = 0 first_offset = 0 for param in param_group: param_size = param.numel() param_id = self.get_param_id(param) if start_index <= current_index < end_index: set_key_value_list(self.param_to_partition_ids[i], param_id, partition_id) increment_value(self.total_grads_in_partition[i], partition_id) self.is_grad_computed[i][partition_id][param_id] = False self.grad_partition_insertion_offset[i][partition_id][param_id] = current_index - start_index self.grad_start_offset[i][partition_id][param_id] = 0 elif current_index < start_index < (current_index + param_size): assert (first_offset == 0 ), "This can happen either zero or only once as this must be the first tensor in the partition" first_offset = start_index - current_index set_key_value_list(self.param_to_partition_ids[i], param_id, partition_id) increment_value(self.total_grads_in_partition[i], partition_id) self.is_grad_computed[i][partition_id][param_id] = False self.grad_partition_insertion_offset[i][partition_id][param_id] = 0 self.grad_start_offset[i][partition_id][param_id] = first_offset current_index = current_index + param_size def overlapping_partition_gradients_reduce_epilogue(self): self.independent_gradient_partition_epilogue() def fill_grad_accum_attribute(self): for group in self.bit16_groups: for param in group: if param.grad is not None: if param.grad_accum is None: param.grad_accum = param.grad.to(self.gradient_accumulation_dtype) else: param.grad_accum.add_( param.grad.to(self.gradient_accumulation_dtype).view(param.grad_accum.shape)) param.grad = None def get_gradient_for_reduction(self, param): if self.use_grad_accum_attribute: return param.grad_accum.to(self.dtype) if param.grad_accum is not None else None else: return param.grad def get_param_gradient_attribute(self, param): return param.grad_accum if self.use_grad_accum_attribute else param.grad # Clear the tensor the reduction gradient attribute is pointing to def clear_grad_attribute(self, param): if self.use_grad_accum_attribute: param.grad_accum = None else: param.grad = None def create_reduce_and_remove_grad_hooks(self): self.grad_accs = [] for i, param_group in enumerate(self.bit16_groups): for param in param_group: if param.requires_grad: def wrapper(param, i): param_tmp = param.expand_as(param) grad_acc = param_tmp.grad_fn.next_functions[0][0] def reduce_partition_and_remove_grads(*notneeded): self.reduce_ready_partitions_and_remove_grads(param, i) grad_acc.register_hook(reduce_partition_and_remove_grads) self.grad_accs.append(grad_acc) wrapper(param, i) def get_param_id(self, param): unique_id = id(param) return self.param_id[unique_id] def report_ipg_memory_usage(self, tag, param_elems): elem_count = self.elements_in_ipg_bucket + param_elems percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size see_memory_usage( f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}" ) # create a flat tensor aligned at the alignment boundary def flatten_dense_tensors_aligned(self, tensor_list, alignment): return self.flatten(align_dense_tensors(tensor_list, alignment)) ############### Independent Partition Gradient ######################## def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): grad_reduc = self.get_gradient_for_reduction(param) if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size: self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.numel()) self.reduce_ipg_grads() if self.contiguous_gradients and self.overlap_comm: # Swap ipg_index between 0 and 1 self.ipg_index = 1 - self.ipg_index self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads", param.numel()) param_id = self.get_param_id(param) assert self.params_already_reduced[param_id] == False, \ f"The parameter {param_id} has already been reduced. \ Gradient computed twice for this partition. \ Multiple gradient reduction is currently not supported" if self.contiguous_gradients: if param.numel() > self.reduce_bucket_size: self.extra_large_param_to_reduce = param else: # keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(0, self.elements_in_ipg_bucket, param.numel()) new_grad_tensor.copy_(grad_reduc.view(-1)) grad_reduc.data = new_grad_tensor.data.view_as(grad_reduc) self.elements_in_ipg_bucket += param.numel() assert grad_reduc is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient" self.grads_in_ipg_bucket.append(grad_reduc) self.params_in_ipg_bucket.append((i, param, param_id)) # make sure the average tensor function knows how to average the gradients if is_moe_param(param): self.ipg_bucket_has_moe_params = True self.report_ipg_memory_usage("End ipg_remove_grads", 0) def print_rank_0(self, message): if dist.get_rank() == 0: logger.info(message) def gradient_reduction_w_predivide(self, tensor): dp_world_size = dist.get_world_size(group=self.dp_process_group) tensor_to_allreduce = tensor if self.communication_data_type != tensor.dtype: tensor_to_allreduce = tensor.to(self.communication_data_type) if self.postscale_gradients: if self.gradient_predivide_factor != 1.0: tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor) dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) if self.gradient_predivide_factor != dp_world_size: tensor_to_allreduce.mul_(self.gradient_predivide_factor / (dp_world_size / float(self.sequence_parallel_size))) else: tensor_to_allreduce.div_(dp_world_size / float(self.sequence_parallel_size)) dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: tensor.copy_(tensor_to_allreduce) return tensor def allreduce_and_copy_with_multiple_ranks(self, small_bucket, log=None, divide=True, process_group=None, bucket_ranks=None): process_group = self.zp_process_group if process_group is None else process_group allreduced = self.allreduce_bucket(small_bucket, log=log, divide=divide, process_group=process_group) for buf, synced, bucket_rank in zip(small_bucket, self.unflatten(allreduced, small_bucket), bucket_ranks): if dist.get_rank(group=process_group) == bucket_rank: buf.copy_(synced) def allreduce_and_scatter(self, bucket, numel_per_bucket=500000000, log=None, divide=True, process_group=None): small_bucket = [] small_bucket_ranks = [] numel = 0 allreduce_sizes = [] for i, bucket_elem in enumerate(bucket): rank, tensor = bucket_elem small_bucket.append(tensor) small_bucket_ranks.append(rank) numel = numel + tensor.numel() if numel > numel_per_bucket: self.allreduce_and_copy_with_multiple_ranks(small_bucket, log=None, divide=divide, process_group=process_group, bucket_ranks=small_bucket_ranks) small_bucket = [] small_bucket_ranks = [] numel = 0 if len(small_bucket) > 0: self.allreduce_and_copy_with_multiple_ranks(small_bucket, log=None, divide=divide, process_group=process_group, bucket_ranks=small_bucket_ranks) def average_tensor(self, tensor): if self.overlap_comm: stream = self.reduction_stream if not get_accelerator().is_synchronized_device(): stream.wait_stream(get_accelerator().current_stream()) else: stream = get_accelerator().current_stream() with get_accelerator().stream(stream): if not self.reduce_scatter: self.gradient_reduction_w_predivide(tensor) return # Accumulate destination ranks and bucket offsets for each gradient slice. # Note: potential future optimization, record access pattern of parameters # in backward pass and partition gradients w.r.t. access pattern so that our # bucket is guaranteed to be contiguous w.r.t. ranks rank_and_offsets = [] real_dp_process_group = [] curr_size = 0 prev_id, prev_process_group = -1, None process_group = self.zp_process_group # count = 0 for i, param, param_id in self.params_in_ipg_bucket: process_group = self.zp_process_group grad_reduc = self.get_gradient_for_reduction(param) # Averages gradients at parameter level if ipg has a moe param # Otherwise averaging is done at the entire buffer level at the end of the loop # MoE param have different groups if self.ipg_bucket_has_moe_params: process_group = self.expert_dp_process_group[param.group_name] if is_moe_param( param) else self.zp_process_group grad_reduc.data.div_(dist.get_world_size(group=process_group) / float(self.sequence_parallel_size)) partition_ids = self.param_to_partition_ids[i][param_id] assert all([p_id < dist.get_world_size(group=process_group) for p_id in partition_ids ]), f"world size {dist.get_world_size(group=process_group)} and p_ids: {partition_ids}" partition_size = self.partition_size[i] # Get all partition ids + their offsets partition_ids_w_offsets = [] for partition_id in partition_ids: offset = self.grad_start_offset[i][partition_id][param_id] partition_ids_w_offsets.append((partition_id, offset)) partition_ids_w_offsets.sort(key=lambda t: t[1]) # Calculate rank and offsets for grad slices for idx in range(len(partition_ids_w_offsets)): partition_id, offset = partition_ids_w_offsets[idx] # if dist.get_rank() == 0 and count < 100: # print(f"Rank {dist.get_rank()} rank offset id {idx} calculated dp size {dist.get_world_size(group=process_group)} real dp size {dist.get_world_size(self.real_dp_process_group[i])} and dst: {partition_id}") # count += 1 # Calculate numel for grad slice depending on partition location if idx == len(partition_ids_w_offsets) - 1: # Last partition_id uses its own offset numel = param.numel() - offset else: # Set numel to next partition's offset numel = partition_ids_w_offsets[idx + 1][1] - offset # Merge bucket ranges if they belong to the same rank if partition_id == prev_id and process_group == prev_process_group: prev_pid, prev_size, prev_numel = rank_and_offsets[-1] rank_and_offsets[-1] = (prev_pid, prev_size, prev_numel + numel) else: rank_and_offsets.append((partition_id, curr_size, numel)) real_dp_process_group.append(process_group) curr_size += numel prev_id, prev_process_group = partition_id, process_group if not self.ipg_bucket_has_moe_params: tensor.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size)) buckets = {} for i, (dst, bucket_offset, numel) in enumerate(rank_and_offsets): grad_slice = tensor.narrow(0, int(bucket_offset), int(numel)) bucket_key = real_dp_process_group[i] if self.use_multi_rank_bucket_allreduce else ( dst, real_dp_process_group[i]) if bucket_key not in buckets: buckets[bucket_key] = [] if self.use_multi_rank_bucket_allreduce: buckets[bucket_key].append((dst, grad_slice)) else: buckets[bucket_key].append(grad_slice) for bucket_key in buckets: if self.use_multi_rank_bucket_allreduce: self.allreduce_and_scatter(buckets[bucket_key], numel_per_bucket=self.reduce_bucket_size, divide=self.ipg_bucket_has_moe_params, process_group=bucket_key) else: dst, process_group = bucket_key self.allreduce_no_retain(buckets[bucket_key], numel_per_bucket=self.reduce_bucket_size, rank=dst, divide=self.ipg_bucket_has_moe_params, process_group=process_group) ############################################################################## ############################# CPU Offload Methods############################# ############################################################################## def get_grad_position(self, group_id, tensor_list, first_offset, partition_size): current_offset = 0 for i, tensor in enumerate(tensor_list): param_id = self.get_param_id(tensor) param_start_offset = 0 num_elements = tensor.numel() # we need to offset to get to the right element if i == 0 and first_offset > 0: tensor_offset = first_offset num_elements = num_elements - tensor_offset param_start_offset = first_offset # we dont need all elements of the tensor if num_elements > (partition_size - current_offset): num_elements = partition_size - current_offset self.grad_position[param_id] = [ int(group_id), int(param_start_offset), int(current_offset), int(num_elements) ] current_offset += num_elements def update_overflow_tracker_for_param_grad(self, param): grad_accum = self.get_param_gradient_attribute(param) if grad_accum is not None and self._has_inf_or_nan(grad_accum.data): self.local_overflow = True def _get_offload_gradient_dict(self): for param_group_index, _ in enumerate(self.optimizer.param_groups): self.offload_gradient_dict[param_group_index] = [] for lp_param in self.params_in_partition[param_group_index]: param_id = self.get_param_id(lp_param) [_, _, dest_offset, num_elements] = self.grad_position[param_id] dest_tensor = self.single_partition_of_fp32_groups[param_group_index].grad.view(-1).narrow( 0, dest_offset, num_elements) self.offload_gradient_dict[param_group_index].append(dest_tensor) def async_accumulate_grad_in_cpu_via_gpu(self, param): param_id = self.get_param_id(param) [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] # copy to a preexisiting buffer to avoid memory allocation penalty dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow(0, 0, param.numel()) # buffer for storing gradients for this parameter in CPU def buffer_to_accumulate_to_in_cpu(): if not self.fp16_master_weights_and_gradients: buffer = torch.zeros(param.numel(), dtype=param.dtype, device=self.device) return get_accelerator().pin_memory(buffer) if self.cpu_offload_pin_memory else buffer else: return self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, num_elements) # accumulate gradients into param.grad_accum or parts of it that belongs to this partition def accumulate_gradients(): grad_accum = self.get_param_gradient_attribute(param) if not self.fp16_master_weights_and_gradients: dest_buffer.copy_(self.accumulated_grads_in_cpu[param_id].view(-1), non_blocking=True) grad_accum.data.view(-1).add_(dest_buffer) else: dest_buffer.narrow(0, source_offset, num_elements).copy_(self.accumulated_grads_in_cpu[param_id].view(-1), non_blocking=True) grad_accum.data.view(-1).narrow(0, source_offset, num_elements).add_(dest_buffer.narrow(0, source_offset, num_elements)) # move accumulated gradients back to CPU def copy_gradients_to_cpu(): grad_accum = self.get_param_gradient_attribute(param) if not self.fp16_master_weights_and_gradients: self.accumulated_grads_in_cpu[param_id].data.copy_(grad_accum.data.view(-1), non_blocking=True) else: self.accumulated_grads_in_cpu[param_id].data.copy_(grad_accum.data.view(-1).narrow( 0, source_offset, num_elements), non_blocking=True) if param_id not in self.accumulated_grads_in_cpu: self.accumulated_grads_in_cpu[param_id] = buffer_to_accumulate_to_in_cpu() if self.micro_step_id > 0: accumulate_gradients() # at the boundary we will send 32bit directly if not self.is_gradient_accumulation_boundary: copy_gradients_to_cpu() def set_norm_for_param_grad(self, param): param_id = self.get_param_id(param) grad_accum = self.get_param_gradient_attribute(param) accumulated_grad = self.accumulated_grads_in_cpu[ param_id] if self.gradient_accumulation_steps > 1 else grad_accum [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] start = source_offset accumulated_grad = accumulated_grad.view(-1).narrow(0, start, num_elements) self.norm_for_param_grads[param_id] = accumulated_grad.data.double().norm(2) def set_norm_for_param_grad_in_gpu(self, param): param_id = self.get_param_id(param) grad_accum = self.get_param_gradient_attribute(param) if grad_accum is None: accumulated_grad = param.grad else: accumulated_grad = grad_accum [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] start = source_offset accumulated_grad = accumulated_grad.view(-1).narrow(0, start, num_elements) self.norm_for_param_grads[param_id] = accumulated_grad.data.double().norm(2) def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param): param_id = self.get_param_id(param) [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, num_elements) grad_accum = self.get_param_gradient_attribute(param) if grad_accum is None: src_tensor = grad_accum.view(-1).narrow(0, source_offset, num_elements) else: src_tensor = grad_accum.view(-1).narrow(0, source_offset, num_elements) if not self.fp16_master_weights_and_gradients: src_tensor = src_tensor.float() dest_tensor.copy_(src_tensor, non_blocking=True) param.grad = None # offload only def complete_grad_norm_calculation_for_cpu_offload(self, params): total_norm = 0.0 norm_type = 2.0 for p in params: # Pipeline parallelism may replicate parameters. Avoid multi-counting. if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: continue if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): param_id = self.get_param_id(p) # as some model have trainable parameters but skipped in training, # their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run, # so they have no norm_for_param_grads if param_id in self.norm_for_param_grads: param_norm = self.norm_for_param_grads[param_id] total_norm += param_norm.item() ** 2 else: # As unused parameters in modules may not be expected sometimes, # add an explicit error msg when it occurred and an option to # avoid the error assert self.ignore_unused_parameters, """ This assert indicates that your module has parameters that were not used in producing loss. You can avoid this assert by (1) enable ignore_unused_parameters option in zero_optimization config; (2) making sure all trainable parameters and `forward` function outputs participate in calculating loss. """ # Sum across all model parallel GPUs. total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group) self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) total_norm = total_norm_cuda[0].item() ** (1. / norm_type) if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: total_norm = -1 return total_norm ############################################################################################ def copy_grads_in_partition(self, param): if self.cpu_offload: if self.gradient_accumulation_steps > 1: self.async_accumulate_grad_in_cpu_via_gpu(param) if self.is_gradient_accumulation_boundary: self.set_norm_for_param_grad_in_gpu(param) self.update_overflow_tracker_for_param_grad(param) self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param) return # print(f"ID {self.get_param_id(param)} grad norm {param.grad.norm()}") if self.grads_in_partition is None: self.grads_in_partition_offset = 0 total_size = 0 for group in self.params_in_partition: for param_in_partition in group: total_size += param_in_partition.numel() see_memory_usage(f"before copying {total_size} gradients into partition") self.grads_in_partition = torch.empty(int(total_size), dtype=self.dtype, device=get_accelerator().current_device_name()) see_memory_usage(f"after copying {total_size} gradients into partition") grad_reduc = self.get_gradient_for_reduction(param) # The allreduce buffer will be rewritten. Copy the gradients in partition to a new buffer new_grad_tensor = self.grads_in_partition.view(-1).narrow(0, self.grads_in_partition_offset, param.numel()) new_grad_tensor.copy_(grad_reduc.view(-1)) grad_reduc.data = new_grad_tensor.data.view_as(grad_reduc) # print(f"Grad norm after copy to contiguous_buffer {param.grad.data.norm()}") self.grads_in_partition_offset += param.numel() def reduce_ipg_grads(self): if self.contiguous_gradients: if self.extra_large_param_to_reduce is not None: assert len(self.params_in_ipg_bucket) == 1, "more than 1 param in ipg bucket, this shouldn't happen" _, _, param_id = self.params_in_ipg_bucket[0] assert self.get_param_id(self.extra_large_param_to_reduce ) == param_id, "param in ipg bucket does not match extra-large param" extra_large_grad_reduc = self.get_gradient_for_reduction(self.extra_large_param_to_reduce) self.average_tensor(extra_large_grad_reduc.view(-1)) self.extra_large_param_to_reduce = None else: self.average_tensor(self.ipg_buffer[self.ipg_index]) else: self.buffered_reduce_fallback(None, self.grads_in_ipg_bucket, elements_per_buffer=self.elements_in_ipg_bucket) if self.overlap_comm: stream = self.reduction_stream elif self.cpu_offload: # TODO: copy_grad_stream is disabled because of race with reduce. This hurts perf and should be fixed. # get_accelerator().synchronize() # stream = self.copy_grad_stream stream = get_accelerator().current_stream() else: stream = get_accelerator().current_stream() with get_accelerator().stream(stream): for _, param, param_id in self.params_in_ipg_bucket: assert self.params_already_reduced[param_id] == False, \ f"The parameter {param_id} has already been reduced. \ Gradient computed twice for this partition. \ Multiple gradient reduction is currently not supported" self.params_already_reduced[param_id] = True if self.partition_gradients: if not self.is_param_in_current_partition[param_id]: if self.overlap_comm and self.contiguous_gradients is False: # Clear grads of other partitions during the next reduction # to avoid clearing them before the reduction is complete. if self.previous_reduced_grads is None: self.previous_reduced_grads = [] self.previous_reduced_grads.append(param) else: self.clear_grad_attribute(param) elif self.contiguous_gradients: self.copy_grads_in_partition(param) else: # zero stage 1 - partition only optimizer state if self.contiguous_gradients and self.is_param_in_current_partition[param_id]: self.copy_grads_in_partition(param) self.grads_in_ipg_bucket = [] self.params_in_ipg_bucket = [] self.ipg_bucket_has_moe_params = False self.elements_in_ipg_bucket = 0 ##################################################################### def reduce_ready_partitions_and_remove_grads(self, param, i): if self.partition_gradients or self.is_gradient_accumulation_boundary: self.reduce_independent_p_g_buckets_and_remove_grads(param, i) def zero_reduced_gradients(self, partition_id, i): def are_all_related_partitions_reduced(params_id): for partition_id in self.param_to_partition_ids[i][params_id]: if not self.is_partition_reduced[i][partition_id]: return False return True for params_id in self.is_grad_computed[i][partition_id]: if are_all_related_partitions_reduced(params_id): self.param_dict[params_id].grad = None # dead code def flatten_and_print(self, message, tensors, start=0, n=5): flatten_tensor = self.flatten(tensors) def print_func(): logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n)) self.sequential_execution(print_func, message) def get_grads_to_reduce(self, i, partition_id): def get_reducible_portion(key): grad = self.param_dict[key].grad total_elements = grad.numel() start = self.grad_start_offset[i][partition_id][key] num_elements = min(total_elements - start, self.partition_size[i] - self.grad_partition_insertion_offset[i][partition_id][key]) if not pg_correctness_test: if num_elements == total_elements: return grad else: return grad.contiguous().view(-1).narrow(0, int(start), int(num_elements)) else: if num_elements == total_elements: return grad.clone() else: return grad.clone().contiguous().view(-1).narrow(0, int(start), int(num_elements)) grads_to_reduce = [] for key in self.is_grad_computed[i][partition_id]: grad = get_reducible_portion(key) grads_to_reduce.append(grad) return grads_to_reduce def sequential_execution(self, function, message, group=None): if group is None: group = self.zp_process_group if dist.get_rank(group=group) == 0: logger.info(message) for id in range(dist.get_world_size(group=group)): if id == dist.get_rank(group=group): function() dist.barrier(group=group) def set_none_gradients_to_zero(self, i, partition_id): for param_id in self.is_grad_computed[i][partition_id]: param = self.param_dict[param_id] if param.grad is None: param.grad = torch.zero_like(param) ######################Reduction Related Methods############################## def allreduce_bucket(self, bucket, rank=None, log=None, divide=True, process_group=None): rank = None tensor = self.flatten(bucket) process_group = self.zp_process_group if process_group is None else process_group tensor_to_allreduce = tensor if pg_correctness_test or self.sequence_parallel_size > 1: communication_data_type = torch.float32 else: communication_data_type = self.communication_data_type if communication_data_type != tensor.dtype: tensor_to_allreduce = tensor.to(communication_data_type) if divide: tensor_to_allreduce.div_( dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size)) tensor_to_allreduce = tensor_to_allreduce.contiguous() if rank is None: # "All Reducing" dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) else: global_rank = dist.get_global_rank(self.dp_process_group, rank) dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group) if communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: if rank is None or rank == dist.get_rank(group=process_group): tensor.copy_(tensor_to_allreduce) return tensor def _clear_previous_reduced_grads(self): if self.previous_reduced_grads is not None: for param in self.previous_reduced_grads: self.clear_grad_attribute(param) self.previous_reduced_grads = None # if rank is specified do a reduction instead of an allreduce def allreduce_and_copy(self, small_bucket, rank=None, log=None, divide=True, process_group=None): process_group = self.zp_process_group if process_group is None else process_group if self.overlap_comm: get_accelerator().synchronize() # It is safe to clear the previously reduced grads of other partitions self._clear_previous_reduced_grads() stream = self.reduction_stream else: stream = get_accelerator().current_stream() with get_accelerator().stream(stream): allreduced = self.allreduce_bucket( small_bucket, rank=rank, log=log, divide=divide, process_group=process_group, ) if rank is None or rank == dist.get_rank(group=self.zp_process_group): for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): buf.copy_(synced) def allreduce_no_retain( self, bucket, numel_per_bucket=500000000, rank=None, log=None, divide=True, process_group=None, ): small_bucket = [] numel = 0 for tensor in bucket: small_bucket.append(tensor) numel = numel + tensor.numel() if numel > numel_per_bucket: self.allreduce_and_copy(small_bucket, rank=rank, log=None, divide=divide, process_group=process_group) small_bucket = [] numel = 0 if len(small_bucket) > 0: self.allreduce_and_copy(small_bucket, rank=rank, log=log, divide=divide, process_group=process_group) # allows using reduction of gradients instead of using all_reduce def buffered_reduce_fallback(self, rank, grads, elements_per_buffer=500000000, log=None): split_buckets = split_half_float_double(grads) for i, bucket in enumerate(split_buckets): self.allreduce_no_retain(bucket, numel_per_bucket=elements_per_buffer, rank=rank, log=log) ############################################################################# ############################################################################# ############################################################################# # views the tensor as multiple partitions and returns # those partitions def get_data_parallel_partitions(self, tensor, group_id): partitions = [] dp = dist.get_world_size(group=self.real_zp_process_group[group_id]) # dp_id = dist.get_rank(group=self.real_dp_process_group[group_id]) total_num_elements = tensor.numel() base_size = total_num_elements // dp remaining = total_num_elements % dp start = 0 for id in range(dp): partition_size = base_size if id < remaining: partition_size = partition_size + 1 partitions.append(tensor.narrow(0, start, partition_size)) start = start + partition_size return partitions def get_partition_info(self, tensor_list, partition_size, partition_id): params_in_partition = [] params_not_in_partition = [] start_index = partition_size * partition_id end_index = partition_size * (partition_id + 1) current_index = 0 first_offset = 0 for tensor in tensor_list: tensor_size = tensor.numel() if start_index <= current_index < end_index: params_in_partition.append(tensor) elif current_index < start_index < (current_index + tensor_size): params_in_partition.append(tensor) assert (first_offset == 0 ), "This can happen either zero or only once as this must be the first tensor in the partition" first_offset = start_index - current_index else: params_not_in_partition.append(tensor) current_index = current_index + tensor_size return params_in_partition, params_not_in_partition, first_offset def zero_grad(self, set_to_none=True): """ Zero FP16 parameter grads. """ # FP32 grad should never exist. # For speed, set model fp16 grad to None by default # zero all pointers to grad tensors for group in self.bit16_groups: for p in group: if set_to_none: p.grad = None # epilogue and in step p.grad_accum = None else: if p.grad is not None: p.grad.detach_() p.grad.zero_() def _model_parallel_all_reduce(self, tensor, op): """ Perform all reduce within model parallel group, if any. """ if self.model_parallel_group is None or self.model_parallel_world_size == 1: pass else: dist.all_reduce(tensor=tensor, op=op, group=self.model_parallel_group) def get_grad_norm_direct(self, gradients, params, norm_type=2): """Clips gradient norm of an iterable of parameters. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and added functionality to handle model parallel parameters. Note that the gradients are modified in place. Arguments: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. Returns: Total norm of the parameters (viewed as a single vector). """ norm_type = float(norm_type) if norm_type == inf: total_norm = max(g.data.abs().max() for g in gradients) total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.dp_process_group) # Take max across all GPUs. self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX) total_norm = total_norm_cuda[0].item() else: total_norm = 0.0 # if dist.get_rank() == 0: # logger.info(f"Total Norm beginning {total_norm}") for g, p in zip(gradients, params): # Pipeline parallelism may replicate parameters. Avoid multi-counting. if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: continue if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): param_norm = g.data.double().norm(2) total_norm += param_norm.item() ** 2 # Sum across all model parallel GPUs. total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group) self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) total_norm = total_norm_cuda[0].item() ** (1. / norm_type) if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: total_norm = -1 return total_norm # creates a flat fused tensor from the tensor list starting at the first_offset # in the first tensor of the list. If there are not enough elements in the tensor # list then the flat tensor will be padded with zeros def get_flat_partition(self, tensor_list, first_offset, partition_size, dtype, device, return_tensor_list=False): flat_tensor_list = [] current_size = 0 for i, tensor in enumerate(tensor_list): grad_accum = self.get_param_gradient_attribute(tensor) if grad_accum is None: grad_accum = torch.zeros_like(tensor, dtype=dtype) tensor = grad_accum num_elements = tensor.numel() tensor_offset = 0 # we need to offset to get to the right element if i == 0 and first_offset > 0: tensor_offset = first_offset num_elements = num_elements - tensor_offset # we dont need all elements of the tensor if num_elements > (partition_size - current_size): num_elements = partition_size - current_size # we need a narrow view of the tensor based on the tensor offset and number of elements that # we need from this tensor if tensor_offset > 0 or num_elements < tensor.numel(): flat_tensor_list.append(tensor.contiguous().view(-1).narrow(0, int(tensor_offset), int(num_elements))) else: flat_tensor_list.append(tensor) current_size = current_size + num_elements # this means its the last partition and does not align with the dp boundary. We need to pad before flattening if current_size < partition_size: flat_tensor_list.append(torch.zeros(int(partition_size - current_size), dtype=dtype, device=device)) if return_tensor_list: return flat_tensor_list return self.flatten(flat_tensor_list) def free_grad_in_param_list(self, param_list): for p in param_list: p.grad = None # in step p.grad_accum = None def reset_cpu_buffers(self): self.norm_for_param_grads = {} self.local_overflow = False def set_lr(self, lr): """Set the learning rate.""" for param_group in self.optimizer.param_groups: param_group["lr"] = lr def get_lr(self): """Return the current learning rate.""" return self.optimizer.param_groups[0]["lr"] def override_loss_scale(self, loss_scale): if loss_scale != self.external_loss_scale: logger.info(f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}') self.custom_loss_scaler = True self.external_loss_scale = loss_scale def scaled_global_norm(self, norm_type=2): assert norm_type == 2, "only L2 norm supported" norm_groups = [] for i, group in enumerate(self.bit16_groups): partition_id = dist.get_rank(group=self.real_zp_process_group[i]) if self.cpu_offload: norm_groups.append(self.complete_grad_norm_calculation_for_cpu_offload(self.params_in_partition[i])) single_grad_partition = self.single_partition_of_fp32_groups[i].grad else: norm_groups.append(self.get_grad_norm_direct(self.averaged_gradients[i], self.params_in_partition[i])) if self.has_moe_layers: self._average_expert_grad_norms(norm_groups) # note that the get_global_norm function only supports l2 norm return get_global_norm(norm_list=norm_groups) def get_bit16_param_group(self, group_no): bit16_partitions = self.parallel_partitioned_bit16_groups[group_no] partition_id = dist.get_rank(group=self.real_zp_process_group[group_no]) return [bit16_partitions[dist.get_rank(group=self.real_zp_process_group[group_no])]] def _optimizer_step(self, group_no): original_param_groups = self.optimizer.param_groups self.optimizer.param_groups = [original_param_groups[group_no]] # Disabling this as the C++ side copy & synchronize is not working correctly # from deepspeed.ops.adam import DeepSpeedCPUAdam # if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half: # self.optimizer.step(fp16_param_groups=[self.get_bit16_param_group(group_no)]) # else: # self.optimizer.step() self.optimizer.step() self.optimizer.param_groups = original_param_groups def step(self, closure=None): """ Not supporting closure. """ self.micro_step_id = -1 see_memory_usage(f"In step before checking overflow") # First compute norm for all group so we know if there is overflow if self.dtype == torch.float16: self.check_overflow() prev_scale = self.loss_scale self._update_scale(self.overflow) if self.overflow: see_memory_usage('After overflow before clearing gradients') self.zero_grad(set_to_none=True) if self.cpu_offload: self.reset_cpu_buffers() else: self.averaged_gradients = {} see_memory_usage('After overflow after clearing gradients') for timer in OPTIMIZER_TIMERS: self.timers(timer).start() self.timers(timer).stop() return # Step 1:- Calculate gradient norm using bit-16 grads see_memory_usage('Before norm calculation') scaled_global_grad_norm = self.scaled_global_norm() self._global_grad_norm = scaled_global_grad_norm / prev_scale see_memory_usage('After norm before optimizer') # Step 2:- run optimizer and upscaling simultaneously for i, group in enumerate(self.bit16_groups): self.timers(OPTIMIZER_GRADIENTS_TIMER).start() partition_id = dist.get_rank(group=self.real_zp_process_group[i]) if self.cpu_offload: single_grad_partition = self.single_partition_of_fp32_groups[i].grad self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm) self.timers(OPTIMIZER_GRADIENTS_TIMER).stop() self.timers(OPTIMIZER_STEP_TIMER).start() self._optimizer_step(i) # Disabled, this is not currently working # from deepspeed.ops.adam import DeepSpeedCPUAdam # if not (type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half): # bit16_partitions = self.parallel_partitioned_bit16_groups[i] # fp32_partition = self.single_partition_of_fp32_groups[i] # bit16_partitions[partition_id].data.copy_(fp32_partition.data) bit16_partitions = self.parallel_partitioned_bit16_groups[i] fp32_partition = self.single_partition_of_fp32_groups[i] bit16_partitions[partition_id].data.copy_(fp32_partition.data) self.timers(OPTIMIZER_STEP_TIMER).stop() else: # free gradients for all the parameters that are not updated by this process(ZeRO stage2) self.free_grad_in_param_list(self.params_not_in_partition[i]) # create a flat gradients for parameters updated by this process # If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors if partition_id == dist.get_world_size(group=self.real_zp_process_group[i]) - 1: single_grad_partition = self.flatten_dense_tensors_aligned( self.averaged_gradients[i], int(self.partition_size[i])).to(self.single_partition_of_fp32_groups[i].dtype) else: single_grad_partition = self.flatten(self.averaged_gradients[i]).to( self.single_partition_of_fp32_groups[i].dtype) assert single_grad_partition.numel() == self.partition_size[i], \ "averaged gradients have different number of elements that partition size {} {} {} {}".format( single_grad_partition.numel(), self.partition_size[i], i, partition_id) self.single_partition_of_fp32_groups[i].grad = single_grad_partition # release all the gradient since we have already created a necessary copy in dp_grad_partition(ZeRO stage2) self.free_grad_in_param_list(self.params_in_partition[i]) self.averaged_gradients[i] = None self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm) self.timers(OPTIMIZER_GRADIENTS_TIMER).stop() # Step 3:- run the optimizer if no offloading self.timers(OPTIMIZER_STEP_TIMER).start() self._optimizer_step(i) # Step 4:- get rid of the fp32 gradients. Not needed anymore self.single_partition_of_fp32_groups[i].grad = None del single_grad_partition bit16_partitions = self.parallel_partitioned_bit16_groups[i] fp32_partition = self.single_partition_of_fp32_groups[i] bit16_partitions[partition_id].data.copy_(fp32_partition.data) self.timers(OPTIMIZER_STEP_TIMER).stop() see_memory_usage('After optimizer before all-gather') if self.cpu_offload: self.reset_cpu_buffers() self.timers(OPTIMIZER_ALLGATHER_TIMER).start() # if dist.get_rank(group=self.dp_process_group) == 0: # pdb.set_trace() # 或者使用其他调试工具 # Gather the updated weights from everyone. # Then all partitions of the model parameters are updated and ready for next round forward. all_gather_into_tensor_dp_groups(groups_flat=self.bit16_groups_flat, partitioned_param_groups=self.parallel_partitioned_bit16_groups, zp_process_group=self.real_zp_process_group) self.timers(OPTIMIZER_ALLGATHER_TIMER).stop() # TODO: we probably don't need this? just to be safe for i in range(len(self.bit16_groups)): self._update_model_bit16_weights(i) self.timers.log(OPTIMIZER_TIMERS) see_memory_usage('After zero_optimizer step') return @torch.no_grad() def update_lp_params(self): for i, (bit16_partitions, fp32_partition) in enumerate( zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)): partition_id = dist.get_rank(group=self.real_zp_process_group[i]) bit16_partitions[partition_id].data.copy_(fp32_partition.data) # print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True) # if i == 0: # print_rank_0(f'{fp32_partition[:10]=}', force=True) all_gather_into_tensor_dp_groups(groups_flat=self.bit16_groups_flat, partitioned_param_groups=self.parallel_partitioned_bit16_groups, zp_process_group=self.real_zp_process_group) def _average_expert_grad_norms(self, norm_groups): for i, norm in enumerate(norm_groups): if self.is_moe_param_group[i]: scaled_norm = norm * 1.0 / float(dist.get_world_size(group=self.dp_process_group)) scaled_norm_tensor = torch.tensor(scaled_norm, device=get_accelerator().device_name(), dtype=torch.float) dist.all_reduce(scaled_norm_tensor, group=self.dp_process_group) norm_groups[i] = scaled_norm_tensor.item() def unscale_and_clip_grads(self, grad_groups_flat, total_norm): # compute combined scale factor for this group combined_scale = self.loss_scale if self.clip_grad > 0.: # norm is in fact norm*scale clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad if clip > 1: combined_scale = clip * self.loss_scale for grad in grad_groups_flat: if isinstance(grad, list): sub_partitions = grad for g in sub_partitions: g.data.mul_(1. / combined_scale) else: grad.data.mul_(1. / combined_scale) def _check_overflow(self, partition_gradients=True): self.overflow = self.has_overflow(partition_gradients) # `params` is a list / generator of torch.Variable def has_overflow_serial(self, params, is_grad_list=False): for p in params: if p.grad is not None and self._has_inf_or_nan(p.grad.data): return True return False def has_overflow_partitioned_grads_serial(self): for i in range(len(self.bit16_groups)): for j, grad in enumerate(self.averaged_gradients[i]): if grad is not None and self._has_inf_or_nan(grad.data, j): return True return False def has_overflow(self, partition_gradients=True): if partition_gradients: overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial() overflow_gpu = get_accelerator().ByteTensor([overflow]) '''This will capture overflow across all data parallel and expert parallel process Since expert parallel process are a subset of data parallel process''' dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group) else: params = [] for group in self.bit16_groups: for param in group: params.append(param) overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients) overflow_gpu = get_accelerator().ByteTensor([overflow]) # Since each model parallel GPU carries only part of the model, # make sure overflow flag is synced across all the model parallel GPUs self._model_parallel_all_reduce(tensor=overflow_gpu, op=dist.ReduceOp.MAX) overflow = overflow_gpu[0].item() return bool(overflow) # `x` is a torch.Tensor @staticmethod def _has_inf_or_nan(x, j=None): try: # if x is half, the .float() incurs an additional deep copy, but it's necessary if # Pytorch's .sum() creates a one-element tensor of the same type as x # (which is true for some recent version of pytorch). cpu_sum = float(x.float().sum()) # More efficient version that can be used if .sum() returns a Python scalar # cpu_sum = float(x.sum()) except RuntimeError as instance: # We want to check if inst is actually an overflow exception. # RuntimeError could come from a different error. # If so, we still want the exception to propagate. if "value cannot be converted" not in instance.args[0]: raise return True else: if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: return True return False def backward(self, loss, retain_graph=False): """ :attr:`backward` performs the following steps: 1. fp32_loss = loss.float() 2. scaled_loss = fp32_loss*loss_scale 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves """ self.micro_step_id += 1 if self.contiguous_gradients: self.ipg_buffer = [] buf_0 = torch.empty(int(self.reduce_bucket_size), dtype=self.dtype, device=get_accelerator().current_device_name()) self.ipg_buffer.append(buf_0) # Use double buffers to avoid data access conflict when overlap_comm is enabled. if self.overlap_comm: buf_1 = torch.empty(int(self.reduce_bucket_size), dtype=self.dtype, device=get_accelerator().current_device_name()) self.ipg_buffer.append(buf_1) self.ipg_index = 0 if self.custom_loss_scaler: scaled_loss = self.external_loss_scale * loss scaled_loss.backward() else: self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) # Only for Stage 1, Mode 2 if self.use_grad_accum_attribute: self.fill_grad_accum_attribute() def check_overflow(self, partition_gradients=True): self._check_overflow(partition_gradients) def _update_scale(self, has_overflow=False): self.loss_scaler.update_scale(has_overflow) # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" def _get_state(self): return self.optimizer.state def _set_state(self, value): self.optimizer.state = value state = property(_get_state, _set_state) # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" # (for example, to adjust the learning rate) def _get_param_groups(self): return self.optimizer.param_groups def _set_param_groups(self, value): self.optimizer.param_groups = value param_groups = property(_get_param_groups, _set_param_groups) # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" def _get_loss_scale(self): if self.custom_loss_scaler: return self.external_loss_scale else: return self.loss_scaler.cur_scale def _set_loss_scale(self, value): self.loss_scaler.cur_scale = value loss_scale = property(_get_loss_scale, _set_loss_scale) cur_scale = property(_get_loss_scale, _set_loss_scale) # Return group tensor after removing paddings that are added for alignment to DP world size. # This method works on the assumption that each group contains a single flattened tensor. def _get_groups_without_padding(self, groups_with_padding): groups_without_padding = [] for i, group in enumerate(groups_with_padding): lean_length = group.numel() - self.groups_padding[i] groups_without_padding.append(group[:lean_length]) return groups_without_padding # Return optimizer state after removing paddings that are added for alignment. def _get_state_without_padding(self, state_with_padding, padding): lean_state = {} for key, value in state_with_padding.items(): if torch.is_tensor(value): lean_length = value.numel() - padding lean_state[key] = value[:lean_length] else: lean_state[key] = value return lean_state # Return base optimizer states. # This method assumes that each param group contains a single flattened tensor. def _get_base_optimizer_state(self): optimizer_groups_state = [] for i, group in enumerate(self.optimizer.param_groups): p = group['params'][0] lean_optimizer_state = self._get_state_without_padding(self.optimizer.state[p], self.groups_padding[i]) optimizer_groups_state.append(lean_optimizer_state) return optimizer_groups_state def state_dict(self): """ Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict of the contained Pytorch optimizer. Example:: checkpoint = {} checkpoint['model'] = model.state_dict() checkpoint['optimizer'] = optimizer.state_dict() torch.save(checkpoint, "saved.pth") """ state_dict = {} state_dict[LOSS_SCALER] = self.loss_scaler state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale state_dict['overflow'] = self.overflow state_dict[CLIP_GRAD] = self.clip_grad if self.elastic_checkpoint: state_dict[BASE_OPTIMIZER_STATE] = self._get_base_optimizer_state() if "step" in self.optimizer.param_groups[0]: # Assuming "step" is the only item that changes through training iterations assert all(group["step"] == self.optimizer.param_groups[0]["step"] for group in self.optimizer.param_groups), "All param groups must have the same step value" state_dict[BASE_OPTIMIZER_STATE_STEP] = self.optimizer.param_groups[0]["step"] else: state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict() # Remove paddings for DP alignment to enable loading for other alignment values fp32_groups_without_padding = self._get_groups_without_padding(self.single_partition_of_fp32_groups) state_dict[SINGLE_PARTITION_OF_FP32_GROUPS] = fp32_groups_without_padding state_dict[ ZERO_STAGE] = ZeroStageEnum.gradients if self.partition_gradients else ZeroStageEnum.optimizer_states state_dict[GROUP_PADDINGS] = self.groups_padding state_dict[PARTITION_COUNT] = self.partition_count state_dict[DS_VERSION] = version state_dict[PARAM_SLICE_MAPPINGS] = self._param_slice_mappings return state_dict # Restore base optimizer fp32 weights from elastic checkpoint by: # 1) Merging fp32 weights from checkpoints of all partitions # 2) Extracting fp32 weights for current partition from merged weights # 3) Using extracted weights to update base optimizer weights directly. def _restore_from_elastic_fp32_weights(self, all_state_dict): merged_single_partition_of_fp32_groups = [] for i in range(len(self.single_partition_of_fp32_groups)): partition_id = dist.get_rank(group=self.real_zp_process_group[i]) merged_partitions = [sd[SINGLE_PARTITION_OF_FP32_GROUPS][i] for sd in all_state_dict] if self.is_moe_group(self.optimizer.param_groups[i]): ranks = self.get_ep_ranks(group_name=self.optimizer.param_groups[i]['name']) merged_partitions = [merged_partitions[i] for i in ranks] flat_merged_partitions = self.flatten_dense_tensors_aligned( merged_partitions, self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_zp_process_group[i])) dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions, i) merged_single_partition_of_fp32_groups.append(dp_partitions[partition_id]) for current, saved in zip(self.single_partition_of_fp32_groups, merged_single_partition_of_fp32_groups): current.data.copy_(saved.data) # Restore base optimizer fp32 weights from ZeRO fp16 or bfloat16 weights def _restore_from_bit16_weights(self): for group_id, (bit16_partitions, fp32_partition) in enumerate( zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)): partition_id = dist.get_rank(group=self.real_zp_process_group[group_id]) fp32_partition.data.copy_(bit16_partitions[partition_id].data) # Refresh the fp32 master params from the fp16 or bfloat16 copies. def refresh_fp32_params(self): self._restore_from_bit16_weights() # Extract optimizer state for current partition from merged states of all partitions def _partition_base_optimizer_state(self, state_key, all_partition_states, group_id): partition_id = dist.get_rank(group=self.real_zp_process_group[group_id]) alignment = dist.get_world_size(group=self.real_zp_process_group[group_id]) if torch.is_tensor(all_partition_states[0]): flat_merged_partitions = self.flatten_dense_tensors_aligned(all_partition_states, alignment) dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions, group_id) return dp_partitions[partition_id] else: # Assume non-tensor states are not partitioned and equal across ranks, so return first one return all_partition_states[0] def _restore_base_optimizer_state(self, base_optimizer_group_states): if type(base_optimizer_group_states) == dict: base_optimizer_group_states = base_optimizer_group_states['state'] for i, group in enumerate(self.optimizer.param_groups): p = group['params'][0] for key, saved in base_optimizer_group_states[i].items(): if torch.is_tensor(self.optimizer.state[p][key]): dst_tensor = self.optimizer.state[p][key] src_tensor = _get_padded_tensor(saved, dst_tensor.numel()) self.optimizer.state[p][key].data.copy_(src_tensor.data) else: self.optimizer.state[p][key] = saved def get_ep_ranks(self, rank=0, group_name=None): from deepspeed.utils import groups expert_parallel_size_ = groups._get_expert_parallel_world_size(group_name) world_size = groups._get_data_parallel_world_size() rank = groups._get_expert_parallel_rank(group_name) ranks = range(rank, world_size, expert_parallel_size_) return list(ranks) # Restore base optimizer state from elastic checkpoint by # 1) Merging optimizer state from checkpoints of all partitions # 2) Extracting optimizer state for current partition from the merged state # 3) Using the extracted value to directly update the base optimizer. def _restore_elastic_base_optimizer_state(self, all_state_dict): base_optimizer_group_states = [] for i in range(len(self.optimizer.param_groups)): partition_states = {} all_partition_group_states = [sd[BASE_OPTIMIZER_STATE][i] for sd in all_state_dict] if self.is_moe_group(self.optimizer.param_groups[i]): ranks = self.get_ep_ranks(group_name=self.optimizer.param_groups[i]['name']) all_partition_group_states = [all_partition_group_states[i] for i in ranks] for key in all_partition_group_states[0].keys(): all_partition_states = [all_states[key] for all_states in all_partition_group_states] partition_states[key] = self._partition_base_optimizer_state(key, all_partition_states, i) base_optimizer_group_states.append(partition_states) self._restore_base_optimizer_state(base_optimizer_group_states) # Restore step if BASE_OPTIMIZER_STATE_STEP in all_state_dict[0]: assert all(sd[BASE_OPTIMIZER_STATE_STEP] == all_state_dict[0][BASE_OPTIMIZER_STATE_STEP] for sd in all_state_dict), "State dicts of all partitions must have the same step value" loaded_param_groups_step = all_state_dict[0][BASE_OPTIMIZER_STATE_STEP] for param_group in self.optimizer.param_groups: param_group['step'] = loaded_param_groups_step def load_state_dict(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False, checkpoint_folder=None, load_serial=None): if checkpoint_folder: self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights) else: self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights) def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights): self._load_hp_checkpoint_state(checkpoint_folder) @property def param_groups(self): """Forward the wrapped optimizer's parameters.""" return self.optimizer.param_groups def _load_hp_checkpoint_state(self, checkpoint_dir): checkpoint_dir = os.path.join(checkpoint_dir, "zero") optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") assert os.path.isfile( optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' optim_sd = torch.load(optim_state_path) self._load_global_state(optim_sd) tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \ else self.mpu.get_tensor_model_parallel_world_size() for i, _ in enumerate(self.optimizer.param_groups): for lp in self.bit16_groups[i]: if lp._hp_mapping is not None: # print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}") lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank, tp_world_size) def _load_global_state(self, sd): self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler) self.dynamic_loss_scale = sd.get('dynamic_loss_scale', self.dynamic_loss_scale) self.overflow = sd.get('overflow', self.overflow) self.clip_grad = sd.get(CLIP_GRAD, self.clip_grad) ckpt_version = sd.get(DS_VERSION, False) assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed" ckpt_version = pkg_version.parse(ckpt_version) # zero stage 1 mode if not self.partition_gradients: required_version = pkg_version.parse("0.3.17") error_str = f"ZeRO stage 1 changed in {required_version} and is not backwards compatible " \ "with older stage 1 checkpoints. If you'd like to load an old ZeRO-1 checkpoint " \ "please use an older version of DeepSpeed (<= 0.5.8) and set 'legacy_stage1': true in your zero config json." assert required_version <= ckpt_version, f"Old version: {ckpt_version} {error_str}" def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False): r"""Loading ZeRO checkpoint Arguments: state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition. Note that the number of saved partitions may differ from number of loading partitions to support changing GPU count, specifically DP world size, between saving and loading checkpoints. load_optimizer_states: Boolean indicating whether or not to load base optimizer states load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32 copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss). """ """ Loads a state_dict created by an earlier call to state_dict(). If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, whose parameters in turn came from ``model``, it is expected that the user will call ``model.load_state_dict()`` before ``fp16_optimizer_instance.load_state_dict()`` is called. Example:: model = torch.nn.Linear(D_in, D_out).to(get_accelerator().device_name()).half() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) ... checkpoint = torch.load("saved.pth") model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) """ # I think it should actually be ok to reload the optimizer before the model. dp_rank = dist.get_rank(group=self.zp_process_group) current_rank_sd = state_dict_list[dp_rank] self._load_global_state(current_rank_sd) ckpt_is_rigid = isinstance(current_rank_sd[BASE_OPTIMIZER_STATE], dict) # padding is always at the last rank/partition # if DP=1024 and param-group elems=16 -> padding will be 1024-16 across all but one rank # scenario-1 (shrink): saving w. 4 gpus -> loading w. 2 gpus # scenario-2 (expand): saving w. 2 gpus -> loading w. 4 gpus # if load_optimizer_states: # if new_dp_size: # self.strip_padding() # self.add_padding_w_new_dp_size() # self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE]) if load_optimizer_states: if ckpt_is_rigid: # loading rigid ckpt into either rigid or elastic exec self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE]) else: if self.elastic_checkpoint: # loading elastic into elastic exec self._restore_elastic_base_optimizer_state(state_dict_list) else: # loading an elastic checkpoint into rigid exec self._restore_base_optimizer_state(current_rank_sd[BASE_OPTIMIZER_STATE]) # At this point, the optimizer's references to the model's fp32 parameters are up to date. # The optimizer's hyperparameters and internal buffers are also up to date. # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still # out of date. There are two options. # 1: Refresh the master params from the model's fp16 params. # This requires less storage but incurs precision loss. # 2: Save and restore the fp32 master copies separately. # We choose option 1 if changing DP degree and option 2 otherwise. # # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device # of their associated parameters, because it's possible those buffers might not exist yet in # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been # constructed in the same way as the one whose state_dict we are loading, the same master params # are guaranteed to exist, so we can just copy_() from the saved master params. if load_from_fp32_weights: # option 2 from above if self.elastic_checkpoint and not ckpt_is_rigid: self._restore_from_elastic_fp32_weights(state_dict_list) else: # For non-elastic checkpoint, simply copying from saved weights of current rank is sufficient. for current, saved in zip(self.single_partition_of_fp32_groups, current_rank_sd[SINGLE_PARTITION_OF_FP32_GROUPS]): src_tensor = _get_padded_tensor(saved, current.numel()) current.data.copy_(src_tensor.data) else: # option 1 from above self._restore_from_bit16_weights() if load_optimizer_states: self._link_all_hp_params() def _handle_overflow(cpu_sum, x, i): import math rank = dist.get_rank() if rank == 0: t_i = -1 for v_i, v in enumerate(x.data.contiguous().view(-1)): if not math.isfinite(float(v)): t_i = v_i break logger.info(f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}") def estimate_zero2_model_states_mem_needs(total_params, num_gpus_per_node=1, num_nodes=1, cpu_offload=True, additional_buffer_factor=1.5): total_gpus = num_nodes * num_gpus_per_node if cpu_offload: gpu_mem = 2 * total_params cpu_mem = total_params * max(4 * total_gpus, 16) * additional_buffer_factor else: gpu_mem = 4 * total_params + int(16 * total_params / total_gpus) cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor return int(cpu_mem), int(gpu_mem) def model_to_params(model): # shared params calculated only once total_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values()) return total_params def estimate_zero2_model_states_mem_needs_all_live(model, num_gpus_per_node=1, num_nodes=1, additional_buffer_factor=1.5): """ Print out estimates on memory usage requirements for ZeRO 2 params, optim states and gradients for a given ``model`` and hardware setup. If you have an actual model object, use this function and everything will be derived automatically. If it's a hypothetical model, use ``estimate_zero2_model_states_mem_needs_all_cold`` where you have to pass the ``total_params`` explicitly. Args: - ``model``: ``nn.Module`` object - ``num_gpus_per_node``: how many gpus per node (defaults to 1) - ``num_nodes``: how many nodes (defaults to 1), - ``additional_buffer_factor``: estimation factor (defaults to 1.5): """ total_params = model_to_params(model) estimate_zero2_model_states_mem_needs_all_cold(total_params=total_params, num_gpus_per_node=num_gpus_per_node, num_nodes=num_nodes, additional_buffer_factor=additional_buffer_factor) def estimate_zero2_model_states_mem_needs_all_cold(total_params, num_gpus_per_node=1, num_nodes=1, additional_buffer_factor=1.5): """ Print out estimates on memory usage requirements for ZeRO 2 params, optim states and gradients for a given ``model`` and hardware setup. If it's a hypothetical model, use this function where you have to pass the ``total_params`` and ``largest_layer_params`` explicitly. If you have an actual model object, use ``estimate_zero2_model_states_mem_needs_all_live`` and everything will be derived automatically. Args: - ``total_params``: total model params - ``num_gpus_per_node``: how many gpus per node (defaults to 1) - ``num_nodes``: how many nodes (defaults to 1), - ``additional_buffer_factor``: estimation factor (defaults to 1.5): """ def format_options(cpu_offload): enabled = [] device = f'{OffloadDeviceEnum.cpu:4}' if cpu_offload else "none" enabled.append(f"offload_optimizer={device}") return ", ".join(enabled) nodes_str = "nodes" if num_nodes > 1 else "node" gpus_str = "GPUs" if num_gpus_per_node > 1 else "GPU" print("Estimated memory needed for params, optim states and gradients for a:\n" f"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\n" f"SW: Model with {int(total_params / 1e6)}M total params.") print(" per CPU | per GPU | Options") for cpu_offload in [True, False]: cpu_mem, gpu_mem = estimate_zero2_model_states_mem_needs(total_params=total_params, num_gpus_per_node=num_gpus_per_node, num_nodes=num_nodes, cpu_offload=cpu_offload, additional_buffer_factor=additional_buffer_factor) options_str = format_options(cpu_offload=cpu_offload) print(f" {cpu_mem / 2 ** 30:7.2f}GB | {gpu_mem / 2 ** 30:6.2f}GB | {options_str}") ================================================ FILE: opensora/adaptor/utils.py ================================================ # Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team """ Copyright NVIDIA/Megatron Helper functions and classes from multiple sources. """ from collections.abc import Iterable from deepspeed.moe.utils import is_moe_param import os import psutil import gc from math import sqrt from packaging import version as pkg_version import torch from deepspeed import comm as dist try: from torch._six import inf except ModuleNotFoundError: from torch import inf from deepspeed.utils import groups, logger from deepspeed.runtime.constants import PIPE_REPLICATED from numpy import prod from deepspeed.accelerator import get_accelerator from deepspeed.module_inject.policy import transpose from torch.nn import functional as F torch_memory_reserved = get_accelerator().memory_reserved torch_max_memory_reserved = get_accelerator().max_memory_reserved class DummyOptim(): """ Dummy optimizer presents model parameters as a param group, this is primarily used to allow ZeRO-3 without an optimizer """ def __init__(self, params): self.param_groups = [] self.param_groups.append({'params': params}) graph_cache = {} def graph_process(replay_first_step, func, *args, **kwargs): # `func` should only contain operations on the GPU # Please ensure that the memory address of the data required by 'func' remains constant if func.__name__ not in graph_cache: cuda_stream = get_accelerator().Stream() cuda_stream.wait_stream(get_accelerator().current_stream()) with get_accelerator().stream(cuda_stream): func(*args, **kwargs) get_accelerator().current_stream().wait_stream(cuda_stream) graph_cache[func.__name__] = get_accelerator().create_graph() with get_accelerator().capture_to_graph(graph_cache[func.__name__]): func(*args, **kwargs) if replay_first_step: get_accelerator().replay_graph(graph_cache[func.__name__]) else: get_accelerator().replay_graph(graph_cache[func.__name__]) def noop_decorator(func): return func class noop_context(object): def __init__(self): pass def __enter__(self): pass def __exit__(self, exc_type, exc_val, exc_tb): pass def ensure_directory_exists(filename): """Create the directory path to ``filename`` if it does not already exist. Args: filename (str): A file path. """ dirname = os.path.dirname(filename) os.makedirs(dirname, exist_ok=True) def set_random_seed(seed): """Set the random seed for common PRNGs used during training: random, numpy, and torch. Args: seed (int): the seed to use """ import numpy import random random.seed(seed) numpy.random.seed(seed) torch.manual_seed(seed) def is_model_parallel_parameter(p) -> bool: if hasattr(p, 'model_parallel') and p.model_parallel: return True if hasattr(p, 'tensor_model_parallel') and p.tensor_model_parallel: return True return False def bwc_tensor_model_parallel_rank(mpu=None): """Backwards-compatible way of querying the tensor model parallel rank from an ``mpu`` object. *Tensor* model parallelism means that tensors are physically split across processes. This contrasts with *pipeline* model parallelism, in which the layers are partitioned but tensors left intact. The API for tensor model parallelism has changed across versions and this helper provides a best-effort implementation across versions of ``mpu`` objects. The preferred mechanism is ``mpu.get_tensor_model_parallel_rank()``. This should "just work" with both Megatron-LM and DeepSpeed's pipeline parallelism. Args: mpu (model parallel unit, optional): The tensor model parallel rank. If ``mpu=None``, returns 0. Defaults to ``None``. Returns: int: the rank """ if mpu is None: # No model parallelism in easy :) return 0 if hasattr(mpu, 'get_tensor_model_parallel_rank'): # New Megatron and DeepSpeed convention (post pipeline-parallelism release) return mpu.get_tensor_model_parallel_rank() elif hasattr(mpu, 'get_slice_parallel_rank'): # Some DeepSpeed + pipeline parallelism versions return mpu.get_slice_parallel_rank() else: # Deprecated Megatron and DeepSpeed convention return mpu.get_model_parallel_rank() def copy_to_device(item, device, criterion_func): """ Return a copy of tensor on specified device. Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts. Parameters: item: tensor to copy or (possibly nested) container of tensors to copy. device: target device criterion_func: Function to restrict copy operation to items meet criterion Returns: None """ if criterion_func(item): return item.to(device) elif isinstance(item, list): return [copy_to_device(v, device, criterion_func) for v in item] elif isinstance(item, tuple): return tuple([copy_to_device(v, device, criterion_func) for v in item]) elif isinstance(item, dict): return {k: copy_to_device(v, device, criterion_func) for k, v in item.items()} else: return item def move_to_device(item, device, criterion_func): """ Move tensor on to specified device by changing the storage. Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts. Parameters: item: tensor to move or (possibly nested) container of tensors to move. device: target device criterion_func: Function to restrict move operation to items meet criterion Returns: None """ if criterion_func(item): device_copy = item.to(device) item.data = device_copy.data return item elif isinstance(item, list): return [move_to_device(v, device, criterion_func) for v in item] elif isinstance(item, tuple): return tuple([move_to_device(v, device, criterion_func) for v in item]) elif isinstance(item, dict): return {k: move_to_device(v, device, criterion_func) for k, v in item.items()} else: return item class CheckOverflow(object): '''Checks for overflow in gradient across parallel process''' def __init__(self, param_groups=None, mpu=None, zero_reduce_scatter=False, deepspeed=None): self.mpu = mpu self.params = [] if param_groups else None self.zero_reduce_scatter = zero_reduce_scatter self.deepspeed = deepspeed self.has_moe_params = False if param_groups: for group in param_groups: for param in group: self.params.append(param) if is_moe_param(param): self.has_moe_params = True def check_using_norm(self, norm_group, reduce_overflow=True): # TODO: I don't think reduce_overflow is needed if mpu is None overflow = -1 in norm_group overflow_gpu = get_accelerator().FloatTensor([overflow]) if self.has_moe_params: # In this case, we need to do an all_reduce across # the expert_parallel_group, so that if there was # an overflow due to expert weights, we detect it # Only need to check groups.get_largest_expert_parallel_group() dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=groups._get_max_expert_parallel_group()) if self.mpu is not None: dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_model_parallel_group()) elif reduce_overflow: dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX) dist.barrier() overflow = overflow_gpu[0].item() return bool(overflow) def check(self, param_groups=None): params = [] has_moe_params = False if param_groups is None: params = self.params has_moe_params = self.has_moe_params else: assert param_groups is not None, \ "self.params and param_groups both cannot be none" for group in param_groups: for param in group: params.append(param) if is_moe_param(param): has_moe_params = True return self.has_overflow(params, has_moe_params=has_moe_params) # `params` is a list / generator of torch.Variable def has_overflow_serial(self, params): for i, p in enumerate(params): if p.grad is not None and self._has_inf_or_nan(p.grad.data, i): return True return False def has_overflow(self, params, has_moe_params=None): if has_moe_params is None: has_moe_params = self.has_moe_params overflow = self.has_overflow_serial(params) # Since each model parallel GPU carries only part of the model, # make sure overflow flag is synced across all the model parallel GPUs overflow_gpu = get_accelerator().ByteTensor([overflow]) # deepspeed.comm.all_reduce(overflow_gpu, # op=deepspeed.comm.ReduceOp.MAX, # group=mpu.get_model_parallel_group()) if has_moe_params: # All reduce this across expert_parallel_group, so that if an expert # overflows, we detect it here dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=groups._get_max_expert_parallel_group()) if self.zero_reduce_scatter: dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=dist.get_world_group()) elif self.mpu is not None: if self.deepspeed is not None: using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') if (using_pipeline and self.deepspeed.pipeline_enable_backward_allreduce is False) or ( not using_pipeline and self.deepspeed.enable_backward_allreduce is False): dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_data_parallel_group()) dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_model_parallel_group()) elif self.deepspeed is not None and self.deepspeed.enable_backward_allreduce is False: dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=dist.get_world_group()) overflow = overflow_gpu[0].item() return bool(overflow) # `x` is a torch.Tensor @staticmethod def _has_inf_or_nan(x, i): try: # if x is half, the .float() incurs an additional deep copy, but it's necessary if # Pytorch's .sum() creates a one-element tensor of the same type as x # (which is true for some recent version of pytorch). cpu_sum = float(x.float().sum()) # More efficient version that can be used if .sum() returns a Python scalar # cpu_sum = float(x.sum()) except RuntimeError as instance: # We want to check if inst is actually an overflow exception. # RuntimeError could come from a different error. # If so, we still want the exception to propagate. if "value cannot be converted" not in instance.args[0]: raise return True else: if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: return True return False def _handle_overflow(cpu_sum, x, i): import math rank = dist.get_rank() if rank == 0: t_i = -1 for v_i, v in enumerate(x.data.contiguous().view(-1)): if not math.isfinite(float(v)): t_i = v_i break logger.info(f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}") def get_global_norm(norm_list): """ Compute total from a list of norms """ total_norm = 0.0 for norm in norm_list: total_norm += norm**2.0 # logger.info(f'norm_list = {norm_list} global = {sqrt(total_norm)}') return sqrt(total_norm) def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None): """Clips gradient norm of an iterable of parameters. This has been adapted from Nvidia megatron. We add norm averaging to consider MoE params when calculating norm as they will result in different norms across different ranks. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and added functionality to handle model parallel parameters. Note that the gradients are modified in place. Arguments: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. Returns: Total norm of the parameters (viewed as a single vector). """ if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = list(filter(lambda p: p.grad is not None, parameters)) max_norm = float(max_norm) norm_type = float(norm_type) if norm_type == inf: total_norm = max(p.grad.data.abs().max() for p in parameters) total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) # Take max across all GPUs. if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() else: total_norm = 0 for p in parameters: if mpu is not None: if (mpu.get_model_parallel_rank() == 0) or is_model_parallel_parameter(p): param_norm = p.grad.data.norm(norm_type) total_norm += param_norm.item()**norm_type else: param_norm = p.grad.data.float().norm(norm_type) total_norm += param_norm.item()**norm_type # Sum across all model parallel GPUs. total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item()**(1. / norm_type) # Need to average total_norm across different GPUs due to the presence of moe params pg = groups._get_data_parallel_group() scaled_norm = total_norm * 1.0 / float(dist.get_world_size(group=pg)) scaled_norm_tensor = get_accelerator().FloatTensor([float(scaled_norm)]) dist.all_reduce(scaled_norm_tensor, group=pg) total_norm = scaled_norm_tensor.item() clip_coef = max_norm / (total_norm + 1e-6) if clip_coef < 1: for p in parameters: p.grad.data.mul_(clip_coef) return total_norm def get_grad_norm(parameters, norm_type=2, mpu=None): """Get grad norm of an iterable of parameters. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and added functionality to handle model parallel parameters. Note that the gradients are modified in place. Taken from Nvidia Megatron. Arguments: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. Returns: Total norm of the parameters (viewed as a single vector). """ if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = list(filter(lambda p: p.grad is not None, parameters)) norm_type = float(norm_type) if norm_type == inf: total_norm = max(p.grad.data.abs().max() for p in parameters) total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) # Take max across all GPUs. if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() else: total_norm = 0. tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu) for p in parameters: # Pipeline parallelism may replicate parameters. Avoid multi-counting. if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: continue # Filter to avoid over-counting replicated tensors from tensor # model parallelism if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p): continue param_norm = p.grad.data.float().norm(norm_type) total_norm += param_norm.item()**norm_type # Sum across all model parallel GPUs. total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item()**(1. / norm_type) if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: total_norm = -1 return total_norm def get_grad_zeros(parameters, mpu=None): """Compute the number of grads with zero values. This is adapted from get_grad_norm Arguments: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized Returns: Total number of params with zero values (viewed as a single vector). """ if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = list(filter(lambda p: p.grad is not None, parameters)) total_zeros = 0. tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu) for p in parameters: # Pipeline parallelism may replicate parameters. Avoid multi-counting. if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: continue # Filter to avoid over-counting replicated tensors from tensor # model parallelism if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p): continue count_zeros = p.grad.numel() - torch.count_nonzero(p.grad) total_zeros += count_zeros.item() # Sum across all model parallel GPUs. total_zeros_cuda = get_accelerator().FloatTensor([float(total_zeros)]) if mpu is not None: dist.all_reduce(total_zeros_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) total_zeros = total_zeros_cuda[0].item() return total_zeros def get_weight_norm(parameters, norm_type=2, mpu=None): """Get norm of an iterable of parameters. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and added functionality to handle model parallel parameters. Note that the gradients are modified in place. Taken from Nvidia Megatron. Arguments: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. Returns: Total norm of the parameters (viewed as a single vector). -1 if the norm value is NaN or Inf. """ if isinstance(parameters, torch.Tensor): parameters = [parameters] norm_type = float(norm_type) if norm_type == inf: total_norm = max(p.data.abs().max() for p in parameters) total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) # Take max across all GPUs. if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() else: total_norm = 0. tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu) for p in parameters: # Pipeline parallelism may replicate parameters. Avoid multi-counting. if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: continue # Filter to avoid over-counting replicated tensors from tensor # model parallelism if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p): continue param_norm = p.data.float().norm(norm_type) total_norm += param_norm**norm_type # Sum across all model parallel GPUs. total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item()**(1. / norm_type) if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: total_norm = -1 return total_norm def prefix_sum_inc(weights): """ Compute an inclusive prefix sum. Example: >>> prefix_sum_inc([3,4,5]) [3, 7, 12] """ weights_ = [w for w in weights] for x in range(1, len(weights_)): weights_[x] += weights_[x - 1] return weights_ def partition_uniform(num_items, num_parts): import numpy parts = [0] * (num_parts + 1) # First check for the trivial edge case if num_items <= num_parts: for p in range(num_parts + 1): parts[p] = min(p, num_items) return parts chunksize = num_items // num_parts residual = num_items - (chunksize * num_parts) parts = numpy.arange(0, (num_parts + 1) * chunksize, chunksize) for i in range(residual): parts[i + 1:] += 1 parts = parts.tolist() return parts def partition_balanced(weights, num_parts): """ use dynamic programming solve `The Linear Partition Problem`. see https://www8.cs.umu.se/kurser/TDBAfl/VT06/algorithms/BOOK/BOOK2/NODE45.HTM """ import numpy as np n = len(weights) m = num_parts if n <= m: return partition_uniform(n, m) dp_max = np.full((n + 1, m + 1), np.inf) dp_min = np.full((n + 1, m + 1), np.inf) dp_cost = np.full((n + 1, m + 1), np.inf) position = np.zeros((n + 1, m + 1), dtype=int) prefix_sum = np.zeros((n + 1)) prefix_sum[1:] = np.cumsum(weights) dp_max[0, 0] = 0 dp_cost[0, 0] = 0 for i in range(1, n + 1): for j in range(1, min(i, m) + 1): for k in range(i): max_sum = max(dp_max[k, j - 1], prefix_sum[i] - prefix_sum[k]) min_sum = min(dp_min[k, j - 1], prefix_sum[i] - prefix_sum[k]) cost = max_sum - min_sum if dp_cost[i, j] >= cost: dp_cost[i, j] = cost dp_max[i, j] = max_sum dp_min[i, j] = min_sum position[i, j] = k parts = [n] for i in reversed(range(1, m + 1)): parts.append(position[parts[-1], i]) parts.reverse() return parts class PartitionedTensor: def __init__(self, tensor, group, partition_meta=None): super().__init__() self.group = group self.num_parts = dist.get_world_size(group=self.group) self.rank = dist.get_rank(group=self.group) self.orig_size = list(tensor.size()) self.orig_device = tensor.device self.local_data, self.partition = self._partition_tensor(tensor) self.even_split = tensor.numel() % self.num_parts == 0 @classmethod def from_meta(cls, meta, local_part, group, device=get_accelerator().device_name()): assert meta.dtype == torch.long dummy = torch.ones(dist.get_world_size(group=group)) part_obj = cls(tensor=dummy, group=group) meta = meta.tolist() # [N, list0, ..., listN-1] part_obj.orig_size = meta[1:(1 + meta[0])] meta = meta[1 + meta[0]:] part_obj.orig_device = device part_obj.local_data = local_part.detach() part_obj.group = group # Partition is encoded like the rowptr of a CSR matrix: # [num_parts, rank, 0, part_1, ..., part_num_parts] # TODO: support shuffle between different partition granularities assert part_obj.num_parts == meta[0] assert part_obj.rank == meta[1] part_obj.partition = meta[2:] # length num_parts+1 return part_obj def _partition_tensor(self, tensor): partition = partition_uniform(num_items=tensor.numel(), num_parts=self.num_parts) start = partition[self.rank] length = partition[self.rank + 1] - start tensor_part = tensor.detach().contiguous().view(-1).narrow(0, start=start, length=length).clone() return tensor_part, partition def full(self, device=None): if device is None: device = self.orig_device # Allocate the full tensor as a flat buffer. full_numel = prod(self.full_size()) flat_tensor = torch.zeros([full_numel], dtype=self.local_data.dtype, device=device) if self.even_split: # Collect the full tensor dist.all_gather_into_tensor(flat_tensor, self.local_data, group=self.group) else: for part_id in range(self.num_parts): part_size = self.partition[part_id + 1] - self.partition[part_id] buf = flat_tensor.narrow(0, start=self.partition[part_id], length=part_size) if part_id == self.rank: buf.copy_(self.local_data) dist.broadcast(buf, part_id, self.group) return flat_tensor.view(self.full_size()).clone().detach() def to_meta(self): """Returns a torch.LongTensor that encodes partitioning information. Can be used along with ``data()`` to serialize a ``PartitionedTensor`` for communication. Returns: torch.LongTensor: a tensor encoding the meta-information for the partitioning """ meta = [] meta.append(len(self.orig_size)) meta += list(self.orig_size) meta.append(self.num_parts) meta.append(self.rank) meta += self.partition return torch.LongTensor(data=meta).to(self.orig_device) def data(self): return self.local_data def local_size(self): return self.local_data.size() def full_size(self): return self.orig_size mem_alloced = 0 mem_cached = 0 def memory_status(msg, print_rank=-1, reset_max=False): global mem_alloced, mem_cached rank = dist.get_rank() if print_rank != -1 and rank != print_rank: return get_accelerator().synchronize() if reset_max: get_accelerator().reset_max_memory_cached() get_accelerator().reset_max_memory_allocated() new_alloced = get_accelerator().memory_allocated() new_cached = get_accelerator().memory_cached() delta_alloced = new_alloced - mem_alloced delta_cached = new_cached - mem_cached mem_cached = new_cached mem_alloced = new_alloced max_alloced = get_accelerator().max_memory_allocated() max_cached = get_accelerator().max_memory_cached() # convert to GB for printing new_alloced /= 1024**3 new_cached /= 1024**3 delta_alloced /= 1024**3 delta_cached /= 1024**3 max_alloced /= 1024**3 max_cached /= 1024**3 print( f'RANK={rank} MEMSTATS', msg, f'device={get_accelerator().current_device_name()} ' f'current alloc={new_alloced:0.4f}GB (delta={delta_alloced:0.4f}GB max={max_alloced:0.4f}GB) ' f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)') def get_ma_status(): if dist.is_initialized() and not dist.get_rank() == 0: return 0 return get_accelerator().memory_allocated() def empty_cache(): get_accelerator().empty_cache() get_accelerator().reset_peak_memory_stats() def see_memory_usage(message, force=False): if not force: return if dist.is_initialized() and not dist.get_rank() == 0: return # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports gc.collect() # Print message except when distributed but not rank 0 logger.info(message) logger.info(f"MA {round(get_accelerator().memory_allocated() / (1024 * 1024 * 1024),2 )} GB \ Max_MA {round(get_accelerator().max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \ CA {round(torch_memory_reserved() / (1024 * 1024 * 1024),2)} GB \ Max_CA {round(torch_max_memory_reserved() / (1024 * 1024 * 1024))} GB ") vm_stats = psutil.virtual_memory() used_GB = round(((vm_stats.total - vm_stats.available) / (1024**3)), 2) logger.info(f'CPU Virtual Memory: used = {used_GB} GB, percent = {vm_stats.percent}%') # get the peak memory to report correct data, so reset the counter for the next call get_accelerator().reset_peak_memory_stats() def call_to_str(base, *args, **kwargs): """Construct a string representation of a call. Args: base (str): name of the call args (tuple, optional): args to ``base`` kwargs (dict, optional): kwargs supplied to ``base`` Returns: str: A string representation of base(*args, **kwargs) """ name = f'{base}(' if args: name += ', '.join(repr(arg) for arg in args) if kwargs: name += ', ' if kwargs: name += ', '.join(f'{key}={repr(arg)}' for key, arg in kwargs.items()) name += ')' return name def get_only_unique_item(items): item_set = set(items) if len(item_set) != 1: raise RuntimeError(f"expected there to be only one unique element in {items}") unique_item, = item_set return unique_item def clip_gradients(parameters, max_norm=1.0, global_grad_norm=None, mpu=None, eps=1e-6): """Clip the gradient of a list of parameters. Args: parameters: List of parameters whose .grad will be clipped. global_grad_norm (float, optional): Precomputed gradient norm. Defaults to None. mpu (optional): model parallelism unit. Defaults to None. eps (float, optional): epsilon value added to grad norm. Defaults to 1e-6 Returns: float: the global gradient norm """ if global_grad_norm is None: global_grad_norm = get_grad_norm(parameters, mpu=mpu) clip_coef = max_norm / (global_grad_norm + eps) if clip_coef < 1: for p in parameters: p.grad.detach().mul_(clip_coef) return global_grad_norm def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False): """Get norm of an iterable of tensors. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and added functionality to handle model parallel parameters. Taken from Nvidia Megatron. Arguments: input_tensors (Iterable[Tensor]): an iterable of Tensors will have norm computed norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. Returns: Total norm of the tensors (viewed as a single vector). """ assert isinstance(input_tensors, Iterable), f'expected Iterable type not {type(input_tensors)}' assert all([torch.is_tensor(t) for t in input_tensors]), f'expected list of only tensors' norm_type = float(norm_type) if norm_type == inf: total_norm = max(t.data.abs().max() for t in input_tensors) total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() else: if use_graph: if 'norm_tensors_compute_buffer' not in graph_cache: graph_cache['norm_tensors_compute_buffer'] = [t.data.float().norm(norm_type) for t in input_tensors] compute_buffer = graph_cache['norm_tensors_compute_buffer'] def _norm_tensors(tensor_list, _compute_buffer, _norm_type): for i, t in enumerate(tensor_list): _compute_buffer[i].data.copy_(t.data.float().norm(_norm_type)**_norm_type) if i != 0: _compute_buffer[0].data.add_(_compute_buffer[i].data) graph_process(False, _norm_tensors, input_tensors, compute_buffer, norm_type) total_norm = compute_buffer[0] else: total_norm = sum([t.data.float().norm(norm_type).item()**norm_type for t in input_tensors]) total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]).detach() if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item()**(1. / norm_type) if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: total_norm = -1 return total_norm def clip_tensors_by_global_norm(input_tensors, max_norm=1.0, global_norm=None, mpu=None, eps=1e-6, use_graph=False): """Clip list of tensors by global norm. Args: input_tensors: List of tensors to be clipped global_norm (float, optional): Precomputed norm. Defaults to None. mpu (optional): model parallelism unit. Defaults to None. eps (float, optional): epsilon value added to grad norm. Defaults to 1e-6 Returns: float: the global norm """ if global_norm is None: global_norm = get_global_norm_of_tensors(input_tensors, mpu=mpu, use_graph=use_graph) clip_coef = max_norm / (global_norm + eps) if clip_coef < 1: if use_graph: def clip_tensors(_tensor_list, _clip_coef_tensor): for t in _tensor_list: t.detach().mul_(_clip_coef_tensor) if 'clip_coef_tensor' not in graph_cache: # Alloc memory graph_cache['clip_coef_tensor'] = torch.tensor(clip_coef, dtype=torch.float32).to(get_accelerator().device_name()) clip_coef_tensor = graph_cache['clip_coef_tensor'] clip_coef_tensor.copy_(torch.tensor(clip_coef, dtype=torch.float32)) graph_process(False, clip_tensors, input_tensors, clip_coef_tensor) else: for t in input_tensors: t.detach().mul_(clip_coef) return global_norm def align_dense_tensors(tensor_list, alignment): num_elements = sum(t.numel() for t in tensor_list) remaining = num_elements % alignment if remaining: elements_to_add = alignment - remaining pad_tensor = torch.zeros(elements_to_add, device=tensor_list[0].device, dtype=tensor_list[0].dtype) padded_tensor_list = tensor_list + [pad_tensor] else: padded_tensor_list = tensor_list return padded_tensor_list def all_gather_into_tensor_dp_groups(groups_flat, partitioned_param_groups, zp_process_group, dp_process_group=None): for group_id, (group_flat, partitioned_params) in enumerate(zip(groups_flat, partitioned_param_groups)): partition_id = dist.get_rank(group=zp_process_group[group_id]) dp_world_size = dist.get_world_size(group=dp_process_group) if dp_world_size == 1: # no groups share optimizer states # pipeline parallel with bf16 will default call this even if dp size = 1. continue # print("call contiguous for all_gather_into_tensor_dp_groups") dist.all_gather_into_tensor(group_flat, partitioned_params[partition_id].contiguous(), dp_process_group) def all_gather_dp_groups(groups_flat, partitioned_param_groups, zp_process_group, start_alignment_factor, allgather_bucket_size, dp_process_group=None): # if dist.has_all_gather_into_tensor(): return all_gather_into_tensor_dp_groups(groups_flat, partitioned_param_groups, zp_process_group, dp_process_group) # for group_id, partitioned_params in enumerate(partitioned_param_groups): # # Sequential AllGather Best of both worlds # partition_id = dist.get_rank(group=dp_process_group[group_id]) # dp_world_size = dist.get_world_size(group=dp_process_group[group_id]) # # if dp_world_size == 1: # # no groups share optimizer states # # pipeline parallel with bf16 will default call this even if dp size = 1. # continue # num_shards = max(1, partitioned_params[partition_id].numel() * dp_world_size // allgather_bucket_size) # # shard_size = partitioned_params[partition_id].numel() // num_shards # # # Enforce nccl/rccl alignment of start location of each shard # shard_size = shard_size - (shard_size % start_alignment_factor) # # num_elements = shard_size # # assert shard_size * num_shards <= partitioned_params[partition_id].numel() # # for shard_id in range(num_shards): # # if shard_id == (num_shards - 1): # num_elements = partitioned_params[partition_id].numel() - shard_id * shard_size # # shard_list = [] # for dp_id in range(dp_world_size): # curr_shard = partitioned_params[dp_id].narrow(0, shard_id * shard_size, num_elements).detach() # shard_list.append(curr_shard) # dist.all_gather(shard_list, shard_list[partition_id].contiguous(), dp_process_group[group_id]) class TLinear(torch.nn.Linear): def __init__(self, orig_layer, name=""): self.name = name super().__init__(orig_layer.weight.shape[1], orig_layer.weight.shape[0], bias=(orig_layer.bias is not None)) self.weight.data = transpose(orig_layer.weight.data) self.bias = orig_layer.bias self._fwd_func = self._fwd_bias_add if self.bias is not None else self._fwd def _fwd(self, input): return F.linear(input, self.weight) def _fwd_bias_add(self, input): return F.linear(input, self.weight, bias=self.bias) def forward(self, input): return self._fwd_func(input) def get_inactive_params(param_list): from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus return [param for param in param_list if (hasattr(param, 'ds_id') and \ param.ds_status == ZeroParamStatus.NOT_AVAILABLE)] def required_torch_version(min_version=None, max_version=None): assert min_version or max_version, "Must provide a min_version or max_version argument" torch_version = pkg_version.parse(torch.__version__) if min_version and pkg_version.parse(str(min_version)) > torch_version: return False if max_version and pkg_version.parse(str(max_version)) < torch_version: return False return True ================================================ FILE: opensora/adaptor/zp_manager.py ================================================ import torch import os import torch.distributed as dist class ZPManager(object): def __init__(self, zp_size=8): self.rank = int(os.getenv('RANK', '0')) self.world_size = int(os.getenv("WORLD_SIZE", '1')) self.zp_size = zp_size self.zp_group = None self.zp_rank = None self.is_initialized = False def init_group(self): if self.is_initialized: return self.is_initialized = True """Initialize the sequence parallel group.""" num_zp_groups: int = self.world_size // self.zp_size for i in range(num_zp_groups): ranks = range(i * self.zp_size, (i + 1) * self.zp_size) group = dist.new_group(ranks) if self.rank in ranks: self.zp_group = group self.zp_rank = self.rank % self.zp_size zp_manager = ZPManager() ================================================ FILE: opensora/dataset/__init__.py ================================================ from torchvision.transforms import Compose from transformers import AutoTokenizer, AutoImageProcessor from torchvision import transforms from torchvision.transforms import Lambda try: import torch_npu except: torch_npu = None from opensora.dataset.t2v_datasets import T2V_dataset from opensora.dataset.inpaint_dataset import Inpaint_dataset from opensora.models.causalvideovae import ae_norm, ae_denorm from opensora.dataset.transform import ToTensorVideo, TemporalRandomCrop, MaxHWResizeVideo, CenterCropResizeVideo, LongSideResizeVideo, SpatialStrideCropVideo, NormalizeVideo, ToTensorAfterResize def getdataset(args): temporal_sample = TemporalRandomCrop(args.num_frames) # 16 x norm_fun = ae_norm[args.ae] if args.force_resolution: resize = [CenterCropResizeVideo((args.max_height, args.max_width)), ] else: resize = [ MaxHWResizeVideo(args.max_hxw), SpatialStrideCropVideo(stride=args.hw_stride), ] tokenizer_1 = AutoTokenizer.from_pretrained(args.text_encoder_name_1, cache_dir=args.cache_dir) tokenizer_2 = None if args.text_encoder_name_2 is not None: tokenizer_2 = AutoTokenizer.from_pretrained(args.text_encoder_name_2, cache_dir=args.cache_dir) if args.dataset == 't2v': transform = transforms.Compose([ ToTensorVideo(), *resize, norm_fun ]) # also work for img, because img is video when frame=1 return T2V_dataset( args, transform=transform, temporal_sample=temporal_sample, tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2 ) elif args.dataset == 'i2v' or args.dataset == 'inpaint': resize_transform = Compose(resize) transform = Compose([ ToTensorAfterResize(), norm_fun, ]) return Inpaint_dataset( args, resize_transform=resize_transform, transform=transform, temporal_sample=temporal_sample, tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2 ) raise NotImplementedError(args.dataset) if __name__ == "__main__": ''' python opensora/dataset/__init__.py ''' from accelerate import Accelerator from opensora.dataset.t2v_datasets import dataset_prog from opensora.utils.dataset_utils import LengthGroupedSampler, Collate from torch.utils.data import DataLoader import random from torch import distributed as dist from tqdm import tqdm args = type('args', (), { 'ae': 'WFVAEModel_D32_4x8x8', 'dataset': 't2v', 'model_max_length': 512, 'max_height': 640, 'max_width': 640, 'hw_stride': 16, 'num_frames': 93, 'compress_kv_factor': 1, 'interpolation_scale_t': 1, 'interpolation_scale_h': 1, 'interpolation_scale_w': 1, 'cache_dir': '../cache_dir', 'data': '/home/image_data/gyy/mmdit/Open-Sora-Plan/scripts/train_data/current_hq_on_npu.txt', 'train_fps': 18, 'drop_short_ratio': 0.0, 'speed_factor': 1.0, 'cfg': 0.1, 'text_encoder_name_1': 'google/mt5-xxl', 'text_encoder_name_2': None, 'dataloader_num_workers': 8, 'force_resolution': False, 'use_decord': True, 'group_data': True, 'train_batch_size': 1, 'gradient_accumulation_steps': 1, 'ae_stride': 8, 'ae_stride_t': 4, 'patch_size': 2, 'patch_size_t': 1, 'total_batch_size': 256, 'sp_size': 1, 'max_hxw': 384*384, 'min_hxw': 384*288, # 'max_hxw': 236544, # 'min_hxw': 102400, } ) # accelerator = Accelerator() dataset = getdataset(args) # data = next(iter(dataset)) # import ipdb;ipdb.set_trace() # print() sampler = LengthGroupedSampler( args.train_batch_size, world_size=1, gradient_accumulation_size=args.gradient_accumulation_steps, initial_global_step=0, lengths=dataset.lengths, group_data=args.group_data, ) train_dataloader = DataLoader( dataset, shuffle=False, # pin_memory=True, collate_fn=Collate(args), batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers, sampler=sampler, drop_last=False, prefetch_factor=4 ) import ipdb;ipdb.set_trace() import imageio import numpy as np from einops import rearrange while True: for idx, i in enumerate(tqdm(train_dataloader)): pixel_values = i[0][0] pixel_values_ = (pixel_values+1)/2 pixel_values_ = rearrange(pixel_values_, 'c t h w -> t h w c') * 255.0 pixel_values_ = pixel_values_.numpy().astype(np.uint8) imageio.mimwrite(f'output{idx}.mp4', pixel_values_, fps=args.train_fps) dist.barrier() pass ================================================ FILE: opensora/dataset/inpaint_dataset.py ================================================ import time import traceback try: import torch_npu from opensora.npu_config import npu_config except: torch_npu = None npu_config = None import glob import json import pickle import os, io, csv, math, random import numpy as np import torchvision from einops import rearrange from os.path import join as opj from collections import Counter import cv2 import pandas as pd import time import torch import torchvision.transforms as transforms from torch.utils.data.dataset import Dataset from torch.utils.data import DataLoader, Dataset, get_worker_info from tqdm import tqdm from PIL import Image from accelerate.logging import get_logger import gc import decord from opensora.utils.dataset_utils import DecordInit from opensora.utils.utils import text_preprocessing from opensora.dataset.transform import get_params, maxhwresize, add_masking_notice, calculate_statistics, \ add_aesthetic_notice_image, add_aesthetic_notice_video from opensora.utils.mask_utils import MaskProcessor, STR_TO_TYPE from opensora.dataset.t2v_datasets import T2V_dataset, DataSetProg logger = get_logger(__name__) dataset_prog = DataSetProg() def type_ratio_normalize(mask_type_ratio_dict): for k, v in mask_type_ratio_dict.items(): assert v >= 0, f"mask_type_ratio_dict[{k}] should be non-negative, but got {v}" total = sum(mask_type_ratio_dict.values()) length = len(mask_type_ratio_dict) if total == 0: return {k: 1.0 / length for k in mask_type_ratio_dict.keys()} return {k: v / total for k, v in mask_type_ratio_dict.items()} class Inpaint_dataset(T2V_dataset): def __init__(self, args, resize_transform, transform, temporal_sample, tokenizer_1, tokenizer_2): super().__init__( args=args, transform=transform, temporal_sample=temporal_sample, tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2 ) self.resize_transform = resize_transform if self.num_frames != 1: self.mask_type_ratio_dict_video = args.mask_type_ratio_dict_video if args.mask_type_ratio_dict_video is not None else {'random_temporal': 1.0} self.mask_type_ratio_dict_video = {STR_TO_TYPE[k]: v for k, v in self.mask_type_ratio_dict_video.items()} self.mask_type_ratio_dict_video = type_ratio_normalize(self.mask_type_ratio_dict_video) self.mask_type_ratio_dict_image = args.mask_type_ratio_dict_image if args.mask_type_ratio_dict_image is not None else {'random_spatial': 1.0} self.mask_type_ratio_dict_image = {STR_TO_TYPE[k]: v for k, v in self.mask_type_ratio_dict_image.items()} self.mask_type_ratio_dict_image = type_ratio_normalize(self.mask_type_ratio_dict_image) print(f"mask_type_ratio_dict_video: {self.mask_type_ratio_dict_video}") print(f"mask_type_ratio_dict_image: {self.mask_type_ratio_dict_image}") self.mask_processor = MaskProcessor( max_height=args.max_height, max_width=args.max_width, min_clear_ratio=args.min_clear_ratio, max_clear_ratio=args.max_clear_ratio, ) self.default_text_ratio = args.default_text_ratio def __getitem__(self, idx): try: # future = self.executor.submit(self.get_data, idx) # data = future.result(timeout=self.timeout) # return data return self.get_data(idx) except Exception as e: # if len(str(e)) < 2: # e = f"TimeoutError, {self.timeout}s timeout occur with {dataset_prog.cap_list[idx]['path']}" print(f'Error with {e}') index_cand = self.shape_idx_dict[self.sample_size[idx]] # pick same shape return self.__getitem__(random.choice(index_cand)) # return self.__getitem__(idx) def get_data(self, idx): path = dataset_prog.cap_list[idx]['path'] if not os.path.exists(path): print(f"file {path} do not exist, random choice a new one with same shape!") index_cand = self.shape_idx_dict[self.sample_size[idx]] return self.__getitem__(random.choice(index_cand)) if path.endswith('.mp4'): return self.get_video(idx) else: return self.get_image(idx) def drop(self, text, is_video=True): rand_num = random.random() rand_num_text = random.random() if rand_num < self.cfg: if rand_num_text < self.default_text_ratio: if not is_video: text = "The image showcases a scene with coherent and clear visuals." else: text = "The video showcases a scene with coherent and clear visuals." else: text = '' return dict(text=text) def get_video(self, idx): # npu_config.print_msg(f"current idx is {idx}") # video = random.choice([random_video_noise(65, 3, 336, 448), random_video_noise(65, 3, 1024, 1024), random_video_noise(65, 3, 360, 480)]) # # print('random shape', video.shape) # input_ids = torch.ones(1, 120).to(torch.long).squeeze(0) # cond_mask = torch.cat([torch.ones(1, 60).to(torch.long), torch.ones(1, 60).to(torch.long)], dim=1).squeeze(0) # logger.info(f'Now we use t2v dataset {idx}') video_data = dataset_prog.cap_list[idx] video_path = video_data['path'] # assert os.path.exists(video_path), f"file {video_path} do not exist!" sample_h = video_data['resolution']['sample_height'] sample_w = video_data['resolution']['sample_width'] if self.video_reader == 'decord': video = self.decord_read(video_data) elif self.video_reader == 'opencv': video = self.opencv_read(video_data) else: NotImplementedError(f'Found {self.video_reader}, but support decord or opencv') # import ipdb;ipdb.set_trace() video = self.resize_transform(video) # T C H W -> T C H W assert video.shape[2] == sample_h and video.shape[3] == sample_w, f'sample_h ({sample_h}), sample_w ({sample_w}), video ({video.shape}), video_path ({video_path})' inpaint_cond_data = self.mask_processor(video, mask_type_ratio_dict=self.mask_type_ratio_dict_video) mask, masked_video = inpaint_cond_data['mask'], inpaint_cond_data['masked_pixel_values'] video = self.transform(video) # T C H W -> T C H W masked_video = self.transform(masked_video) # T C H W -> T C H W video = torch.cat([video, masked_video, mask], dim=1) # T 2C+1 H W video = video.transpose(0, 1) # T C H W -> C T H W text = video_data['cap'] if not isinstance(text, list): text = [text] text = [random.choice(text)] if video_data.get('aesthetic', None) is not None or video_data.get('aes', None) is not None: aes = video_data.get('aesthetic', None) or video_data.get('aes', None) text = [add_aesthetic_notice_video(text[0], aes)] text = self.drop(text, is_video=True)['text'] text_tokens_and_mask_1 = self.tokenizer_1( text, max_length=self.model_max_length, padding='max_length', truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors='pt' ) input_ids_1 = text_tokens_and_mask_1['input_ids'] cond_mask_1 = text_tokens_and_mask_1['attention_mask'] input_ids_2, cond_mask_2 = None, None if self.tokenizer_2 is not None: text_tokens_and_mask_2 = self.tokenizer_2( text, max_length=self.tokenizer_2.model_max_length, padding='max_length', truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors='pt' ) input_ids_2 = text_tokens_and_mask_2['input_ids'] cond_mask_2 = text_tokens_and_mask_2['attention_mask'] return dict( pixel_values=video, input_ids_1=input_ids_1, cond_mask_1=cond_mask_1, input_ids_2=input_ids_2, cond_mask_2=cond_mask_2, ) def get_image(self, idx): image_data = dataset_prog.cap_list[idx] # [{'path': path, 'cap': cap}, ...] sample_h = image_data['resolution']['sample_height'] sample_w = image_data['resolution']['sample_width'] image = Image.open(image_data['path']).convert('RGB') # [h, w, c] image = torch.from_numpy(np.array(image)) # [h, w, c] image = rearrange(image, 'h w c -> c h w').unsqueeze(0) # [1 c h w] image = self.resize_transform(image) # [1 c h w] assert image.shape[2] == sample_h, image.shape[3] == sample_w inpaint_cond_data = self.mask_processor(image, mask_type_ratio_dict=self.mask_type_ratio_dict_image) mask, masked_image = inpaint_cond_data['mask'], inpaint_cond_data['masked_pixel_values'] image = self.transform(image) masked_image = self.transform(masked_image) image = torch.cat([image, masked_image, mask], dim=1) # [1 2C+1 H W] # image = [torch.rand(1, 3, 480, 640) for i in image_data] image = image.transpose(0, 1) # [1 C H W] -> [C 1 H W] caps = image_data['cap'] if isinstance(image_data['cap'], list) else [image_data['cap']] caps = [random.choice(caps)] # caps = [caps[0]] if '/sam/' in image_data['path']: caps = [add_masking_notice(caps[0])] if image_data.get('aesthetic', None) is not None or image_data.get('aes', None) is not None: aes = image_data.get('aesthetic', None) or image_data.get('aes', None) caps = [add_aesthetic_notice_image(caps[0], aes)] text = text_preprocessing(caps, support_Chinese=self.support_Chinese) text = self.drop(text, is_video=False)['text'] text_tokens_and_mask_1 = self.tokenizer_1( text, max_length=self.model_max_length, padding='max_length', truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors='pt' ) input_ids_1 = text_tokens_and_mask_1['input_ids'] # 1, l cond_mask_1 = text_tokens_and_mask_1['attention_mask'] # 1, l input_ids_2, cond_mask_2 = None, None if self.tokenizer_2 is not None: text_tokens_and_mask_2 = self.tokenizer_2( text, max_length=self.tokenizer_2.model_max_length, padding='max_length', truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors='pt' ) input_ids_2 = text_tokens_and_mask_2['input_ids'] # 1, l cond_mask_2 = text_tokens_and_mask_2['attention_mask'] # 1, l return dict( pixel_values=image, input_ids_1=input_ids_1, cond_mask_1=cond_mask_1, motion_score=None, input_ids_2=input_ids_2, cond_mask_2=cond_mask_2 ) ================================================ FILE: opensora/dataset/t2v_datasets.py ================================================ import time import traceback try: import torch_npu from opensora.npu_config import npu_config except: torch_npu = None npu_config = None import glob import json import pickle import os, io, csv, math, random import numpy as np import torchvision from einops import rearrange from os.path import join as opj from collections import Counter import cv2 import pandas as pd import time import torch import torchvision.transforms as transforms from torch.utils.data.dataset import Dataset from torch.utils.data import DataLoader, Dataset, get_worker_info from tqdm import tqdm from PIL import Image from accelerate.logging import get_logger import gc from opensora.utils.dataset_utils import DecordInit from opensora.utils.utils import text_preprocessing from opensora.dataset.transform import get_params, maxhwresize, add_masking_notice, calculate_statistics, \ add_aesthetic_notice_image, add_aesthetic_notice_video import decord from concurrent.futures import ThreadPoolExecutor, TimeoutError logger = get_logger(__name__) def filter_json_by_existed_files(directory, data, postfix=".mp4"): # 构建搜索模式,以匹配指定后缀的文件 pattern = os.path.join(directory, '**', f'*{postfix}') mp4_files = glob.glob(pattern, recursive=True) # 使用glob查找所有匹配的文件 # 使用文件的绝对路径构建集合 mp4_files_set = set(os.path.abspath(path) for path in mp4_files) # 过滤数据条目,只保留路径在mp4文件集合中的条目 filtered_items = [item for item in data if item['path'] in mp4_files_set] return filtered_items def random_video_noise(t, c, h, w): vid = torch.rand(t, c, h, w) * 255.0 vid = vid.to(torch.uint8) return vid class SingletonMeta(type): """ 这是一个元类,用于创建单例类。 """ _instances = {} def __call__(cls, *args, **kwargs): if cls not in cls._instances: instance = super().__call__(*args, **kwargs) cls._instances[cls] = instance return cls._instances[cls] class DataSetProg(metaclass=SingletonMeta): def __init__(self): self.cap_list = [] self.elements = [] self.num_workers = 1 self.n_elements = 0 self.worker_elements = dict() self.n_used_elements = dict() def set_cap_list(self, num_workers, cap_list, n_elements): self.num_workers = num_workers self.cap_list = cap_list self.n_elements = n_elements self.elements = list(range(n_elements)) print(f"n_elements: {len(self.elements)}", flush=True) # if torch_npu is not None: # random.shuffle(self.elements) # for i in range(self.num_workers): # self.n_used_elements[i] = 0 # per_worker = int(math.ceil(len(self.elements) / float(self.num_workers))) # start = i * per_worker # end = min(start + per_worker, len(self.elements)) # self.worker_elements[i] = self.elements[start: end] def get_item(self, work_info): if work_info is None: worker_id = 0 else: worker_id = work_info.id idx = self.worker_elements[worker_id][self.n_used_elements[worker_id] % len(self.worker_elements[worker_id])] self.n_used_elements[worker_id] += 1 return idx dataset_prog = DataSetProg() def find_closest_y(x, vae_stride_t=4, model_ds_t=1): min_num_frames = 29 if x < min_num_frames: return -1 for y in range(x, min_num_frames - 1, -1): if (y - 1) % vae_stride_t == 0 and ((y - 1) // vae_stride_t + 1) % model_ds_t == 0: # 4, 8: y in [29, 61, 93, 125, 157, 189, 221, 253, 285, 317, 349, 381, 413, 445, 477, 509, ...] # 4, 4: y in [29, 45, 61, 77, 93, 109, 125, 141, 157, 173, 189, 205, 221, 237, 253, 269, 285, 301, 317, 333, 349, 365, 381, 397, 413, 429, 445, 461, 477, 493, 509, ...] # 8, 1: y in [33, 41, 49, 57, 65, 73, 81, 89, 97, 105] # 8, 2: y in [41, 57, 73, 89, 105] # 8, 4: y in [57, 89] # 8, 8: y in [57] return y return -1 def filter_resolution(h, w, max_h_div_w_ratio=17/16, min_h_div_w_ratio=8 / 16): if h / w <= max_h_div_w_ratio and h / w >= min_h_div_w_ratio: return True return False def read_parquet(path): df = pd.read_parquet(path) data = df.to_dict(orient='records') return data class DecordDecoder(object): def __init__(self, url, num_threads=1): self.num_threads = num_threads self.ctx = decord.cpu(0) self.reader = decord.VideoReader(url, ctx=self.ctx, num_threads=self.num_threads) def get_avg_fps(self): return self.reader.get_avg_fps() if self.reader.get_avg_fps() > 0 else 30.0 def get_num_frames(self): return len(self.reader) def get_height(self): return self.reader[0].shape[0] if self.get_num_frames() > 0 else 0 def get_width(self): return self.reader[0].shape[1] if self.get_num_frames() > 0 else 0 # output shape [T, H, W, C] def get_batch(self, frame_indices): try: #frame_indices[0] = 1000 video_data = self.reader.get_batch(frame_indices).asnumpy() video_data = torch.from_numpy(video_data) return video_data except Exception as e: print('get_batch execption:', e) return None class T2V_dataset(Dataset): def __init__(self, args, transform, temporal_sample, tokenizer_1, tokenizer_2): self.data = args.data self.num_frames = args.num_frames self.train_fps = args.train_fps self.transform = transform self.temporal_sample = temporal_sample self.tokenizer_1 = tokenizer_1 self.tokenizer_2 = tokenizer_2 self.model_max_length = args.model_max_length self.cfg = args.cfg self.speed_factor = args.speed_factor self.max_height = args.max_height self.max_width = args.max_width self.drop_short_ratio = args.drop_short_ratio self.hw_stride = args.hw_stride self.force_resolution = args.force_resolution self.max_hxw = args.max_hxw self.min_hxw = args.min_hxw self.sp_size = args.sp_size assert self.speed_factor >= 1 self.video_reader = 'decord' if args.use_decord else 'opencv' self.ae_stride_t = args.ae_stride_t self.total_batch_size = args.total_batch_size self.seed = 42 self.generator = torch.Generator().manual_seed(self.seed) self.hw_aspect_thr = 2.0 # just a threshold self.too_long_factor = 5.0 self.support_Chinese = False if 'mt5' in args.text_encoder_name_1: self.support_Chinese = True if args.text_encoder_name_2 is not None and 'mt5' in args.text_encoder_name_2: self.support_Chinese = True s = time.time() cap_list, self.sample_size, self.shape_idx_dict = self.define_frame_index(self.data) e = time.time() print(f'Build data time: {e-s}') self.lengths = self.sample_size n_elements = len(cap_list) dataset_prog.set_cap_list(args.dataloader_num_workers, cap_list, n_elements) print(f"Data length: {len(dataset_prog.cap_list)}") self.executor = ThreadPoolExecutor(max_workers=1) self.timeout = 60 def set_checkpoint(self, n_used_elements): for i in range(len(dataset_prog.n_used_elements)): dataset_prog.n_used_elements[i] = n_used_elements def __len__(self): return dataset_prog.n_elements def __getitem__(self, idx): try: future = self.executor.submit(self.get_data, idx) data = future.result(timeout=self.timeout) # data = self.get_data(idx) return data except Exception as e: if len(str(e)) < 2: e = f"TimeoutError, {self.timeout}s timeout occur with {dataset_prog.cap_list[idx]['path']}" print(f'Error with {e}') index_cand = self.shape_idx_dict[self.sample_size[idx]] # pick same shape return self.__getitem__(random.choice(index_cand)) def get_data(self, idx): path = dataset_prog.cap_list[idx]['path'] if path.endswith('.mp4'): return self.get_video(idx) else: return self.get_image(idx) def get_video(self, idx): video_data = dataset_prog.cap_list[idx] video_path = video_data['path'] assert os.path.exists(video_path), f"file {video_path} do not exist!" sample_h = video_data['resolution']['sample_height'] sample_w = video_data['resolution']['sample_width'] if self.video_reader == 'decord': video = self.decord_read(video_data) elif self.video_reader == 'opencv': video = self.opencv_read(video_data) else: NotImplementedError(f'Found {self.video_reader}, but support decord or opencv') # import ipdb;ipdb.set_trace() video = self.transform(video) # T C H W -> T C H W assert video.shape[2] == sample_h and video.shape[3] == sample_w, f'sample_h ({sample_h}), sample_w ({sample_w}), video ({video.shape})' # video = torch.rand(105, 3, 640, 640) video = video.transpose(0, 1) # T C H W -> C T H W text = video_data['cap'] if not isinstance(text, list): text = [text] text = [random.choice(text)] if video_data.get('aesthetic', None) is not None or video_data.get('aes', None) is not None: aes = video_data.get('aesthetic', None) or video_data.get('aes', None) text = [add_aesthetic_notice_video(text[0], aes)] text = text_preprocessing(text, support_Chinese=self.support_Chinese) text = text if random.random() > self.cfg else "" text_tokens_and_mask_1 = self.tokenizer_1( text, max_length=self.model_max_length, padding='max_length', truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors='pt' ) input_ids_1 = text_tokens_and_mask_1['input_ids'] cond_mask_1 = text_tokens_and_mask_1['attention_mask'] input_ids_2, cond_mask_2 = None, None if self.tokenizer_2 is not None: text_tokens_and_mask_2 = self.tokenizer_2( text, max_length=self.tokenizer_2.model_max_length, padding='max_length', truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors='pt' ) input_ids_2 = text_tokens_and_mask_2['input_ids'] cond_mask_2 = text_tokens_and_mask_2['attention_mask'] return dict( pixel_values=video, input_ids_1=input_ids_1, cond_mask_1=cond_mask_1, input_ids_2=input_ids_2, cond_mask_2=cond_mask_2, ) def get_image(self, idx): image_data = dataset_prog.cap_list[idx] # [{'path': path, 'cap': cap}, ...] sample_h = image_data['resolution']['sample_height'] sample_w = image_data['resolution']['sample_width'] image = Image.open(image_data['path']).convert('RGB') # [h, w, c] image = torch.from_numpy(np.array(image)) # [h, w, c] image = rearrange(image, 'h w c -> c h w').unsqueeze(0) # [1 c h w] image = self.transform(image) # [1 C H W] -> num_img [1 C H W] assert image.shape[2] == sample_h and image.shape[3] == sample_w, f"image_data: {image_data}, but found image {image.shape}" # image = torch.rand(1, 3, sample_h, sample_w) image = image.transpose(0, 1) # [1 C H W] -> [C 1 H W] caps = image_data['cap'] if isinstance(image_data['cap'], list) else [image_data['cap']] caps = [random.choice(caps)] if image_data.get('aesthetic', None) is not None or image_data.get('aes', None) is not None: aes = image_data.get('aesthetic', None) or image_data.get('aes', None) caps = [add_aesthetic_notice_image(caps[0], aes)] text = text_preprocessing(caps, support_Chinese=self.support_Chinese) text = text if random.random() > self.cfg else "" text_tokens_and_mask_1 = self.tokenizer_1( text, max_length=self.model_max_length, padding='max_length', truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors='pt' ) input_ids_1 = text_tokens_and_mask_1['input_ids'] # 1, l cond_mask_1 = text_tokens_and_mask_1['attention_mask'] # 1, l input_ids_2, cond_mask_2 = None, None if self.tokenizer_2 is not None: text_tokens_and_mask_2 = self.tokenizer_2( text, max_length=self.tokenizer_2.model_max_length, padding='max_length', truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors='pt' ) input_ids_2 = text_tokens_and_mask_2['input_ids'] # 1, l cond_mask_2 = text_tokens_and_mask_2['attention_mask'] # 1, l return dict( pixel_values=image, input_ids_1=input_ids_1, cond_mask_1=cond_mask_1, input_ids_2=input_ids_2, cond_mask_2=cond_mask_2 ) def define_frame_index(self, data): shape_idx_dict = {} new_cap_list = [] sample_size = [] aesthetic_score = [] cnt_vid = 0 cnt_img = 0 cnt_too_long = 0 cnt_too_short = 0 cnt_no_cap = 0 cnt_no_resolution = 0 cnt_no_aesthetic = 0 cnt_img_res_mismatch_stride = 0 cnt_vid_res_mismatch_stride = 0 cnt_img_aspect_mismatch = 0 cnt_vid_aspect_mismatch = 0 cnt_img_res_too_small = 0 cnt_vid_res_too_small = 0 cnt_vid_after_filter = 0 cnt_img_after_filter = 0 cnt = 0 with open(data, 'r') as f: folder_anno = [i.strip().split(',') for i in f.readlines() if len(i.strip()) > 0] for sub_root, anno in tqdm(folder_anno): print(f'Building {anno}...') if anno.endswith('.json'): with open(anno, 'r') as f: sub_list = json.load(f) elif anno.endswith('.pkl'): with open(anno, "rb") as f: sub_list = pickle.load(f) for index, i in enumerate(tqdm(sub_list)): cnt += 1 path = os.path.join(sub_root, i['path']) i['path'] = path if path.endswith('.mp4'): cnt_vid += 1 elif path.endswith('.jpg'): cnt_img += 1 # ======no aesthetic===== if i.get('aesthetic', None) is None or i.get('aes', None) is None: cnt_no_aesthetic += 1 else: aesthetic_score.append(i.get('aesthetic', None) or i.get('aes', None)) # ======no caption===== cap = i.get('cap', None) if cap is None: cnt_no_cap += 1 continue # ======resolution mismatch===== if i.get('resolution', None) is None: cnt_no_resolution += 1 continue else: if i['resolution'].get('height', None) is None or i['resolution'].get('width', None) is None: cnt_no_resolution += 1 continue else: height, width = i['resolution']['height'], i['resolution']['width'] if not self.force_resolution: if height <= 0 or width <= 0: cnt_no_resolution += 1 continue tr_h, tr_w = maxhwresize(height, width, self.max_hxw) _, _, sample_h, sample_w = get_params(tr_h, tr_w, self.hw_stride) if sample_h <= 0 or sample_w <= 0: if path.endswith('.mp4'): cnt_vid_res_mismatch_stride += 1 elif path.endswith('.jpg'): cnt_img_res_mismatch_stride += 1 continue # filter min_hxw if sample_h * sample_w < self.min_hxw: if path.endswith('.mp4'): cnt_vid_res_too_small += 1 elif path.endswith('.jpg'): cnt_img_res_too_small += 1 continue # filter aspect is_pick = filter_resolution( sample_h, sample_w, max_h_div_w_ratio=self.hw_aspect_thr, min_h_div_w_ratio=1/self.hw_aspect_thr ) if not is_pick: if path.endswith('.mp4'): cnt_vid_aspect_mismatch += 1 elif path.endswith('.jpg'): cnt_img_aspect_mismatch += 1 continue i['resolution'].update(dict(sample_height=sample_h, sample_width=sample_w)) else: aspect = self.max_height / self.max_width is_pick = filter_resolution( height, width, max_h_div_w_ratio=self.hw_aspect_thr*aspect, min_h_div_w_ratio=1/self.hw_aspect_thr*aspect ) if not is_pick: if path.endswith('.mp4'): cnt_vid_aspect_mismatch += 1 elif path.endswith('.jpg'): cnt_img_aspect_mismatch += 1 continue sample_h, sample_w = self.max_height, self.max_width i['resolution'].update(dict(sample_height=sample_h, sample_width=sample_w)) if path.endswith('.mp4'): fps = i.get('fps', 24) # max 5.0 and min 1.0 are just thresholds to filter some videos which have suitable duration. if i['num_frames'] > self.too_long_factor * (self.num_frames * fps / self.train_fps * self.speed_factor): # too long video is not suitable for this training stage (self.num_frames) cnt_too_long += 1 continue # resample in case high fps, such as 50/60/90/144 -> train_fps(e.g, 24) frame_interval = 1.0 if abs(fps - self.train_fps) < 0.1 else fps / self.train_fps start_frame_idx = i.get('cut', [0])[0] i['start_frame_idx'] = start_frame_idx frame_indices = np.arange(start_frame_idx, start_frame_idx+i['num_frames'], frame_interval).astype(int) frame_indices = frame_indices[frame_indices < start_frame_idx+i['num_frames']] # comment out it to enable dynamic frames training if len(frame_indices) < self.num_frames and torch.rand(1, generator=self.generator).item() < self.drop_short_ratio: cnt_too_short += 1 continue # too long video will be temporal-crop randomly if len(frame_indices) > self.num_frames: begin_index, end_index = self.temporal_sample(len(frame_indices)) frame_indices = frame_indices[begin_index: end_index] # frame_indices = frame_indices[:self.num_frames] # head crop # to find a suitable end_frame_idx, to ensure we do not need pad video end_frame_idx = find_closest_y( len(frame_indices), vae_stride_t=self.ae_stride_t, model_ds_t=self.sp_size ) if end_frame_idx == -1: # too short that can not be encoded exactly by videovae cnt_too_short += 1 continue frame_indices = frame_indices[:end_frame_idx] i['sample_frame_index'] = frame_indices.tolist() new_cap_list.append(i) cnt_vid_after_filter += 1 elif path.endswith('.jpg'): # image cnt_img_after_filter += 1 i['sample_frame_index'] = [0] new_cap_list.append(i) else: raise NameError(f"Unknown file extention {path.split('.')[-1]}, only support .mp4 for video and .jpg for image") pre_define_shape = f"{len(i['sample_frame_index'])}x{sample_h}x{sample_w}" sample_size.append(pre_define_shape) # if shape_idx_dict.get(pre_define_shape, None) is None: # shape_idx_dict[pre_define_shape] = [index] # else: # shape_idx_dict[pre_define_shape].append(index) counter = Counter(sample_size) counter_cp = counter if not self.force_resolution and self.max_hxw is not None and self.min_hxw is not None: assert all([np.prod(np.array(k.split('x')[1:]).astype(np.int32)) <= self.max_hxw for k in counter_cp.keys()]) assert all([np.prod(np.array(k.split('x')[1:]).astype(np.int32)) >= self.min_hxw for k in counter_cp.keys()]) len_before_filter_major = len(sample_size) filter_major_num = 4 * self.total_batch_size new_cap_list, sample_size = zip(*[[i, j] for i, j in zip(new_cap_list, sample_size) if counter[j] >= filter_major_num]) for idx, shape in enumerate(sample_size): if shape_idx_dict.get(shape, None) is None: shape_idx_dict[shape] = [idx] else: shape_idx_dict[shape].append(idx) cnt_filter_minority = len_before_filter_major - len(sample_size) counter = Counter(sample_size) print(f'no_cap: {cnt_no_cap}, no_resolution: {cnt_no_resolution}\n' f'too_long: {cnt_too_long}, too_short: {cnt_too_short}\n' f'cnt_img_res_mismatch_stride: {cnt_img_res_mismatch_stride}, cnt_vid_res_mismatch_stride: {cnt_vid_res_mismatch_stride}\n' f'cnt_img_res_too_small: {cnt_img_res_too_small}, cnt_vid_res_too_small: {cnt_vid_res_too_small}\n' f'cnt_img_aspect_mismatch: {cnt_img_aspect_mismatch}, cnt_vid_aspect_mismatch: {cnt_vid_aspect_mismatch}\n' f'cnt_filter_minority: {cnt_filter_minority}\n' f'Counter(sample_size): {counter}\n' f'cnt_vid: {cnt_vid}, cnt_vid_after_filter: {cnt_vid_after_filter}, use_ratio: {round(cnt_vid_after_filter/(cnt_vid+1e-6), 5)*100}%\n' f'cnt_img: {cnt_img}, cnt_img_after_filter: {cnt_img_after_filter}, use_ratio: {round(cnt_img_after_filter/(cnt_img+1e-6), 5)*100}%\n' f'before filter: {cnt}, after filter: {len(new_cap_list)}, use_ratio: {round(len(new_cap_list)/cnt, 5)*100}%') # import ipdb;ipdb.set_trace() if len(aesthetic_score) > 0: stats_aesthetic = calculate_statistics(aesthetic_score) print(f"before filter: {cnt}, after filter: {len(new_cap_list)}\n" f"aesthetic_score: {len(aesthetic_score)}, cnt_no_aesthetic: {cnt_no_aesthetic}\n" f"{len([i for i in aesthetic_score if i>=5.75])} > 5.75, 4.5 > {len([i for i in aesthetic_score if i<=4.5])}\n" f"Mean: {stats_aesthetic['mean']}, Var: {stats_aesthetic['variance']}, Std: {stats_aesthetic['std_dev']}\n" f"Min: {stats_aesthetic['min']}, Max: {stats_aesthetic['max']}") return new_cap_list, sample_size, shape_idx_dict def decord_read(self, video_data): path = video_data['path'] predefine_frame_indice = video_data['sample_frame_index'] start_frame_idx = video_data['start_frame_idx'] clip_total_frames = video_data['num_frames'] fps = video_data['fps'] s_x, e_x, s_y, e_y = video_data.get('crop', [None, None, None, None]) predefine_num_frames = len(predefine_frame_indice) # decord_vr = decord.VideoReader(path, ctx=decord.cpu(0), num_threads=1) decord_vr = DecordDecoder(path) frame_indices = self.get_actual_frame( fps, start_frame_idx, clip_total_frames, path, predefine_num_frames, predefine_frame_indice ) # video_data = decord_vr.get_batch(frame_indices).asnumpy() # video_data = torch.from_numpy(video_data) video_data = decord_vr.get_batch(frame_indices) if video_data is not None: video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W) if s_y is not None: video_data = video_data[:, :, s_y: e_y, s_x: e_x] else: raise ValueError(f'Get video_data {video_data}') # del decord_vr # gc.collect() return video_data def opencv_read(self, video_data): path = video_data['path'] predefine_frame_indice = video_data['sample_frame_index'] start_frame_idx = video_data['start_frame_idx'] clip_total_frames = video_data['num_frames'] fps = video_data['fps'] s_x, e_x, s_y, e_y = video_data.get('crop', [None, None, None, None]) predefine_num_frames = len(predefine_frame_indice) cv2_vr = cv2.VideoCapture(path) if not cv2_vr.isOpened(): raise ValueError(f'can not open {path}') frame_indices = self.get_actual_frame( fps, start_frame_idx, clip_total_frames, path, predefine_num_frames, predefine_frame_indice ) video_data = [] for frame_idx in frame_indices: cv2_vr.set(1, frame_idx) _, frame = cv2_vr.read() frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) video_data.append(torch.from_numpy(frame).permute(2, 0, 1)) cv2_vr.release() video_data = torch.stack(video_data) # (T C H W) if s_y is not None: video_data = video_data[:, :, s_y: e_y, s_x: e_x] return video_data def get_actual_frame(self, fps, start_frame_idx, clip_total_frames, path, predefine_num_frames, predefine_frame_indice): # resample in case high fps, such as 50/60/90/144 -> train_fps(e.g, 24) frame_interval = 1.0 if abs(fps - self.train_fps) < 0.1 else fps / self.train_fps frame_indices = np.arange(start_frame_idx, start_frame_idx+clip_total_frames, frame_interval).astype(int) frame_indices = frame_indices[frame_indices < start_frame_idx+clip_total_frames] # speed up max_speed_factor = len(frame_indices) / self.num_frames if self.speed_factor > 1 and max_speed_factor > 1: # speed_factor = random.uniform(1.0, min(self.speed_factor, max_speed_factor)) speed_factor = min(self.speed_factor, max_speed_factor) target_frame_count = int(len(frame_indices) / speed_factor) speed_frame_idx = np.linspace(0, len(frame_indices) - 1, target_frame_count, dtype=int) frame_indices = frame_indices[speed_frame_idx] # too long video will be temporal-crop randomly if len(frame_indices) > self.num_frames: begin_index, end_index = self.temporal_sample(len(frame_indices)) frame_indices = frame_indices[begin_index: end_index] # frame_indices = frame_indices[:self.num_frames] # head crop # to find a suitable end_frame_idx, to ensure we do not need pad video end_frame_idx = find_closest_y( len(frame_indices), vae_stride_t=self.ae_stride_t, model_ds_t=self.sp_size ) if end_frame_idx == -1: # too short that can not be encoded exactly by videovae raise IndexError(f'video ({path}) has {clip_total_frames} frames, but need to sample {len(frame_indices)} frames ({frame_indices})') frame_indices = frame_indices[:end_frame_idx] if predefine_num_frames != len(frame_indices): raise ValueError(f'video ({path}) predefine_num_frames ({predefine_num_frames}) ({predefine_frame_indice}) is not equal with frame_indices ({len(frame_indices)}) ({frame_indices})') if len(frame_indices) < self.num_frames and self.drop_short_ratio >= 1: raise IndexError(f'video ({path}) has {clip_total_frames} frames, but need to sample {len(frame_indices)} frames ({frame_indices})') return frame_indices ================================================ FILE: opensora/dataset/transform.py ================================================ import torch import random import numbers from torchvision.transforms import RandomCrop, RandomResizedCrop import statistics import numpy as np import ftfy import regex as re import html def _is_tensor_video_clip(clip): if not torch.is_tensor(clip): raise TypeError("clip should be Tensor. Got %s" % type(clip)) if not clip.ndimension() == 4: raise ValueError("clip should be 4D. Got %dD" % clip.dim()) return True def center_crop_arr(pil_image, image_size): """ Center cropping implementation from ADM. https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 """ while min(*pil_image.size) >= 2 * image_size: pil_image = pil_image.resize( tuple(x // 2 for x in pil_image.size), resample=Image.BOX ) scale = image_size / min(*pil_image.size) pil_image = pil_image.resize( tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC ) arr = np.array(pil_image) crop_y = (arr.shape[0] - image_size) // 2 crop_x = (arr.shape[1] - image_size) // 2 return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) def crop(clip, i, j, h, w): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) """ if len(clip.size()) != 4: raise ValueError("clip should be a 4D tensor") return clip[..., i: i + h, j: j + w] def resize(clip, target_size, interpolation_mode): if len(target_size) != 2: raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=True, antialias=True) def resize_scale(clip, target_size, interpolation_mode): if len(target_size) != 2: raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") H, W = clip.size(-2), clip.size(-1) scale_ = target_size[0] / min(H, W) return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=True, antialias=True) def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): """ Do spatial cropping and resizing to the video clip Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) i (int): i in (i,j) i.e coordinates of the upper left corner. j (int): j in (i,j) i.e coordinates of the upper left corner. h (int): Height of the cropped region. w (int): Width of the cropped region. size (tuple(int, int)): height and width of resized clip Returns: clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W) """ if not _is_tensor_video_clip(clip): raise ValueError("clip should be a 4D torch.tensor") clip = crop(clip, i, j, h, w) clip = resize(clip, size, interpolation_mode) return clip def center_crop(clip, crop_size): if not _is_tensor_video_clip(clip): raise ValueError("clip should be a 4D torch.tensor") h, w = clip.size(-2), clip.size(-1) th, tw = crop_size if h < th or w < tw: raise ValueError("height and width must be no smaller than crop_size") i = int(round((h - th) / 2.0)) j = int(round((w - tw) / 2.0)) return crop(clip, i, j, th, tw) def center_crop_using_short_edge(clip): if not _is_tensor_video_clip(clip): raise ValueError("clip should be a 4D torch.tensor") h, w = clip.size(-2), clip.size(-1) if h < w: th, tw = h, h i = 0 j = int(round((w - tw) / 2.0)) else: th, tw = w, w i = int(round((h - th) / 2.0)) j = 0 return crop(clip, i, j, th, tw) def center_crop_th_tw(clip, th, tw, top_crop): if not _is_tensor_video_clip(clip): raise ValueError("clip should be a 4D torch.tensor") # import ipdb;ipdb.set_trace() h, w = clip.size(-2), clip.size(-1) tr = th / tw if h / w > tr: # hxw 720x1280 thxtw 320x640 hw_raito 9/16 > tr_ratio 8/16 newh=1280*320/640=640 neww=1280 new_h = int(w * tr) new_w = w else: # hxw 720x1280 thxtw 480x640 hw_raito 9/16 < tr_ratio 12/16 newh=720 neww=720/(12/16)=960 # hxw 1080x1920 thxtw 720x1280 hw_raito 9/16 = tr_ratio 9/16 newh=1080 neww=1080/(9/16)=1920 new_h = h new_w = int(h / tr) i = 0 if top_crop else int(round((h - new_h) / 2.0)) j = int(round((w - new_w) / 2.0)) return crop(clip, i, j, new_h, new_w) def random_shift_crop(clip): ''' Slide along the long edge, with the short edge as crop size ''' if not _is_tensor_video_clip(clip): raise ValueError("clip should be a 4D torch.tensor") h, w = clip.size(-2), clip.size(-1) if h <= w: long_edge = w short_edge = h else: long_edge = h short_edge = w th, tw = short_edge, short_edge i = torch.randint(0, h - th + 1, size=(1,)).item() j = torch.randint(0, w - tw + 1, size=(1,)).item() return crop(clip, i, j, th, tw) def to_tensor(clip): """ Convert tensor data type from uint8 to float, divide value by 255.0 and permute the dimensions of clip tensor Args: clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) Return: clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) """ _is_tensor_video_clip(clip) if not clip.dtype == torch.uint8: raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) # return clip.float().permute(3, 0, 1, 2) / 255.0 return clip.float() / 255.0 def to_tensor_after_resize(clip): """ Convert resized tensor to [0, 1] Args: clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) Return: clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W), but in [0, 1] """ _is_tensor_video_clip(clip) # return clip.float().permute(3, 0, 1, 2) / 255.0 return clip.float() / 255.0 def normalize(clip, mean, std, inplace=False): """ Args: clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) mean (tuple): pixel RGB mean. Size is (3) std (tuple): pixel standard deviation. Size is (3) Returns: normalized clip (torch.tensor): Size is (T, C, H, W) """ if not _is_tensor_video_clip(clip): raise ValueError("clip should be a 4D torch.tensor") if not inplace: clip = clip.clone() mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) # print(mean) std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) return clip def hflip(clip): """ Args: clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) Returns: flipped clip (torch.tensor): Size is (T, C, H, W) """ if not _is_tensor_video_clip(clip): raise ValueError("clip should be a 4D torch.tensor") return clip.flip(-1) class RandomCropVideo: def __init__(self, size): if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: self.size = size def __call__(self, clip): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) Returns: torch.tensor: randomly cropped video clip. size is (T, C, OH, OW) """ i, j, h, w = self.get_params(clip) return crop(clip, i, j, h, w) def get_params(self, clip): h, w = clip.shape[-2:] th, tw = self.size if h < th or w < tw: raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") if w == tw and h == th: return 0, 0, h, w i = torch.randint(0, h - th + 1, size=(1,)).item() j = torch.randint(0, w - tw + 1, size=(1,)).item() return i, j, th, tw def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size})" def get_params(h, w, stride): th, tw = h // stride * stride, w // stride * stride i = (h - th) // 2 j = (w - tw) // 2 return i, j, th, tw class SpatialStrideCropVideo: def __init__(self, stride): self.stride = stride def __call__(self, clip): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) Returns: torch.tensor: cropped video clip by stride. size is (T, C, OH, OW) """ h, w = clip.shape[-2:] i, j, h, w = get_params(h, w, self.stride) return crop(clip, i, j, h, w) def __repr__(self) -> str: return f"{self.__class__.__name__}(stride={self.stride})" def longsideresize(h, w, size, skip_low_resolution): if h <= size[0] and w <= size[1] and skip_low_resolution: return h, w if h / w > size[0] / size[1]: # hxw 720x1280 size 320x640 hw_raito 9/16 > size_ratio 8/16 neww=320/720*1280=568 newh=320 w = int(size[0] / h * w) h = size[0] else: # hxw 720x1280 size 480x640 hw_raito 9/16 < size_ratio 12/16 newh=640/1280*720=360 neww=640 # hxw 1080x1920 size 720x1280 hw_raito 9/16 = size_ratio 9/16 newh=1280/1920*1080=720 neww=1280 h = int(size[1] / w * h) w = size[1] return h, w def maxhwresize(ori_height, ori_width, max_hxw): if ori_height * ori_width > max_hxw: scale_factor = np.sqrt(max_hxw / (ori_height * ori_width)) new_height = int(ori_height * scale_factor) new_width = int(ori_width * scale_factor) else: new_height = ori_height new_width = ori_width return new_height, new_width class LongSideResizeVideo: ''' First use the long side, then resize to the specified size ''' def __init__( self, size, skip_low_resolution=False, interpolation_mode="bilinear", ): self.size = size self.skip_low_resolution = skip_low_resolution self.interpolation_mode = interpolation_mode def __call__(self, clip): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) Returns: torch.tensor: scale resized video clip. """ _, _, h, w = clip.shape tr_h, tr_w = longsideresize(h, w, self.size, self.skip_low_resolution) if h == tr_h and w == tr_w: return clip resize_clip = resize(clip, target_size=(tr_h, tr_w), interpolation_mode=self.interpolation_mode) return resize_clip def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" class MaxHWResizeVideo: ''' First use the h*w, then resize to the specified size ''' def __init__( self, max_hxw, interpolation_mode="bilinear", ): self.max_hxw = max_hxw self.interpolation_mode = interpolation_mode def __call__(self, clip): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) Returns: torch.tensor: scale resized video clip. """ _, _, h, w = clip.shape tr_h, tr_w = maxhwresize(h, w, self.max_hxw) if h == tr_h and w == tr_w: return clip resize_clip = resize(clip, target_size=(tr_h, tr_w), interpolation_mode=self.interpolation_mode) return resize_clip def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" class CenterCropResizeVideo: ''' First use the short side for cropping length, center crop video, then resize to the specified size ''' def __init__( self, size, top_crop=False, interpolation_mode="bilinear", ): if len(size) != 2: raise ValueError(f"size should be tuple (height, width), instead got {size}") self.size = size self.top_crop = top_crop self.interpolation_mode = interpolation_mode def __call__(self, clip): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) Returns: torch.tensor: scale resized / center cropped video clip. size is (T, C, crop_size, crop_size) """ clip_center_crop = center_crop_th_tw(clip, self.size[0], self.size[1], top_crop=self.top_crop) clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode) return clip_center_crop_resize def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" class UCFCenterCropVideo: ''' First scale to the specified size in equal proportion to the short edge, then center cropping ''' def __init__( self, size, interpolation_mode="bilinear", ): if isinstance(size, tuple): if len(size) != 2: raise ValueError(f"size should be tuple (height, width), instead got {size}") self.size = size else: self.size = (size, size) self.interpolation_mode = interpolation_mode def __call__(self, clip): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) Returns: torch.tensor: scale resized / center cropped video clip. size is (T, C, crop_size, crop_size) """ clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode) clip_center_crop = center_crop(clip_resize, self.size) return clip_center_crop def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" class KineticsRandomCropResizeVideo: ''' Slide along the long edge, with the short edge as crop size. And resie to the desired size. ''' def __init__( self, size, interpolation_mode="bilinear", ): if isinstance(size, tuple): if len(size) != 2: raise ValueError(f"size should be tuple (height, width), instead got {size}") self.size = size else: self.size = (size, size) self.interpolation_mode = interpolation_mode def __call__(self, clip): clip_random_crop = random_shift_crop(clip) clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode) return clip_resize class CenterCropVideo: def __init__( self, size, interpolation_mode="bilinear", ): if isinstance(size, tuple): if len(size) != 2: raise ValueError(f"size should be tuple (height, width), instead got {size}") self.size = size else: self.size = (size, size) self.interpolation_mode = interpolation_mode def __call__(self, clip): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) Returns: torch.tensor: center cropped video clip. size is (T, C, crop_size, crop_size) """ clip_center_crop = center_crop(clip, self.size) return clip_center_crop def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" class NormalizeVideo: """ Normalize the video clip by mean subtraction and division by standard deviation Args: mean (3-tuple): pixel RGB mean std (3-tuple): pixel RGB standard deviation inplace (boolean): whether do in-place normalization """ def __init__(self, mean, std, inplace=False): self.mean = mean self.std = std self.inplace = inplace def __call__(self, clip): """ Args: clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W) """ return normalize(clip, self.mean, self.std, self.inplace) def __repr__(self) -> str: return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})" class ToTensorVideo: """ Convert tensor data type from uint8 to float, divide value by 255.0 and permute the dimensions of clip tensor """ def __init__(self): pass def __call__(self, clip): """ Args: clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) Return: clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) """ return to_tensor(clip) def __repr__(self) -> str: return self.__class__.__name__ class ToTensorAfterResize: """ Convert tensor data type from uint8 to float, divide value by 255.0 and permute the dimensions of clip tensor """ def __init__(self): pass def __call__(self, clip): """ Args: clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) Return: clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W), but in [0, 1] """ return to_tensor_after_resize(clip) def __repr__(self) -> str: return self.__class__.__name__ class RandomHorizontalFlipVideo: """ Flip the video clip along the horizontal direction with a given probability Args: p (float): probability of the clip being flipped. Default value is 0.5 """ def __init__(self, p=0.5): self.p = p def __call__(self, clip): """ Args: clip (torch.tensor): Size is (T, C, H, W) Return: clip (torch.tensor): Size is (T, C, H, W) """ if random.random() < self.p: clip = hflip(clip) return clip def __repr__(self) -> str: return f"{self.__class__.__name__}(p={self.p})" # ------------------------------------------------------------ # --------------------- Sampling --------------------------- # ------------------------------------------------------------ class TemporalRandomCrop(object): """Temporally crop the given frame indices at a random location. Args: size (int): Desired length of frames will be seen in the model. """ def __init__(self, size): self.size = size def __call__(self, total_frames): rand_end = max(0, total_frames - self.size - 1) begin_index = random.randint(0, rand_end) end_index = min(begin_index + self.size, total_frames) return begin_index, end_index class DynamicSampleDuration(object): """Temporally crop the given frame indices at a random location. Args: size (int): Desired length of frames will be seen in the model. """ def __init__(self, t_stride, extra_1): self.t_stride = t_stride self.extra_1 = extra_1 def __call__(self, t, h, w): if self.extra_1: t = t - 1 truncate_t_list = list(range(t+1))[t//2:][::self.t_stride] # need half at least truncate_t = random.choice(truncate_t_list) if self.extra_1: truncate_t = truncate_t + 1 return 0, truncate_t keywords = [ ' man ', ' woman ', ' person ', ' people ', 'human', ' individual ', ' child ', ' kid ', ' girl ', ' boy ', ] keywords += [i[:-1] + 's ' for i in keywords] masking_notices = [ "Note: The faces in this image are blurred.", "This image contains faces that have been pixelated.", "Notice: Faces in this image are masked.", "Please be aware that the faces in this image are obscured.", "The faces in this image are hidden.", "This is an image with blurred faces.", "The faces in this image have been processed.", "Attention: Faces in this image are not visible.", "The faces in this image are partially blurred.", "This image has masked faces.", "Notice: The faces in this picture have been altered.", "This is a picture with obscured faces.", "The faces in this image are pixelated.", "Please note, the faces in this image have been blurred.", "The faces in this photo are hidden.", "The faces in this picture have been masked.", "Note: The faces in this picture are altered.", "This is an image where faces are not clear.", "Faces in this image have been obscured.", "This picture contains masked faces.", "The faces in this image are processed.", "The faces in this picture are not visible.", "Please be aware, the faces in this photo are pixelated.", "The faces in this picture have been blurred.", ] webvid_watermark_notices = [ "This video has a faint Shutterstock watermark in the center.", "There is a slight Shutterstock watermark in the middle of this video.", "The video contains a subtle Shutterstock watermark in the center.", "This video features a light Shutterstock watermark at its center.", "A faint Shutterstock watermark is present in the middle of this video.", "There is a mild Shutterstock watermark at the center of this video.", "This video has a slight Shutterstock watermark in the middle.", "You can see a faint Shutterstock watermark in the center of this video.", "A subtle Shutterstock watermark appears in the middle of this video.", "This video includes a light Shutterstock watermark at its center.", ] high_aesthetic_score_notices_video = [ "This video has a high aesthetic quality.", "The beauty of this video is exceptional.", "This video scores high in aesthetic value.", "With its harmonious colors and balanced composition.", "This video ranks highly for aesthetic quality", "The artistic quality of this video is excellent.", "This video is rated high for beauty.", "The aesthetic quality of this video is impressive.", "This video has a top aesthetic score.", "The visual appeal of this video is outstanding.", ] low_aesthetic_score_notices_video = [ "This video has a low aesthetic quality.", "The beauty of this video is minimal.", "This video scores low in aesthetic appeal.", "The aesthetic quality of this video is below average.", "This video ranks low for beauty.", "The artistic quality of this video is lacking.", "This video has a low score for aesthetic value.", "The visual appeal of this video is low.", "This video is rated low for beauty.", "The aesthetic quality of this video is poor.", ] high_aesthetic_score_notices_image = [ "This image has a high aesthetic quality.", "The beauty of this image is exceptional", "This photo scores high in aesthetic value.", "With its harmonious colors and balanced composition.", "This image ranks highly for aesthetic quality.", "The artistic quality of this photo is excellent.", "This image is rated high for beauty.", "The aesthetic quality of this image is impressive.", "This photo has a top aesthetic score.", "The visual appeal of this image is outstanding.", ] low_aesthetic_score_notices_image = [ "This image has a low aesthetic quality.", "The beauty of this image is minimal.", "This image scores low in aesthetic appeal.", "The aesthetic quality of this image is below average.", "This image ranks low for beauty.", "The artistic quality of this image is lacking.", "This image has a low score for aesthetic value.", "The visual appeal of this image is low.", "This image is rated low for beauty.", "The aesthetic quality of this image is poor.", ] high_aesthetic_score_notices_image_human = [ "High-quality image with visible human features and high aesthetic score.", "Clear depiction of an individual in a high-quality image with top aesthetics.", "High-resolution photo showcasing visible human details and high beauty rating.", "Detailed, high-quality image with well-defined human subject and strong aesthetic appeal.", "Sharp, high-quality portrait with clear human features and high aesthetic value.", "High-quality image featuring a well-defined human presence and exceptional aesthetics.", "Visible human details in a high-resolution photo with a high aesthetic score.", "Clear, high-quality image with prominent human subject and superior aesthetic rating.", "High-quality photo capturing a visible human with excellent aesthetics.", "Detailed, high-quality image of a human with high visual appeal and aesthetic value.", ] def add_masking_notice(caption): if any(keyword in caption for keyword in keywords): notice = random.choice(masking_notices) return random.choice([caption + ' ' + notice, notice + ' ' + caption]) return caption def add_webvid_watermark_notice(caption): notice = random.choice(webvid_watermark_notices) return random.choice([caption + ' ' + notice, notice + ' ' + caption]) def add_aesthetic_notice_video(caption, aesthetic_score): if aesthetic_score <= 4.25: notice = random.choice(low_aesthetic_score_notices_video) return random.choice([caption + ' ' + notice, notice + ' ' + caption]) if aesthetic_score >= 5.75: notice = random.choice(high_aesthetic_score_notices_video) return random.choice([caption + ' ' + notice, notice + ' ' + caption]) return caption def add_aesthetic_notice_image(caption, aesthetic_score): if aesthetic_score <= 4.25: notice = random.choice(low_aesthetic_score_notices_image) return random.choice([caption + ' ' + notice, notice + ' ' + caption]) if aesthetic_score >= 5.75: notice = random.choice(high_aesthetic_score_notices_image) return random.choice([caption + ' ' + notice, notice + ' ' + caption]) return caption def add_high_aesthetic_notice_image(caption): notice = random.choice(high_aesthetic_score_notices_image) return random.choice([caption + ' ' + notice, notice + ' ' + caption]) def add_high_aesthetic_notice_image_human(caption): notice = random.choice(high_aesthetic_score_notices_image_human) return random.choice([caption + ' ' + notice, notice + ' ' + caption]) def basic_clean(text): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) return text.strip() def whitespace_clean(text): text = re.sub(r"\s+", " ", text) text = text.strip() return text def clean_youtube(text, is_tags=False): text = text.lower() + ' ' text = re.sub( r'#video|video|#shorts|shorts| shorts|#short| short|#youtubeshorts|youtubeshorts|#youtube| youtube|#shortsyoutube|#ytshorts|ytshorts|#ytshort|#shortvideo|shortvideo|#shortsfeed|#tiktok|tiktok|#tiktokchallenge|#myfirstshorts|#myfirstshort|#viral|viralvideo|viral|#viralshorts|#virlshort|#ytviralshorts|#instagram', ' ', text) text = re.sub(r' s |short|youtube|virlshort|#', ' ', text) pattern = r'[^a-zA-Z0-9\s\.,;:?!\'\"|]' if is_tags: pattern = r'[^a-zA-Z0-9\s]' text = re.sub(pattern, '', text) text = whitespace_clean(basic_clean(text)) return text def clean_vidal(text): title_hashtags = text.split('#') title, hashtags = title_hashtags[0], '#' + '#'.join(title_hashtags[1:]) title = clean_youtube(title) hashtags = clean_youtube(hashtags, is_tags=True) text = title + ', ' + hashtags if text == '' or text.isspace(): raise ValueError('text is empty') return text def calculate_statistics(data): if len(data) == 0: return None data = np.array(data) mean = np.mean(data) variance = np.var(data) std_dev = np.std(data) minimum = np.min(data) maximum = np.max(data) return { 'mean': mean, 'variance': variance, 'std_dev': std_dev, 'min': minimum, 'max': maximum } if __name__ == '__main__': from torchvision import transforms import torchvision.io as io import numpy as np from torchvision.utils import save_image import os vframes, aframes, info = io.read_video( filename='./v_Archery_g01_c03.avi', pts_unit='sec', output_format='TCHW' ) trans = transforms.Compose([ ToTensorVideo(), RandomHorizontalFlipVideo(), UCFCenterCropVideo(512), # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) ]) target_video_len = 32 frame_interval = 1 total_frames = len(vframes) print(total_frames) temporal_sample = TemporalRandomCrop(target_video_len * frame_interval) # Sampling video frames start_frame_ind, end_frame_ind = temporal_sample(total_frames) # print(start_frame_ind) # print(end_frame_ind) assert end_frame_ind - start_frame_ind >= target_video_len frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int) print(frame_indice) select_vframes = vframes[frame_indice] print(select_vframes.shape) print(select_vframes.dtype) select_vframes_trans = trans(select_vframes) print(select_vframes_trans.shape) print(select_vframes_trans.dtype) select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8) print(select_vframes_trans_int.dtype) print(select_vframes_trans_int.permute(0, 2, 3, 1).shape) io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8) for i in range(target_video_len): save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True, value_range=(-1, 1)) ================================================ FILE: opensora/dataset/virtual_disk.py ================================================ import subprocess import json import pickle from collections import OrderedDict from opensora.npu_config import npu_config import sys import os class SuppressStdout: _instance = None def __new__(cls, *args, **kwargs): if cls._instance is None: cls._instance = super(SuppressStdout, cls).__new__(cls, *args, **kwargs) return cls._instance def __enter__(self): self._original_stdout = sys.stdout sys.stdout = open(os.devnull, 'w') def __exit__(self, exc_type, exc_value, traceback): sys.stdout.close() sys.stdout = self._original_stdout # 创建单例 class ObsConnection: """ AK, SK, STS_TOKEN临时密钥有效时效云计算网站最长为24h buckets & object: https://uconsole.ccaicc.com/#/mgt/modelarts -> 对象控制台 keys & tokens: https://uconsole.ccaicc.com/#/mgt/modelarts -> 对象控制台 -> 获取访问密匙(AK 和 SK) """ def __init__(self): with open(f"{npu_config.work_path}/scripts/train_data/key.json", "r") as f: key = json.load(f) self.AK = key["AK"] self.SK = key["SK"] self.endpoint = key["EP"] self.bucket = "sora" self.suppress_stdout = SuppressStdout() def connect(self, obs): config_command = [ obs, 'config', '-i=' + self.AK, '-k=' + self.SK, '-e=' + self.endpoint ] result = subprocess.run(config_command, capture_output=True, text=True) if result.returncode != 0: print(f"Failed to configure obsutil: {result.stderr}") else: print("Successfully configured obsutil") class VirtualDisk: """ :param storage_dir: 内存虚拟磁盘的挂载点路径。 :param size: 内存虚拟磁盘的大小,例如 '1G'。 :param obs: linux 系统里面obs具体位置 :param connection: 抽象出obs连接管理 """ def __init__(self, storage_dir, size="1G", obs="/home/opensora/obsutil_linux_arm64_5.5.12/obsutil"): self.obs = obs self.connection = ObsConnection() self.connection.connect(obs) os.makedirs(storage_dir, exist_ok=True) self.storage_dir = storage_dir self.size = self._convert_size_to_bytes(size) if not self.is_tmpfs_mounted(): self.create_ramdisk() else: print(f"{self.storage_dir} is already mounted as tmpfs.") self.index_file = os.path.join(self.storage_dir, 'index.pkl') self.index = self.load_index() self.lru = OrderedDict() self.current_size = self.get_total_storage_size() # 初始化时计算总大小 def _convert_size_to_bytes(self, size): unit = size[-1].upper() size_value = int(size[:-1]) if unit == 'K': return size_value * 1024 elif unit == 'M': return size_value * 1024 ** 2 elif unit == 'G': return size_value * 1024 ** 3 else: raise ValueError("Invalid size unit. Use K, M, or G.") """ 创建并挂载一个 tmpfs 类型的内存虚拟磁盘。 """ def create_ramdisk(self): try: # 如果挂载点目录不存在,创建它 if not os.path.exists(self.storage_dir): os.makedirs(self.storage_dir) # 挂载 tmpfs 到挂载点 subprocess.run(['sudo', 'mount', '-t', 'tmpfs', '-o', f'size={self.size}', 'tmpfs', self.storage_dir], check=True) print(f"Successfully mounted tmpfs on {self.storage_dir} with size {self.size}.") except subprocess.CalledProcessError as e: print(f"Failed to mount tmpfs: {e}") except Exception as e: print(f"An error occurred: {e}") def load_index(self): """ 加载索引文件。 :return: 索引字典。 """ if os.path.exists(self.index_file): with open(self.index_file, 'rb') as f: return pickle.load(f) return {} def save_index(self): """ 保存索引文件。 """ with open(self.index_file, 'wb') as f: pickle.dump(self.index, f) """ 取消挂载内存虚拟磁盘。 :param storage_dir: 内存虚拟磁盘的挂载点路径。 """ def unmount_ramdisk(self): try: # 确保没有进程在使用挂载点后取消挂载 subprocess.run(['sudo', 'umount', self.storage_dir], check=True) print(f"Successfully unmounted tmpfs from {self.storage_dir}.") except subprocess.CalledProcessError as e: print(f"Failed to unmount tmpfs: {e}") except Exception as e: print(f"An error occurred: {e}") """ 检查挂载点是否已经被挂载为 tmpfs。 :param storage_dir: 挂载点路径。 :return: 如果已挂载为 tmpfs,返回 True;否则返回 False。 """ def is_tmpfs_mounted(self): try: result = subprocess.run(['mountpoint', '-q', self.storage_dir], check=False) if result.returncode == 0: return True return False except Exception as e: print(f"An error occurred while checking if tmpfs is mounted: {e}") return False def get_data(self, key): """ 获取存储在本地磁盘上的数据。如果数据不存在,通过 obsutil 从远端获取并存储。 :param key: 数据的唯一键。 :return: 数据。 """ # if key in self.index: # data_file = self.index[key] # if os.path.exists(data_file): # self.lru.move_to_end(key) # with open(data_file, 'rb') as f: # # print(f"Successfully get {key} from local") # return pickle.load(f) # 如果数据不存在,使用 obsutil 从远端获取 object_name = key # 假设 key 对应于远端对象名称 local_path = os.path.join(self.storage_dir, key) with self.connection.suppress_stdout: self.download_and_convert_to_pickle(self.connection.bucket, object_name, local_path) # 保存数据的位置 # self.index[key] = local_path # self.save_index() # self.lru[key] = local_path # # file_size = os.path.getsize(local_path) # self.current_size += file_size # self.ensure_storage_limit() return local_path def del_data(self, local_path): os.remove(local_path) def download_and_convert_to_pickle(self, bucket, object_name, local_path): """ 使用 obsutil 从 OBS 下载文件并转换为 pickle 格式存储到本地路径。 :param bucket: OBS 存储桶名称。 :param object_name: OBS 中的对象名称。 :param local_path: 本地文件路径。 """ # try: # 下载文件到local_path路径 subprocess.run([self.obs, 'cp', f'obs://{bucket}/{object_name}', local_path], check=True) # print(f"Successfully downloaded obs://{bucket}/{object_name} to {local_path}.") # except subprocess.CalledProcessError as e: # print(f"Failed to download obs://{bucket}/{object_name} to {local_path}: {e}") def ensure_storage_limit(self): """ 确保存储总大小不超过虚拟磁盘大小,超出时根据LRU策略删除最旧的文件。 """ while self.current_size > self.size: oldest_key, oldest_path = self.lru.popitem(last=False) file_size = os.path.getsize(oldest_path) os.remove(oldest_path) del self.index[oldest_key] self.save_index() print(f"Removed {oldest_key} to free up {file_size} bytes.") self.current_size -= file_size def get_total_storage_size(self): """ 获取当前所有存储文件的总大小。 :return: 总大小(字节)。 """ total_size = 0 for path in self.lru.values(): if os.path.exists(path): total_size += os.path.getsize(path) return total_size ================================================ FILE: opensora/models/__init__.py ================================================ from .causalvideovae import CausalVAEModelWrapper, WFVAEModelWrapper ================================================ FILE: opensora/models/causalvideovae/__init__.py ================================================ from torchvision.transforms import Lambda from .model.vae import CausalVAEModel, WFVAEModel from einops import rearrange import torch try: import torch_npu from opensora.npu_config import npu_config except: torch_npu = None npu_config = None pass import torch.nn as nn import torch class CausalVAEModelWrapper(nn.Module): def __init__(self, model_path, subfolder=None, cache_dir=None, use_ema=False, **kwargs): super(CausalVAEModelWrapper, self).__init__() self.vae = CausalVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir, **kwargs) def encode(self, x): x = self.vae.encode(x).sample().mul_(0.18215) return x def decode(self, x): x = self.vae.decode(x / 0.18215) x = rearrange(x, 'b c t h w -> b t c h w').contiguous() return x def dtype(self): return self.vae.dtype class WFVAEModelWrapper(nn.Module): def __init__(self, model_path, subfolder=None, cache_dir=None, **kwargs): super(WFVAEModelWrapper, self).__init__() self.vae = WFVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir, **kwargs) self.register_buffer('shift', torch.tensor(self.vae.config.shift)[None, :, None, None, None]) self.register_buffer('scale', torch.tensor(self.vae.config.scale)[None, :, None, None, None]) def encode(self, x): x = (self.vae.encode(x).sample() - self.shift.to(x.device, dtype=x.dtype)) * self.scale.to(x.device, dtype=x.dtype) return x def decode(self, x): x = x / self.scale.to(x.device, dtype=x.dtype) + self.shift.to(x.device, dtype=x.dtype) x = self.vae.decode(x) x = rearrange(x, 'b c t h w -> b t c h w').contiguous() return x def dtype(self): return self.vae.dtype ae_wrapper = { 'CausalVAEModel_D4_2x8x8': CausalVAEModelWrapper, 'CausalVAEModel_D8_2x8x8': CausalVAEModelWrapper, 'CausalVAEModel_D4_4x8x8': CausalVAEModelWrapper, 'CausalVAEModel_D8_4x8x8': CausalVAEModelWrapper, 'WFVAEModel_D8_4x8x8': WFVAEModelWrapper, 'WFVAEModel_D16_4x8x8': WFVAEModelWrapper, 'WFVAEModel_D32_4x8x8': WFVAEModelWrapper, 'WFVAEModel_D32_8x8x8': WFVAEModelWrapper, } ae_stride_config = { 'CausalVAEModel_D4_2x8x8': [2, 8, 8], 'CausalVAEModel_D8_2x8x8': [2, 8, 8], 'CausalVAEModel_D4_4x8x8': [4, 8, 8], 'CausalVAEModel_D8_4x8x8': [4, 8, 8], 'WFVAEModel_D8_4x8x8': [4, 8, 8], 'WFVAEModel_D16_4x8x8': [4, 8, 8], 'WFVAEModel_D32_4x8x8': [4, 8, 8], 'WFVAEModel_D32_8x8x8': [8, 8, 8], } ae_channel_config = { 'CausalVAEModel_D4_2x8x8': 4, 'CausalVAEModel_D8_2x8x8': 8, 'CausalVAEModel_D4_4x8x8': 4, 'CausalVAEModel_D8_4x8x8': 8, 'WFVAEModel_D8_4x8x8': 8, 'WFVAEModel_D16_4x8x8': 16, 'WFVAEModel_D32_4x8x8': 32, 'WFVAEModel_D32_8x8x8': 32, } ae_denorm = { 'CausalVAEModel_D4_2x8x8': lambda x: (x + 1.) / 2., 'CausalVAEModel_D8_2x8x8': lambda x: (x + 1.) / 2., 'CausalVAEModel_D4_4x8x8': lambda x: (x + 1.) / 2., 'CausalVAEModel_D8_4x8x8': lambda x: (x + 1.) / 2., 'WFVAEModel_D8_4x8x8': lambda x: (x + 1.) / 2., 'WFVAEModel_D16_4x8x8': lambda x: (x + 1.) / 2., 'WFVAEModel_D32_4x8x8': lambda x: (x + 1.) / 2., 'WFVAEModel_D32_8x8x8': lambda x: (x + 1.) / 2., } ae_norm = { 'CausalVAEModel_D4_2x8x8': Lambda(lambda x: 2. * x - 1.), 'CausalVAEModel_D8_2x8x8': Lambda(lambda x: 2. * x - 1.), 'CausalVAEModel_D4_4x8x8': Lambda(lambda x: 2. * x - 1.), 'CausalVAEModel_D8_4x8x8': Lambda(lambda x: 2. * x - 1.), 'WFVAEModel_D8_4x8x8': Lambda(lambda x: 2. * x - 1.), 'WFVAEModel_D16_4x8x8': Lambda(lambda x: 2. * x - 1.), 'WFVAEModel_D32_4x8x8': Lambda(lambda x: 2. * x - 1.), 'WFVAEModel_D32_8x8x8': Lambda(lambda x: 2. * x - 1.), } ================================================ FILE: opensora/models/causalvideovae/dataset/__init__.py ================================================ ================================================ FILE: opensora/models/causalvideovae/dataset/ddp_sampler.py ================================================ import math from typing import TypeVar, Optional, Iterator import torch from torch.utils.data import Sampler, Dataset import torch.distributed as dist T_co = TypeVar('T_co', covariant=True) class CustomDistributedSampler(Sampler[T_co]): r"""Sampler that restricts data loading to a subset of the dataset. It is especially useful in conjunction with :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the original dataset that is exclusive to it. .. note:: Dataset is assumed to be of constant size and that any instance of it always returns the same elements in the same order. Args: dataset: Dataset used for sampling. num_replicas (int, optional): Number of processes participating in distributed training. By default, :attr:`world_size` is retrieved from the current distributed group. rank (int, optional): Rank of the current process within :attr:`num_replicas`. By default, :attr:`rank` is retrieved from the current distributed group. shuffle (bool, optional): If ``True`` (default), sampler will shuffle the indices. seed (int, optional): random seed used to shuffle the sampler if :attr:`shuffle=True`. This number should be identical across all processes in the distributed group. Default: ``0``. drop_last (bool, optional): if ``True``, then the sampler will drop the tail of the data to make it evenly divisible across the number of replicas. If ``False``, the sampler will add extra indices to make the data evenly divisible across the replicas. Default: ``False``. .. warning:: In distributed mode, calling the :meth:`set_epoch` method at the beginning of each epoch **before** creating the :class:`DataLoader` iterator is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be always used. Example:: >>> # xdoctest: +SKIP >>> sampler = DistributedSampler(dataset) if is_distributed else None >>> loader = DataLoader(dataset, shuffle=(sampler is None), ... sampler=sampler) >>> for epoch in range(start_epoch, n_epochs): ... if is_distributed: ... sampler.set_epoch(epoch) ... train(loader) """ def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = True, seed: int = 0, drop_last: bool = False) -> None: if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") num_replicas = dist.get_world_size() if rank is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() if rank >= num_replicas or rank < 0: raise ValueError( f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") self.dataset = dataset self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.current_index = 0 self.drop_last = drop_last # If the dataset length is evenly divisible by # of replicas, then there # is no need to drop any data, since the dataset will be split equally. if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] # Split to nearest available length that is evenly divisible. # This is to ensure each rank receives the same amount of data when # using this Sampler. self.num_samples = math.ceil( (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] ) else: self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle self.seed = seed def __iter__(self) -> Iterator[T_co]: if self.shuffle: # deterministically shuffle based on epoch and seed g = torch.Generator() g.manual_seed(self.seed + self.epoch) indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] else: indices = list(range(len(self.dataset))) # type: ignore[arg-type] if not self.drop_last: # add extra samples to make it evenly divisible padding_size = self.total_size - len(indices) if padding_size <= len(indices): indices += indices[:padding_size] else: indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] else: # remove tail of data to make it evenly divisible. indices = indices[:self.total_size] assert len(indices) == self.total_size # subsample indices = indices[self.rank:self.total_size:self.num_replicas] assert len(indices) == self.num_samples while self.current_index < len(indices): yield indices[self.current_index] self.current_index += 1 self.current_index = 0 def __len__(self) -> int: return self.num_samples def set_epoch(self, epoch: int) -> None: r""" Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering. Args: epoch (int): Epoch number. """ self.epoch = epoch def state_dict(self) -> dict: return { 'epoch': self.epoch, 'seed': self.seed, 'current_index': self.current_index } def load_state_dict(self, state_dict: dict) -> None: self.epoch = state_dict['epoch'] self.seed = state_dict['seed'] self.current_index = state_dict.get('current_index', 0) ================================================ FILE: opensora/models/causalvideovae/dataset/transform.py ================================================ import torch import random import numbers from torchvision.transforms import RandomCrop, RandomResizedCrop def _is_tensor_video_clip(clip): if not torch.is_tensor(clip): raise TypeError("clip should be Tensor. Got %s" % type(clip)) if not clip.ndimension() == 4: raise ValueError("clip should be 4D. Got %dD" % clip.dim()) return True def center_crop_arr(pil_image, image_size): """ Center cropping implementation from ADM. https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 """ while min(*pil_image.size) >= 2 * image_size: pil_image = pil_image.resize( tuple(x // 2 for x in pil_image.size), resample=Image.BOX ) scale = image_size / min(*pil_image.size) pil_image = pil_image.resize( tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC ) arr = np.array(pil_image) crop_y = (arr.shape[0] - image_size) // 2 crop_x = (arr.shape[1] - image_size) // 2 return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) def crop(clip, i, j, h, w): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) """ if len(clip.size()) != 4: raise ValueError("clip should be a 4D tensor") return clip[..., i: i + h, j: j + w] def resize(clip, target_size, interpolation_mode): if len(target_size) != 2: raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=True, antialias=True) def resize_scale(clip, target_size, interpolation_mode): if len(target_size) != 2: raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") H, W = clip.size(-2), clip.size(-1) scale_ = target_size[0] / min(H, W) return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=True, antialias=True) def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): """ Do spatial cropping and resizing to the video clip Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) i (int): i in (i,j) i.e coordinates of the upper left corner. j (int): j in (i,j) i.e coordinates of the upper left corner. h (int): Height of the cropped region. w (int): Width of the cropped region. size (tuple(int, int)): height and width of resized clip Returns: clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W) """ if not _is_tensor_video_clip(clip): raise ValueError("clip should be a 4D torch.tensor") clip = crop(clip, i, j, h, w) clip = resize(clip, size, interpolation_mode) return clip def center_crop(clip, crop_size): if not _is_tensor_video_clip(clip): raise ValueError("clip should be a 4D torch.tensor") h, w = clip.size(-2), clip.size(-1) th, tw = crop_size if h < th or w < tw: raise ValueError("height and width must be no smaller than crop_size") i = int(round((h - th) / 2.0)) j = int(round((w - tw) / 2.0)) return crop(clip, i, j, th, tw) def center_crop_using_short_edge(clip): if not _is_tensor_video_clip(clip): raise ValueError("clip should be a 4D torch.tensor") h, w = clip.size(-2), clip.size(-1) if h < w: th, tw = h, h i = 0 j = int(round((w - tw) / 2.0)) else: th, tw = w, w i = int(round((h - th) / 2.0)) j = 0 return crop(clip, i, j, th, tw) def random_shift_crop(clip): ''' Slide along the long edge, with the short edge as crop size ''' if not _is_tensor_video_clip(clip): raise ValueError("clip should be a 4D torch.tensor") h, w = clip.size(-2), clip.size(-1) if h <= w: long_edge = w short_edge = h else: long_edge = h short_edge = w th, tw = short_edge, short_edge i = torch.randint(0, h - th + 1, size=(1,)).item() j = torch.randint(0, w - tw + 1, size=(1,)).item() return crop(clip, i, j, th, tw) def to_tensor(clip): """ Convert tensor data type from uint8 to float, divide value by 255.0 and permute the dimensions of clip tensor Args: clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) Return: clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) """ _is_tensor_video_clip(clip) if not clip.dtype == torch.uint8: raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) # return clip.float().permute(3, 0, 1, 2) / 255.0 return clip.float() / 255.0 def normalize(clip, mean, std, inplace=False): """ Args: clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) mean (tuple): pixel RGB mean. Size is (3) std (tuple): pixel standard deviation. Size is (3) Returns: normalized clip (torch.tensor): Size is (T, C, H, W) """ if not _is_tensor_video_clip(clip): raise ValueError("clip should be a 4D torch.tensor") if not inplace: clip = clip.clone() mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) # print(mean) std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) return clip def hflip(clip): """ Args: clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) Returns: flipped clip (torch.tensor): Size is (T, C, H, W) """ if not _is_tensor_video_clip(clip): raise ValueError("clip should be a 4D torch.tensor") return clip.flip(-1) class RandomCropVideo: def __init__(self, size): if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: self.size = size def __call__(self, clip): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) Returns: torch.tensor: randomly cropped video clip. size is (T, C, OH, OW) """ i, j, h, w = self.get_params(clip) return crop(clip, i, j, h, w) def get_params(self, clip): h, w = clip.shape[-2:] th, tw = self.size if h < th or w < tw: raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") if w == tw and h == th: return 0, 0, h, w i = torch.randint(0, h - th + 1, size=(1,)).item() j = torch.randint(0, w - tw + 1, size=(1,)).item() return i, j, th, tw def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size})" class SpatialStrideCropVideo: def __init__(self, stride): self.stride = stride def __call__(self, clip): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) Returns: torch.tensor: cropped video clip by stride. size is (T, C, OH, OW) """ i, j, h, w = self.get_params(clip) return crop(clip, i, j, h, w) def get_params(self, clip): h, w = clip.shape[-2:] th, tw = h // self.stride * self.stride, w // self.stride * self.stride return 0, 0, th, tw # from top-left def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size})" class LongSideResizeVideo: ''' First use the long side, then resize to the specified size ''' def __init__( self, size, skip_low_resolution=False, interpolation_mode="bilinear", ): self.size = size self.skip_low_resolution = skip_low_resolution self.interpolation_mode = interpolation_mode def __call__(self, clip): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) Returns: torch.tensor: scale resized video clip. size is (T, C, 512, *) or (T, C, *, 512) """ _, _, h, w = clip.shape if self.skip_low_resolution and max(h, w) <= self.size: return clip if h > w: w = int(w * self.size / h) h = self.size else: h = int(h * self.size / w) w = self.size resize_clip = resize(clip, target_size=(h, w), interpolation_mode=self.interpolation_mode) return resize_clip def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" class CenterCropResizeVideo: ''' First use the short side for cropping length, center crop video, then resize to the specified size ''' def __init__( self, size, interpolation_mode="bilinear", ): if isinstance(size, tuple): if len(size) != 2: raise ValueError(f"size should be tuple (height, width), instead got {size}") self.size = size else: self.size = (size, size) self.interpolation_mode = interpolation_mode def __call__(self, clip): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) Returns: torch.tensor: scale resized / center cropped video clip. size is (T, C, crop_size, crop_size) """ clip_center_crop = center_crop_using_short_edge(clip) clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode) return clip_center_crop_resize def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" class UCFCenterCropVideo: ''' First scale to the specified size in equal proportion to the short edge, then center cropping ''' def __init__( self, size, interpolation_mode="bilinear", ): if isinstance(size, tuple): if len(size) != 2: raise ValueError(f"size should be tuple (height, width), instead got {size}") self.size = size else: self.size = (size, size) self.interpolation_mode = interpolation_mode def __call__(self, clip): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) Returns: torch.tensor: scale resized / center cropped video clip. size is (T, C, crop_size, crop_size) """ clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode) clip_center_crop = center_crop(clip_resize, self.size) return clip_center_crop def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" class KineticsRandomCropResizeVideo: ''' Slide along the long edge, with the short edge as crop size. And resie to the desired size. ''' def __init__( self, size, interpolation_mode="bilinear", ): if isinstance(size, tuple): if len(size) != 2: raise ValueError(f"size should be tuple (height, width), instead got {size}") self.size = size else: self.size = (size, size) self.interpolation_mode = interpolation_mode def __call__(self, clip): clip_random_crop = random_shift_crop(clip) clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode) return clip_resize class CenterCropVideo: def __init__( self, size, interpolation_mode="bilinear", ): if isinstance(size, tuple): if len(size) != 2: raise ValueError(f"size should be tuple (height, width), instead got {size}") self.size = size else: self.size = (size, size) self.interpolation_mode = interpolation_mode def __call__(self, clip): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) Returns: torch.tensor: center cropped video clip. size is (T, C, crop_size, crop_size) """ clip_center_crop = center_crop(clip, self.size) return clip_center_crop def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" class NormalizeVideo: """ Normalize the video clip by mean subtraction and division by standard deviation Args: mean (3-tuple): pixel RGB mean std (3-tuple): pixel RGB standard deviation inplace (boolean): whether do in-place normalization """ def __init__(self, mean, std, inplace=False): self.mean = mean self.std = std self.inplace = inplace def __call__(self, clip): """ Args: clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W) """ return normalize(clip, self.mean, self.std, self.inplace) def __repr__(self) -> str: return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})" class ToTensorVideo: """ Convert tensor data type from uint8 to float, divide value by 255.0 and permute the dimensions of clip tensor """ def __init__(self): pass def __call__(self, clip): """ Args: clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) Return: clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) """ return to_tensor(clip) def __repr__(self) -> str: return self.__class__.__name__ class RandomHorizontalFlipVideo: """ Flip the video clip along the horizontal direction with a given probability Args: p (float): probability of the clip being flipped. Default value is 0.5 """ def __init__(self, p=0.5): self.p = p def __call__(self, clip): """ Args: clip (torch.tensor): Size is (T, C, H, W) Return: clip (torch.tensor): Size is (T, C, H, W) """ if random.random() < self.p: clip = hflip(clip) return clip def __repr__(self) -> str: return f"{self.__class__.__name__}(p={self.p})" # ------------------------------------------------------------ # --------------------- Sampling --------------------------- # ------------------------------------------------------------ class TemporalRandomCrop(object): """Temporally crop the given frame indices at a random location. Args: size (int): Desired length of frames will be seen in the model. """ def __init__(self, size): self.size = size def __call__(self, total_frames): rand_end = max(0, total_frames - self.size - 1) begin_index = random.randint(0, rand_end) end_index = min(begin_index + self.size, total_frames) return begin_index, end_index class DynamicSampleDuration(object): """Temporally crop the given frame indices at a random location. Args: size (int): Desired length of frames will be seen in the model. """ def __init__(self, t_stride, extra_1): self.t_stride = t_stride self.extra_1 = extra_1 def __call__(self, t, h, w): if self.extra_1: t = t - 1 truncate_t_list = list(range(t+1))[t//2:][::self.t_stride] # need half at least truncate_t = random.choice(truncate_t_list) if self.extra_1: truncate_t = truncate_t + 1 return 0, truncate_t if __name__ == '__main__': from torchvision import transforms import torchvision.io as io import numpy as np from torchvision.utils import save_image import os vframes, aframes, info = io.read_video( filename='./v_Archery_g01_c03.avi', pts_unit='sec', output_format='TCHW' ) trans = transforms.Compose([ ToTensorVideo(), RandomHorizontalFlipVideo(), UCFCenterCropVideo(512), # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) ]) target_video_len = 32 frame_interval = 1 total_frames = len(vframes) print(total_frames) temporal_sample = TemporalRandomCrop(target_video_len * frame_interval) # Sampling video frames start_frame_ind, end_frame_ind = temporal_sample(total_frames) # print(start_frame_ind) # print(end_frame_ind) assert end_frame_ind - start_frame_ind >= target_video_len frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int) print(frame_indice) select_vframes = vframes[frame_indice] print(select_vframes.shape) print(select_vframes.dtype) select_vframes_trans = trans(select_vframes) print(select_vframes_trans.shape) print(select_vframes_trans.dtype) select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8) print(select_vframes_trans_int.dtype) print(select_vframes_trans_int.permute(0, 2, 3, 1).shape) io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8) for i in range(target_video_len): save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True, value_range=(-1, 1)) ================================================ FILE: opensora/models/causalvideovae/dataset/video_dataset.py ================================================ import os.path as osp import random from glob import glob from torchvision import transforms import numpy as np import torch import torch.utils.data as data import torch.nn.functional as F import pickle import decord from torch.nn import functional as F from .transform import ToTensorVideo, CenterCropVideo from torchvision.transforms._transforms_video import CenterCropVideo as TVCenterCropVideo from torchvision.transforms import Lambda, Compose, Resize import torch import os class DecordInit(object): def __init__(self, num_threads=1): self.num_threads = num_threads self.ctx = decord.cpu(0) def __call__(self, filename): reader = decord.VideoReader( filename, ctx=self.ctx, num_threads=self.num_threads ) return reader def __repr__(self): repr_str = ( f"{self.__class__.__name__}(" f"sr={self.sr}," f"num_threads={self.num_threads})" ) return repr_str def TemporalRandomCrop(total_frames, size): rand_end = max(0, total_frames - size - 1) begin_index = random.randint(0, rand_end) end_index = min(begin_index + size, total_frames) return begin_index, end_index def _format_video_shape(video, time_compress=4, spatial_compress=8): """Prepare video for VAE""" time = video.shape[1] height = video.shape[2] width = video.shape[3] new_time = ( (time - (time - 1) % time_compress) if (time - 1) % time_compress != 0 else time ) new_height = ( (height - (height) % spatial_compress) if height % spatial_compress != 0 else height ) new_width = ( (width - (width) % spatial_compress) if width % spatial_compress != 0 else width ) return video[:, :new_time, :new_height, :new_width] class TrainVideoDataset(data.Dataset): video_exts = ["avi", "mp4", "webm"] def __init__( self, video_folder, sequence_length, train=True, resolution=64, sample_rate=1, dynamic_sample=True, cache_file=None, is_main_process=False, ): self.train = train self.sequence_length = sequence_length self.sample_rate = sample_rate self.resolution = resolution self.v_decoder = DecordInit() self.video_folder = video_folder self.dynamic_sample = dynamic_sample self.cache_file = cache_file self.transform = transforms.Compose( [ ToTensorVideo(), Resize(self.resolution), CenterCropVideo(self.resolution), Lambda(lambda x: 2.0 * x - 1.0), ] ) print("Building datasets...") self.is_main_process = is_main_process self.samples = self._make_dataset() def _make_dataset(self): cache_file = osp.join(self.video_folder, self.cache_file) if osp.exists(cache_file): with open(cache_file, "rb") as f: samples = pickle.load(f) else: samples = [] samples += sum( [ glob(osp.join(self.video_folder, "**", f"*.{ext}"), recursive=True) for ext in self.video_exts ], [], ) if self.is_main_process: with open(cache_file, "wb") as f: pickle.dump(samples, f) return samples def __len__(self): return len(self.samples) def __getitem__(self, idx): video_path = self.samples[idx] try: video = self.decord_read(video_path) video = self.transform(video) # T C H W -> T C H W video = video.transpose(0, 1) # T C H W -> C T H W return dict(video=video, label="") except Exception as e: print(f"Error with {e}, {video_path}") return self.__getitem__(random.randint(0, self.__len__() - 1)) def decord_read(self, path): decord_vr = self.v_decoder(path) total_frames = len(decord_vr) # Sampling video frames if self.dynamic_sample: sample_rate = random.randint(1, self.sample_rate) else: sample_rate = self.sample_rate size = self.sequence_length * sample_rate start_frame_ind, end_frame_ind = TemporalRandomCrop(total_frames, size) frame_indice = np.linspace( start_frame_ind, end_frame_ind - 1, self.sequence_length, dtype=int ) video_data = decord_vr.get_batch(frame_indice).asnumpy() video_data = torch.from_numpy(video_data) video_data = video_data.permute(0, 3, 1, 2) return video_data def resize(x, resolution): height, width = x.shape[-2:] aspect_ratio = width / height if width <= height: new_width = resolution new_height = int(resolution / aspect_ratio) else: new_height = resolution new_width = int(resolution * aspect_ratio) resized_x = F.interpolate(x, size=(new_height, new_width), mode='bilinear', align_corners=True, antialias=True) return resized_x class ValidVideoDataset(data.Dataset): video_exts = ["avi", "mp4", "webm"] def __init__( self, real_video_dir, num_frames, sample_rate=1, crop_size=None, resolution=128, is_main_process=False ) -> None: super().__init__() self.is_main_process = is_main_process self.real_video_files = self._make_dataset(real_video_dir) self.num_frames = num_frames self.sample_rate = sample_rate self.crop_size = crop_size self.short_size = resolution self.v_decoder = DecordInit() self.transform = Compose( [ ToTensorVideo(), Resize(resolution), CenterCropVideo(resolution) if crop_size is not None else Lambda(lambda x: x), ] ) def _make_dataset(self, real_video_dir): cache_file = osp.join(real_video_dir, "idx.pkl") if osp.exists(cache_file): with open(cache_file, "rb") as f: samples = pickle.load(f) else: samples = [] samples += sum( [ glob(osp.join(real_video_dir, "**", f"*.{ext}"), recursive=True) for ext in self.video_exts ], [], ) if self.is_main_process: with open(cache_file, "wb") as f: pickle.dump(samples, f) return samples def __len__(self): return len(self.real_video_files) def __getitem__(self, index): try: if index >= len(self): raise IndexError real_video_file = self.real_video_files[index] real_video_tensor = self._load_video(real_video_file) real_video_tensor = self.transform(real_video_tensor) video_name = os.path.basename(real_video_file) return {'video': real_video_tensor, 'file_name': video_name } except: print(f"Video error: {self.real_video_files[index]}") return self.__getitem__(0) def _load_video(self, video_path, sample_rate=None): num_frames = self.num_frames if not sample_rate: sample_rate = self.sample_rate try: decord_vr = self.v_decoder(video_path) except: raise Exception(f"fail to load {video_path}.") total_frames = len(decord_vr) sample_frames_len = sample_rate * num_frames if total_frames >= sample_frames_len: s = 0 e = s + sample_frames_len num_frames = num_frames else: raise Exception("video too short!") frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int) video_data = decord_vr.get_batch(frame_id_list).asnumpy() video_data = torch.from_numpy(video_data) video_data = video_data.permute(3, 0, 1, 2) return video_data ================================================ FILE: opensora/models/causalvideovae/eval/cal_fvd.py ================================================ import numpy as np import torch from tqdm import tqdm def trans(x): # if greyscale images add channel if x.shape[-3] == 1: x = x.repeat(1, 1, 3, 1, 1) # permute BTCHW -> BCTHW x = x.permute(0, 2, 1, 3, 4) return x def calculate_fvd(videos1, videos2, device, method='styleganv'): if method == 'styleganv': from fvd.styleganv.fvd import get_fvd_feats, frechet_distance, load_i3d_pretrained elif method == 'videogpt': from fvd.videogpt.fvd import load_i3d_pretrained from fvd.videogpt.fvd import get_fvd_logits as get_fvd_feats from fvd.videogpt.fvd import frechet_distance print("calculate_fvd...") # videos [batch_size, timestamps, channel, h, w] assert videos1.shape == videos2.shape i3d = load_i3d_pretrained(device=device) fvd_results = [] # support grayscale input, if grayscale -> channel*3 # BTCHW -> BCTHW # videos -> [batch_size, channel, timestamps, h, w] videos1 = trans(videos1) videos2 = trans(videos2) fvd_results = {} # for calculate FVD, each clip_timestamp must >= 10 for clip_timestamp in tqdm(range(10, videos1.shape[-3]+1)): # get a video clip # videos_clip [batch_size, channel, timestamps[:clip], h, w] videos_clip1 = videos1[:, :, : clip_timestamp] videos_clip2 = videos2[:, :, : clip_timestamp] # get FVD features feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device) feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device) # calculate FVD when timestamps[:clip] fvd_results[clip_timestamp] = frechet_distance(feats1, feats2) result = { "value": fvd_results, "video_setting": videos1.shape, "video_setting_name": "batch_size, channel, time, heigth, width", } return result # test code / using example def main(): NUMBER_OF_VIDEOS = 8 VIDEO_LENGTH = 50 CHANNEL = 3 SIZE = 64 videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) device = torch.device("cuda") # device = torch.device("cpu") import json result = calculate_fvd(videos1, videos2, device, method='videogpt') print(json.dumps(result, indent=4)) result = calculate_fvd(videos1, videos2, device, method='styleganv') print(json.dumps(result, indent=4)) if __name__ == "__main__": main() ================================================ FILE: opensora/models/causalvideovae/eval/cal_lpips.py ================================================ import numpy as np import torch from tqdm import tqdm import math import torch import lpips spatial = True # Return a spatial map of perceptual distance. # Linearly calibrated models (LPIPS) loss_fn = lpips.LPIPS(net='alex', spatial=spatial) # Can also set net = 'squeeze' or 'vgg' # loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg' def trans(x): # if greyscale images add channel if x.shape[-3] == 1: x = x.repeat(1, 1, 3, 1, 1) # value range [0, 1] -> [-1, 1] x = x * 2 - 1 return x def calculate_lpips(videos1, videos2, device): # image should be RGB, IMPORTANT: normalized to [-1,1] print("calculate_lpips...") assert videos1.shape == videos2.shape # videos [batch_size, timestamps, channel, h, w] # support grayscale input, if grayscale -> channel*3 # value range [0, 1] -> [-1, 1] videos1 = trans(videos1) videos2 = trans(videos2) lpips_results = [] for video_num in tqdm(range(videos1.shape[0])): # get a video # video [timestamps, channel, h, w] video1 = videos1[video_num] video2 = videos2[video_num] lpips_results_of_a_video = [] for clip_timestamp in range(len(video1)): # get a img # img [timestamps[x], channel, h, w] # img [channel, h, w] tensor img1 = video1[clip_timestamp].unsqueeze(0).to(device) img2 = video2[clip_timestamp].unsqueeze(0).to(device) loss_fn.to(device) # calculate lpips of a video lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist()) lpips_results.append(lpips_results_of_a_video) lpips_results = np.array(lpips_results) lpips = {} lpips_std = {} for clip_timestamp in range(len(video1)): lpips[clip_timestamp] = np.mean(lpips_results[:,clip_timestamp]) lpips_std[clip_timestamp] = np.std(lpips_results[:,clip_timestamp]) result = { "value": lpips, "value_std": lpips_std, "video_setting": video1.shape, "video_setting_name": "time, channel, heigth, width", } return result # test code / using example def main(): NUMBER_OF_VIDEOS = 8 VIDEO_LENGTH = 50 CHANNEL = 3 SIZE = 64 videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) device = torch.device("cuda") # device = torch.device("cpu") import json result = calculate_lpips(videos1, videos2, device) print(json.dumps(result, indent=4)) if __name__ == "__main__": main() ================================================ FILE: opensora/models/causalvideovae/eval/cal_psnr.py ================================================ import numpy as np import torch from tqdm import tqdm import math def img_psnr_cuda(img1, img2): # [0,1] # compute mse # mse = np.mean((img1-img2)**2) mse = torch.mean((img1 / 1.0 - img2 / 1.0) ** 2) # compute psnr if mse < 1e-10: return 100 psnr = 20 * torch.log10(1 / torch.sqrt(mse)) return psnr def img_psnr(img1, img2): # [0,1] # compute mse # mse = np.mean((img1-img2)**2) mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2) # compute psnr if mse < 1e-10: return 100 psnr = 20 * math.log10(1 / math.sqrt(mse)) return psnr def trans(x): return x def calculate_psnr(videos1, videos2): print("calculate_psnr...") # videos [batch_size, timestamps, channel, h, w] assert videos1.shape == videos2.shape videos1 = trans(videos1) videos2 = trans(videos2) psnr_results = [] for video_num in tqdm(range(videos1.shape[0])): # get a video # video [timestamps, channel, h, w] video1 = videos1[video_num] video2 = videos2[video_num] psnr_results_of_a_video = [] for clip_timestamp in range(len(video1)): # get a img # img [timestamps[x], channel, h, w] # img [channel, h, w] numpy img1 = video1[clip_timestamp].numpy() img2 = video2[clip_timestamp].numpy() # calculate psnr of a video psnr_results_of_a_video.append(img_psnr(img1, img2)) psnr_results.append(psnr_results_of_a_video) psnr_results = np.array(psnr_results) # [batch_size, num_frames] psnr = {} psnr_std = {} for clip_timestamp in range(len(video1)): psnr[clip_timestamp] = np.mean(psnr_results[:,clip_timestamp]) psnr_std[clip_timestamp] = np.std(psnr_results[:,clip_timestamp]) result = { "value": psnr, "value_std": psnr_std, "video_setting": video1.shape, "video_setting_name": "time, channel, heigth, width", } return result # test code / using example def main(): NUMBER_OF_VIDEOS = 8 VIDEO_LENGTH = 50 CHANNEL = 3 SIZE = 64 videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) import json result = calculate_psnr(videos1, videos2) print(json.dumps(result, indent=4)) if __name__ == "__main__": main() ================================================ FILE: opensora/models/causalvideovae/eval/cal_ssim.py ================================================ import numpy as np import torch from tqdm import tqdm import cv2 def ssim(img1, img2): C1 = 0.01 ** 2 C2 = 0.03 ** 2 img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) kernel = cv2.getGaussianKernel(11, 1.5) window = np.outer(kernel, kernel.transpose()) mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] mu1_sq = mu1 ** 2 mu2_sq = mu2 ** 2 mu1_mu2 = mu1 * mu2 sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) return ssim_map.mean() def calculate_ssim_function(img1, img2): # [0,1] # ssim is the only metric extremely sensitive to gray being compared to b/w if not img1.shape == img2.shape: raise ValueError('Input images must have the same dimensions.') if img1.ndim == 2: return ssim(img1, img2) elif img1.ndim == 3: if img1.shape[0] == 3: ssims = [] for i in range(3): ssims.append(ssim(img1[i], img2[i])) return np.array(ssims).mean() elif img1.shape[0] == 1: return ssim(np.squeeze(img1), np.squeeze(img2)) else: raise ValueError('Wrong input image dimensions.') def trans(x): return x def calculate_ssim(videos1, videos2): print("calculate_ssim...") # videos [batch_size, timestamps, channel, h, w] assert videos1.shape == videos2.shape videos1 = trans(videos1) videos2 = trans(videos2) ssim_results = [] for video_num in tqdm(range(videos1.shape[0])): # get a video # video [timestamps, channel, h, w] video1 = videos1[video_num] video2 = videos2[video_num] ssim_results_of_a_video = [] for clip_timestamp in range(len(video1)): # get a img # img [timestamps[x], channel, h, w] # img [channel, h, w] numpy img1 = video1[clip_timestamp].numpy() img2 = video2[clip_timestamp].numpy() # calculate ssim of a video ssim_results_of_a_video.append(calculate_ssim_function(img1, img2)) ssim_results.append(ssim_results_of_a_video) ssim_results = np.array(ssim_results) ssim = {} ssim_std = {} for clip_timestamp in range(len(video1)): ssim[clip_timestamp] = np.mean(ssim_results[:,clip_timestamp]) ssim_std[clip_timestamp] = np.std(ssim_results[:,clip_timestamp]) result = { "value": ssim, "value_std": ssim_std, "video_setting": video1.shape, "video_setting_name": "time, channel, heigth, width", } return result # test code / using example def main(): NUMBER_OF_VIDEOS = 8 VIDEO_LENGTH = 50 CHANNEL = 3 SIZE = 64 videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) device = torch.device("cuda") import json result = calculate_ssim(videos1, videos2) print(json.dumps(result, indent=4)) if __name__ == "__main__": main() ================================================ FILE: opensora/models/causalvideovae/eval/eval.py ================================================ import os from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser import numpy as np import torch from torch.utils.data import DataLoader, Subset from tqdm import tqdm import sys from glob import glob sys.path.append(".") from opensora.models.causalvideovae.eval.cal_lpips import calculate_lpips from opensora.models.causalvideovae.eval.cal_fvd import calculate_fvd from opensora.models.causalvideovae.eval.cal_psnr import calculate_psnr from opensora.models.causalvideovae.eval.cal_ssim import calculate_ssim from opensora.models.causalvideovae.dataset.video_dataset import ( ValidVideoDataset, DecordInit, Compose, Lambda, resize, CenterCropVideo, ToTensorVideo ) class EvalDataset(ValidVideoDataset): def __init__( self, real_video_dir, generated_video_dir, num_frames, sample_rate=1, crop_size=None, resolution=128, ) -> None: self.is_main_process = False self.v_decoder = DecordInit() self.real_video_files = [] self.generated_video_files = self._make_dataset(generated_video_dir) for video_file in self.generated_video_files: filename = os.path.basename(video_file) if not os.path.exists(os.path.join(real_video_dir, filename)): raise Exception(os.path.join(real_video_dir, filename)) self.real_video_files.append(os.path.join(real_video_dir, filename)) self.num_frames = num_frames self.sample_rate = sample_rate self.crop_size = crop_size self.short_size = resolution self.transform = Compose( [ ToTensorVideo(), Lambda(lambda x: resize(x, self.short_size)), ( CenterCropVideo(crop_size) if crop_size is not None else Lambda(lambda x: x) ), ] ) def _make_dataset(self, real_video_dir): samples = [] samples += sum( [ glob(os.path.join(real_video_dir, f"*.{ext}"), recursive=True) for ext in self.video_exts ], [], ) return samples def __len__(self): return len(self.real_video_files) def __getitem__(self, index): if index >= len(self): raise IndexError real_video_file = self.real_video_files[index] generated_video_file = self.generated_video_files[index] real_video_tensor = self._load_video(real_video_file, self.sample_rate) generated_video_tensor = self._load_video(generated_video_file, 1) return {"real": self.transform(real_video_tensor), "generated": self.transform(generated_video_tensor)} def calculate_common_metric(args, dataloader, device): score_list = [] for batch_data in tqdm(dataloader): real_videos = batch_data["real"].to(device) generated_videos = batch_data["generated"].to(device) assert real_videos.shape[2] == generated_videos.shape[2] if args.metric == "fvd": tmp_list = list( calculate_fvd( real_videos, generated_videos, args.device, method=args.fvd_method )["value"].values() ) elif args.metric == "ssim": tmp_list = list( calculate_ssim(real_videos, generated_videos)["value"].values() ) elif args.metric == "psnr": tmp_list = [calculate_psnr(real_videos, generated_videos)] else: tmp_list = [calculate_lpips(real_videos, generated_videos, args.device)] score_list += tmp_list return np.mean(score_list) def main(): if args.device is None: device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu") else: device = torch.device(args.device) if args.num_workers is None: try: num_cpus = len(os.sched_getaffinity(0)) except AttributeError: num_cpus = os.cpu_count() num_workers = min(num_cpus, 8) if num_cpus is not None else 0 else: num_workers = args.num_workers dataset = EvalDataset( args.real_video_dir, args.generated_video_dir, num_frames=args.num_frames, sample_rate=args.sample_rate, crop_size=args.crop_size, resolution=args.resolution, ) if args.subset_size: indices = range(args.subset_size) dataset = Subset(dataset, indices=indices) dataloader = DataLoader( dataset, args.batch_size, num_workers=num_workers, pin_memory=True ) metric_score = calculate_common_metric(args, dataloader, device) print(metric_score) def parse_args(): parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) parser.add_argument("--batch_size", type=int, default=2, help="Batch size to use") parser.add_argument("--real_video_dir", type=str, help=("the path of real videos`")) parser.add_argument( "--generated_video_dir", type=str, help=("the path of generated videos`") ) parser.add_argument( "--device", type=str, default=None, help="Device to use. Like cuda, cuda:0 or cpu", ) parser.add_argument( "--num_workers", type=int, default=8, help=( "Number of processes to use for data loading. " "Defaults to `min(8, num_cpus)`" ), ) parser.add_argument("--sample_fps", type=int, default=30) parser.add_argument("--resolution", type=int, default=336) parser.add_argument("--crop_size", type=int, default=None) parser.add_argument("--num_frames", type=int, default=100) parser.add_argument("--sample_rate", type=int, default=1) parser.add_argument("--subset_size", type=int, default=None) parser.add_argument( "--metric", type=str, default="fvd", choices=["fvd", "psnr", "ssim", "lpips", "flolpips"], ) args = parser.parse_args() return args if __name__ == "__main__": args = parse_args() main() ================================================ FILE: opensora/models/causalvideovae/eval/fvd/styleganv/fvd.py ================================================ import torch import os import math import torch.nn.functional as F # https://github.com/universome/fvd-comparison def load_i3d_pretrained(device=torch.device('cpu')): i3D_WEIGHTS_URL = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt" filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_torchscript.pt') print(filepath) if not os.path.exists(filepath): print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.") os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}") i3d = torch.jit.load(filepath).eval().to(device) i3d = torch.nn.DataParallel(i3d) return i3d def get_feats(videos, detector, device, bs=10): # videos : torch.tensor BCTHW [0, 1] detector_kwargs = dict(rescale=False, resize=False, return_features=True) # Return raw features before the softmax layer. feats = np.empty((0, 400)) with torch.no_grad(): for i in range((len(videos)-1)//bs + 1): feats = np.vstack([feats, detector(torch.stack([preprocess_single(video) for video in videos[i*bs:(i+1)*bs]]).to(device), **detector_kwargs).detach().cpu().numpy()]) return feats def get_fvd_feats(videos, i3d, device, bs=10): # videos in [0, 1] as torch tensor BCTHW # videos = [preprocess_single(video) for video in videos] embeddings = get_feats(videos, i3d, device, bs) return embeddings def preprocess_single(video, resolution=224, sequence_length=None): # video: CTHW, [0, 1] c, t, h, w = video.shape # temporal crop if sequence_length is not None: assert sequence_length <= t video = video[:, :sequence_length] # scale shorter side to resolution scale = resolution / min(h, w) if h < w: target_size = (resolution, math.ceil(w * scale)) else: target_size = (math.ceil(h * scale), resolution) video = F.interpolate(video, size=target_size, mode='bilinear', align_corners=False) # center crop c, t, h, w = video.shape w_start = (w - resolution) // 2 h_start = (h - resolution) // 2 video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution] # [0, 1] -> [-1, 1] video = (video - 0.5) * 2 return video.contiguous() """ Copy-pasted from https://github.com/cvpr2022-stylegan-v/stylegan-v/blob/main/src/metrics/frechet_video_distance.py """ from typing import Tuple from scipy.linalg import sqrtm import numpy as np def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: mu = feats.mean(axis=0) # [d] sigma = np.cov(feats, rowvar=False) # [d, d] return mu, sigma def frechet_distance(feats_fake: np.ndarray, feats_real: np.ndarray) -> float: mu_gen, sigma_gen = compute_stats(feats_fake) mu_real, sigma_real = compute_stats(feats_real) m = np.square(mu_gen - mu_real).sum() if feats_fake.shape[0]>1: s, _ = sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) else: fid = np.real(m) return float(fid) ================================================ FILE: opensora/models/causalvideovae/eval/fvd/videogpt/fvd.py ================================================ import torch import os import math import torch.nn.functional as F import numpy as np import einops def load_i3d_pretrained(device=torch.device('cpu')): i3D_WEIGHTS_URL = "https://onedrive.live.com/download?cid=78EEF3EB6AE7DBCB&resid=78EEF3EB6AE7DBCB%21199&authkey=AApKdFHPXzWLNyI" filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_pretrained_400.pt') print(filepath) if not os.path.exists(filepath): print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.") os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}") from .pytorch_i3d import InceptionI3d i3d = InceptionI3d(400, in_channels=3).eval().to(device) i3d.load_state_dict(torch.load(filepath, map_location=device)) i3d = torch.nn.DataParallel(i3d) return i3d def preprocess_single(video, resolution, sequence_length=None): # video: THWC, {0, ..., 255} video = video.permute(0, 3, 1, 2).float() / 255. # TCHW t, c, h, w = video.shape # temporal crop if sequence_length is not None: assert sequence_length <= t video = video[:sequence_length] # scale shorter side to resolution scale = resolution / min(h, w) if h < w: target_size = (resolution, math.ceil(w * scale)) else: target_size = (math.ceil(h * scale), resolution) video = F.interpolate(video, size=target_size, mode='bilinear', align_corners=False) # center crop t, c, h, w = video.shape w_start = (w - resolution) // 2 h_start = (h - resolution) // 2 video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution] video = video.permute(1, 0, 2, 3).contiguous() # CTHW video -= 0.5 return video def preprocess(videos, target_resolution=224): # we should tras videos in [0-1] [b c t h w] as th.float # -> videos in {0, ..., 255} [b t h w c] as np.uint8 array videos = einops.rearrange(videos, 'b c t h w -> b t h w c') videos = (videos*255).numpy().astype(np.uint8) b, t, h, w, c = videos.shape videos = torch.from_numpy(videos) videos = torch.stack([preprocess_single(video, target_resolution) for video in videos]) return videos * 2 # [-0.5, 0.5] -> [-1, 1] def get_fvd_logits(videos, i3d, device, bs=10): videos = preprocess(videos) embeddings = get_logits(i3d, videos, device, bs=10) return embeddings # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161 def _symmetric_matrix_square_root(mat, eps=1e-10): u, s, v = torch.svd(mat) si = torch.where(s < eps, s, torch.sqrt(s)) return torch.matmul(torch.matmul(u, torch.diag(si)), v.t()) # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L400 def trace_sqrt_product(sigma, sigma_v): sqrt_sigma = _symmetric_matrix_square_root(sigma) sqrt_a_sigmav_a = torch.matmul(sqrt_sigma, torch.matmul(sigma_v, sqrt_sigma)) return torch.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a)) # https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2 def cov(m, rowvar=False): '''Estimate a covariance matrix given data. Covariance indicates the level to which two variables vary together. If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, then the covariance matrix element `C_{ij}` is the covariance of `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. Args: m: A 1-D or 2-D array containing multiple variables and observations. Each row of `m` represents a variable, and each column a single observation of all those variables. rowvar: If `rowvar` is True, then each row represents a variable, with observations in the columns. Otherwise, the relationship is transposed: each column represents a variable, while the rows contain observations. Returns: The covariance matrix of the variables. ''' if m.dim() > 2: raise ValueError('m has more than 2 dimensions') if m.dim() < 2: m = m.view(1, -1) if not rowvar and m.size(0) != 1: m = m.t() fact = 1.0 / (m.size(1) - 1) # unbiased estimate m -= torch.mean(m, dim=1, keepdim=True) mt = m.t() # if complex: mt = m.t().conj() return fact * m.matmul(mt).squeeze() def frechet_distance(x1, x2): x1 = x1.flatten(start_dim=1) x2 = x2.flatten(start_dim=1) m, m_w = x1.mean(dim=0), x2.mean(dim=0) sigma, sigma_w = cov(x1, rowvar=False), cov(x2, rowvar=False) mean = torch.sum((m - m_w) ** 2) if x1.shape[0]>1: sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) trace = torch.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component fd = trace + mean else: fd = np.real(mean) return float(fd) def get_logits(i3d, videos, device, bs=10): # assert videos.shape[0] % 16 == 0 with torch.no_grad(): logits = [] for i in range(0, videos.shape[0], bs): batch = videos[i:i + bs].to(device) # logits.append(i3d.module.extract_features(batch)) # wrong logits.append(i3d(batch)) # right logits = torch.cat(logits, dim=0) return logits ================================================ FILE: opensora/models/causalvideovae/eval/fvd/videogpt/pytorch_i3d.py ================================================ # Original code from https://github.com/piergiaj/pytorch-i3d import torch import torch.nn as nn import torch.nn.functional as F import numpy as np class MaxPool3dSamePadding(nn.MaxPool3d): def compute_pad(self, dim, s): if s % self.stride[dim] == 0: return max(self.kernel_size[dim] - self.stride[dim], 0) else: return max(self.kernel_size[dim] - (s % self.stride[dim]), 0) def forward(self, x): # compute 'same' padding (batch, channel, t, h, w) = x.size() out_t = np.ceil(float(t) / float(self.stride[0])) out_h = np.ceil(float(h) / float(self.stride[1])) out_w = np.ceil(float(w) / float(self.stride[2])) pad_t = self.compute_pad(0, t) pad_h = self.compute_pad(1, h) pad_w = self.compute_pad(2, w) pad_t_f = pad_t // 2 pad_t_b = pad_t - pad_t_f pad_h_f = pad_h // 2 pad_h_b = pad_h - pad_h_f pad_w_f = pad_w // 2 pad_w_b = pad_w - pad_w_f pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) x = F.pad(x, pad) return super(MaxPool3dSamePadding, self).forward(x) class Unit3D(nn.Module): def __init__(self, in_channels, output_channels, kernel_shape=(1, 1, 1), stride=(1, 1, 1), padding=0, activation_fn=F.relu, use_batch_norm=True, use_bias=False, name='unit_3d'): """Initializes Unit3D module.""" super(Unit3D, self).__init__() self._output_channels = output_channels self._kernel_shape = kernel_shape self._stride = stride self._use_batch_norm = use_batch_norm self._activation_fn = activation_fn self._use_bias = use_bias self.name = name self.padding = padding self.conv3d = nn.Conv3d(in_channels=in_channels, out_channels=self._output_channels, kernel_size=self._kernel_shape, stride=self._stride, padding=0, # we always want padding to be 0 here. We will dynamically pad based on input size in forward function bias=self._use_bias) if self._use_batch_norm: self.bn = nn.BatchNorm3d(self._output_channels, eps=1e-5, momentum=0.001) def compute_pad(self, dim, s): if s % self._stride[dim] == 0: return max(self._kernel_shape[dim] - self._stride[dim], 0) else: return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0) def forward(self, x): # compute 'same' padding (batch, channel, t, h, w) = x.size() out_t = np.ceil(float(t) / float(self._stride[0])) out_h = np.ceil(float(h) / float(self._stride[1])) out_w = np.ceil(float(w) / float(self._stride[2])) pad_t = self.compute_pad(0, t) pad_h = self.compute_pad(1, h) pad_w = self.compute_pad(2, w) pad_t_f = pad_t // 2 pad_t_b = pad_t - pad_t_f pad_h_f = pad_h // 2 pad_h_b = pad_h - pad_h_f pad_w_f = pad_w // 2 pad_w_b = pad_w - pad_w_f pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) x = F.pad(x, pad) x = self.conv3d(x) if self._use_batch_norm: x = self.bn(x) if self._activation_fn is not None: x = self._activation_fn(x) return x class InceptionModule(nn.Module): def __init__(self, in_channels, out_channels, name): super(InceptionModule, self).__init__() self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0, name=name+'/Branch_0/Conv3d_0a_1x1') self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0, name=name+'/Branch_1/Conv3d_0a_1x1') self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 3, 3], name=name+'/Branch_1/Conv3d_0b_3x3') self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0, name=name+'/Branch_2/Conv3d_0a_1x1') self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 3, 3], name=name+'/Branch_2/Conv3d_0b_3x3') self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(1, 1, 1), padding=0) self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0, name=name+'/Branch_3/Conv3d_0b_1x1') self.name = name def forward(self, x): b0 = self.b0(x) b1 = self.b1b(self.b1a(x)) b2 = self.b2b(self.b2a(x)) b3 = self.b3b(self.b3a(x)) return torch.cat([b0,b1,b2,b3], dim=1) class InceptionI3d(nn.Module): """Inception-v1 I3D architecture. The model is introduced in: Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset Joao Carreira, Andrew Zisserman https://arxiv.org/pdf/1705.07750v1.pdf. See also the Inception architecture, introduced in: Going deeper with convolutions Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich. http://arxiv.org/pdf/1409.4842v1.pdf. """ # Endpoints of the model in order. During construction, all the endpoints up # to a designated `final_endpoint` are returned in a dictionary as the # second return value. VALID_ENDPOINTS = ( 'Conv3d_1a_7x7', 'MaxPool3d_2a_3x3', 'Conv3d_2b_1x1', 'Conv3d_2c_3x3', 'MaxPool3d_3a_3x3', 'Mixed_3b', 'Mixed_3c', 'MaxPool3d_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 'Mixed_4f', 'MaxPool3d_5a_2x2', 'Mixed_5b', 'Mixed_5c', 'Logits', 'Predictions', ) def __init__(self, num_classes=400, spatial_squeeze=True, final_endpoint='Logits', name='inception_i3d', in_channels=3, dropout_keep_prob=0.5): """Initializes I3D model instance. Args: num_classes: The number of outputs in the logit layer (default 400, which matches the Kinetics dataset). spatial_squeeze: Whether to squeeze the spatial dimensions for the logits before returning (default True). final_endpoint: The model contains many possible endpoints. `final_endpoint` specifies the last endpoint for the model to be built up to. In addition to the output at `final_endpoint`, all the outputs at endpoints up to `final_endpoint` will also be returned, in a dictionary. `final_endpoint` must be one of InceptionI3d.VALID_ENDPOINTS (default 'Logits'). name: A string (optional). The name of this module. Raises: ValueError: if `final_endpoint` is not recognized. """ if final_endpoint not in self.VALID_ENDPOINTS: raise ValueError('Unknown final endpoint %s' % final_endpoint) super(InceptionI3d, self).__init__() self._num_classes = num_classes self._spatial_squeeze = spatial_squeeze self._final_endpoint = final_endpoint self.logits = None if self._final_endpoint not in self.VALID_ENDPOINTS: raise ValueError('Unknown final endpoint %s' % self._final_endpoint) self.end_points = {} end_point = 'Conv3d_1a_7x7' self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7], stride=(2, 2, 2), padding=(3,3,3), name=name+end_point) if self._final_endpoint == end_point: return end_point = 'MaxPool3d_2a_3x3' self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0) if self._final_endpoint == end_point: return end_point = 'Conv3d_2b_1x1' self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0, name=name+end_point) if self._final_endpoint == end_point: return end_point = 'Conv3d_2c_3x3' self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1, name=name+end_point) if self._final_endpoint == end_point: return end_point = 'MaxPool3d_3a_3x3' self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0) if self._final_endpoint == end_point: return end_point = 'Mixed_3b' self.end_points[end_point] = InceptionModule(192, [64,96,128,16,32,32], name+end_point) if self._final_endpoint == end_point: return end_point = 'Mixed_3c' self.end_points[end_point] = InceptionModule(256, [128,128,192,32,96,64], name+end_point) if self._final_endpoint == end_point: return end_point = 'MaxPool3d_4a_3x3' self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2), padding=0) if self._final_endpoint == end_point: return end_point = 'Mixed_4b' self.end_points[end_point] = InceptionModule(128+192+96+64, [192,96,208,16,48,64], name+end_point) if self._final_endpoint == end_point: return end_point = 'Mixed_4c' self.end_points[end_point] = InceptionModule(192+208+48+64, [160,112,224,24,64,64], name+end_point) if self._final_endpoint == end_point: return end_point = 'Mixed_4d' self.end_points[end_point] = InceptionModule(160+224+64+64, [128,128,256,24,64,64], name+end_point) if self._final_endpoint == end_point: return end_point = 'Mixed_4e' self.end_points[end_point] = InceptionModule(128+256+64+64, [112,144,288,32,64,64], name+end_point) if self._final_endpoint == end_point: return end_point = 'Mixed_4f' self.end_points[end_point] = InceptionModule(112+288+64+64, [256,160,320,32,128,128], name+end_point) if self._final_endpoint == end_point: return end_point = 'MaxPool3d_5a_2x2' self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2), padding=0) if self._final_endpoint == end_point: return end_point = 'Mixed_5b' self.end_points[end_point] = InceptionModule(256+320+128+128, [256,160,320,32,128,128], name+end_point) if self._final_endpoint == end_point: return end_point = 'Mixed_5c' self.end_points[end_point] = InceptionModule(256+320+128+128, [384,192,384,48,128,128], name+end_point) if self._final_endpoint == end_point: return end_point = 'Logits' self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7], stride=(1, 1, 1)) self.dropout = nn.Dropout(dropout_keep_prob) self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes, kernel_shape=[1, 1, 1], padding=0, activation_fn=None, use_batch_norm=False, use_bias=True, name='logits') self.build() def replace_logits(self, num_classes): self._num_classes = num_classes self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes, kernel_shape=[1, 1, 1], padding=0, activation_fn=None, use_batch_norm=False, use_bias=True, name='logits') def build(self): for k in self.end_points.keys(): self.add_module(k, self.end_points[k]) def forward(self, x): for end_point in self.VALID_ENDPOINTS: if end_point in self.end_points: x = self._modules[end_point](x) # use _modules to work with dataparallel x = self.logits(self.dropout(self.avg_pool(x))) if self._spatial_squeeze: logits = x.squeeze(3).squeeze(3) logits = logits.mean(dim=2) # logits is batch X time X classes, which is what we want to work with return logits def extract_features(self, x): for end_point in self.VALID_ENDPOINTS: if end_point in self.end_points: x = self._modules[end_point](x) return self.avg_pool(x) ================================================ FILE: opensora/models/causalvideovae/eval/script/cal_clip_score.sh ================================================ # clip_score cross modality python eval_clip_score.py \ --real_path path/to/image \ --generated_path path/to/text \ --batch-size 50 \ --device "cuda" # clip_score within the same modality python eval_clip_score.py \ --real_path path/to/textA \ --generated_path path/to/textB \ --real_flag txt \ --generated_flag txt \ --batch-size 50 \ --device "cuda" python eval_clip_score.py \ --real_path path/to/imageA \ --generated_path path/to/imageB \ --real_flag img \ --generated_flag img \ --batch-size 50 \ --device "cuda" ================================================ FILE: opensora/models/causalvideovae/eval/script/cal_fvd.sh ================================================ python eval_common_metric.py \ --real_video_dir path/to/imageA\ --generated_video_dir path/to/imageB \ --batch_size 10 \ --crop_size 64 \ --num_frames 20 \ --device 'cuda' \ --metric 'fvd' \ --fvd_method 'styleganv' ================================================ FILE: opensora/models/causalvideovae/eval/script/cal_lpips.sh ================================================ python eval_common_metric.py \ --real_video_dir path/to/imageA\ --generated_video_dir path/to/imageB \ --batch_size 10 \ --num_frames 20 \ --crop_size 64 \ --device 'cuda' \ --metric 'lpips' ================================================ FILE: opensora/models/causalvideovae/eval/script/cal_psnr.sh ================================================ python eval_common_metric.py \ --real_video_dir /data/xiaogeng_liu/data/video1 \ --generated_video_dir /data/xiaogeng_liu/data/video2 \ --batch_size 10 \ --num_frames 20 \ --crop_size 64 \ --device 'cuda' \ --metric 'psnr' ================================================ FILE: opensora/models/causalvideovae/eval/script/cal_ssim.sh ================================================ python eval_common_metric.py \ --real_video_dir /data/xiaogeng_liu/data/video1 \ --generated_video_dir /data/xiaogeng_liu/data/video2 \ --batch_size 10 \ --num_frames 20 \ --crop_size 64 \ --device 'cuda' \ --metric 'ssim' ================================================ FILE: opensora/models/causalvideovae/model/__init__.py ================================================ from .registry import ModelRegistry from .vae import ( CausalVAEModel, WFVAEModel ) ================================================ FILE: opensora/models/causalvideovae/model/configuration_videobase.py ================================================ import json import yaml from typing import TypeVar, Dict, Any from diffusers import ConfigMixin T = TypeVar('T', bound='VideoBaseConfiguration') class VideoBaseConfiguration(ConfigMixin): config_name = "VideoBaseConfiguration" _nested_config_fields: Dict[str, Any] = {} def __init__(self, **kwargs): pass def to_dict(self) -> Dict[str, Any]: d = {} for key, value in vars(self).items(): if isinstance(value, VideoBaseConfiguration): d[key] = value.to_dict() # Serialize nested VideoBaseConfiguration instances elif isinstance(value, tuple): d[key] = list(value) else: d[key] = value return d def to_yaml_file(self, yaml_path: str): with open(yaml_path, 'w') as yaml_file: yaml.dump(self.to_dict(), yaml_file, default_flow_style=False) @classmethod def load_from_yaml(cls: T, yaml_path: str) -> T: with open(yaml_path, 'r') as yaml_file: config_dict = yaml.safe_load(yaml_file) for field, field_type in cls._nested_config_fields.items(): if field in config_dict: config_dict[field] = field_type.load_from_dict(config_dict[field]) return cls(**config_dict) @classmethod def load_from_dict(cls: T, config_dict: Dict[str, Any]) -> T: # Process nested configuration objects for field, field_type in cls._nested_config_fields.items(): if field in config_dict: config_dict[field] = field_type.load_from_dict(config_dict[field]) return cls(**config_dict) ================================================ FILE: opensora/models/causalvideovae/model/dataset_videobase.py ================================================ import os.path as osp import random from glob import glob from torchvision import transforms import numpy as np import torch import torch.utils.data as data import torch.nn.functional as F from torchvision.transforms import Lambda from ..dataset.transform import ToTensorVideo, CenterCropVideo from ..utils.dataset_utils import DecordInit def TemporalRandomCrop(total_frames, size): """ Performs a random temporal crop on a video sequence. This function randomly selects a continuous frame sequence of length `size` from a video sequence. `total_frames` indicates the total number of frames in the video sequence, and `size` represents the length of the frame sequence to be cropped. Parameters: - total_frames (int): The total number of frames in the video sequence. - size (int): The length of the frame sequence to be cropped. Returns: - (int, int): A tuple containing two integers. The first integer is the starting frame index of the cropped sequence, and the second integer is the ending frame index (inclusive) of the cropped sequence. """ rand_end = max(0, total_frames - size - 1) begin_index = random.randint(0, rand_end) end_index = min(begin_index + size, total_frames) return begin_index, end_index def resize(x, resolution): height, width = x.shape[-2:] resolution = min(2 * resolution, height, width) aspect_ratio = width / height if width <= height: new_width = resolution new_height = int(resolution / aspect_ratio) else: new_height = resolution new_width = int(resolution * aspect_ratio) resized_x = F.interpolate(x, size=(new_height, new_width), mode='bilinear', align_corners=True, antialias=True) return resized_x class VideoDataset(data.Dataset): """ Generic dataset for videos files stored in folders Returns BCTHW videos in the range [-0.5, 0.5] """ video_exts = ['avi', 'mp4', 'webm'] def __init__(self, video_folder, sequence_length, image_folder=None, train=True, resolution=64, sample_rate=1, dynamic_sample=True): self.train = train self.sequence_length = sequence_length self.sample_rate = sample_rate self.resolution = resolution self.v_decoder = DecordInit() self.video_folder = video_folder self.dynamic_sample = dynamic_sample self.transform = transforms.Compose([ ToTensorVideo(), # Lambda(lambda x: resize(x, self.resolution)), CenterCropVideo(self.resolution), Lambda(lambda x: 2.0 * x - 1.0) ]) print('Building datasets...') self.samples = self._make_dataset() def _make_dataset(self): samples = [] samples += sum([glob(osp.join(self.video_folder, '**', f'*.{ext}'), recursive=True) for ext in self.video_exts], []) return samples def __len__(self): return len(self.samples) def __getitem__(self, idx): video_path = self.samples[idx] try: video = self.decord_read(video_path) video = self.transform(video) # T C H W -> T C H W video = video.transpose(0, 1) # T C H W -> C T H W return dict(video=video, label="") except Exception as e: print(f'Error with {e}, {video_path}') return self.__getitem__(random.randint(0, self.__len__()-1)) def decord_read(self, path): decord_vr = self.v_decoder(path) total_frames = len(decord_vr) # Sampling video frames if self.dynamic_sample: sample_rate = random.randint(1, self.sample_rate) else: sample_rate = self.sample_rate size = self.sequence_length * sample_rate start_frame_ind, end_frame_ind = TemporalRandomCrop(total_frames, size) # assert end_frame_ind - start_frame_ind >= self.num_frames frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.sequence_length, dtype=int) video_data = decord_vr.get_batch(frame_indice).asnumpy() video_data = torch.from_numpy(video_data) video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W) return video_data ================================================ FILE: opensora/models/causalvideovae/model/ema_model.py ================================================ class EMA: def __init__(self, model, decay): self.model = model self.decay = decay self.shadow = {} self.backup = {} def register(self): for name, param in self.model.named_parameters(): if param.requires_grad: self.shadow[name] = param.data.clone() def update(self): for name, param in self.model.named_parameters(): if name in self.shadow: new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name] self.shadow[name] = new_average.clone() def apply_shadow(self): for name, param in self.model.named_parameters(): if name in self.shadow: self.backup[name] = param.data param.data = self.shadow[name] def restore(self): for name, param in self.model.named_parameters(): if name in self.shadow: param.data = self.backup[name] self.backup = {} ================================================ FILE: opensora/models/causalvideovae/model/losses/__init__.py ================================================ from .perceptual_loss import LPIPSWithDiscriminator3D ================================================ FILE: opensora/models/causalvideovae/model/losses/discriminator.py ================================================ import functools import torch.nn as nn from ..modules.conv import CausalConv3d from einops import rearrange def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find('BatchNorm') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0) def weights_init_conv(m): if hasattr(m, 'conv'): m = m.conv classname = m.__class__.__name__ if classname.find('Conv') != -1: nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find('BatchNorm') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0) class NLayerDiscriminator3D(nn.Module): """Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs.""" def __init__(self, input_nc=1, ndf=64, n_layers=3, use_actnorm=False): """ Construct a 3D PatchGAN discriminator Parameters: input_nc (int) -- the number of channels in input volumes ndf (int) -- the number of filters in the last conv layer n_layers (int) -- the number of conv layers in the discriminator use_actnorm (bool) -- flag to use actnorm instead of batchnorm """ super(NLayerDiscriminator3D, self).__init__() if not use_actnorm: norm_layer = nn.BatchNorm3d else: raise NotImplementedError("Not implemented.") if type(norm_layer) == functools.partial: use_bias = norm_layer.func != nn.BatchNorm3d else: use_bias = norm_layer != nn.BatchNorm3d kw = 3 padw = 1 sequence = [nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] nf_mult = 1 nf_mult_prev = 1 for n in range(1, n_layers): # gradually increase the number of filters nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) sequence += [ nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=(2 if n==1 else 1,2,2), padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] nf_mult_prev = nf_mult nf_mult = min(2 ** n_layers, 8) sequence += [ nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] sequence += [nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map self.main = nn.Sequential(*sequence) def forward(self, input): """Standard forward.""" return self.main(input) # class NLayerDiscriminator3D(nn.Module): # """Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs.""" # def __init__(self, input_nc=1, ndf=64, n_layers=3, use_actnorm=False): # """ # Construct a 3D PatchGAN discriminator # Parameters: # input_nc (int) -- the number of channels in input volumes # ndf (int) -- the number of filters in the last conv layer # n_layers (int) -- the number of conv layers in the discriminator # use_actnorm (bool) -- flag to use actnorm instead of batchnorm # """ # super(NLayerDiscriminator3D, self).__init__() # if not use_actnorm: # norm_layer = nn.BatchNorm3d # else: # raise NotImplementedError("Not implemented.") # if type(norm_layer) == functools.partial: # use_bias = norm_layer.func != nn.BatchNorm3d # else: # use_bias = norm_layer != nn.BatchNorm3d # kw = 4 # padw = 1 # sequence = [CausalConv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] # nf_mult = 1 # nf_mult_prev = 1 # for n in range(1, n_layers): # gradually increase the number of filters # nf_mult_prev = nf_mult # nf_mult = min(2 ** n, 8) # sequence += [ # CausalConv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=(2 if n==1 else 1,2,2), padding=padw, bias=use_bias), # norm_layer(ndf * nf_mult), # nn.LeakyReLU(0.2, True) # ] # nf_mult_prev = nf_mult # nf_mult = min(2 ** n_layers, 8) # sequence += [ # CausalConv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias), # norm_layer(ndf * nf_mult), # nn.LeakyReLU(0.2, True) # ] # sequence += [CausalConv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map # self.main = nn.Sequential(*sequence) # def forward(self, input): # """Standard forward.""" # return self.main(input) ================================================ FILE: opensora/models/causalvideovae/model/losses/lpips.py ================================================ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" import torch import torch.nn as nn from torchvision import models from collections import namedtuple from .....utils.taming_download import get_ckpt_path class LPIPS(nn.Module): # Learned perceptual metric def __init__(self, use_dropout=True): super().__init__() self.scaling_layer = ScalingLayer() self.chns = [64, 128, 256, 512, 512] # vg16 features self.net = vgg16(pretrained=True, requires_grad=False) self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) self.load_from_pretrained() for param in self.parameters(): param.requires_grad = False def load_from_pretrained(self, name="vgg_lpips"): ckpt = get_ckpt_path(name, ".cache/lpips") self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) print("loaded pretrained LPIPS loss from {}".format(ckpt)) @classmethod def from_pretrained(cls, name="vgg_lpips"): if name != "vgg_lpips": raise NotImplementedError model = cls() ckpt = get_ckpt_path(name) model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) return model def forward(self, input, target): in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) outs0, outs1 = self.net(in0_input), self.net(in1_input) feats0, feats1, diffs = {}, {}, {} lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] for kk in range(len(self.chns)): feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] val = res[0] for l in range(1, len(self.chns)): val += res[l] return val class ScalingLayer(nn.Module): def __init__(self): super(ScalingLayer, self).__init__() self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) def forward(self, inp): return (inp - self.shift) / self.scale class NetLinLayer(nn.Module): """ A single linear layer which does a 1x1 conv """ def __init__(self, chn_in, chn_out=1, use_dropout=False): super(NetLinLayer, self).__init__() layers = [nn.Dropout(), ] if (use_dropout) else [] layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] self.model = nn.Sequential(*layers) class vgg16(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): super(vgg16, self).__init__() vgg_pretrained_features = models.vgg16(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.N_slices = 5 for x in range(4): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(4, 9): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(9, 16): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(16, 23): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(23, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h = self.slice1(X) h_relu1_2 = h h = self.slice2(h) h_relu2_2 = h h = self.slice3(h) h_relu3_3 = h h = self.slice4(h) h_relu4_3 = h h = self.slice5(h) h_relu5_3 = h vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) return out def normalize_tensor(x,eps=1e-10): norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) return x/(norm_factor+eps) def spatial_average(x, keepdim=True): return x.mean([2,3],keepdim=keepdim) ================================================ FILE: opensora/models/causalvideovae/model/losses/perceptual_loss.py ================================================ import torch from torch import nn import torch.nn.functional as F from .lpips import LPIPS from einops import rearrange from .discriminator import weights_init, NLayerDiscriminator3D def hinge_d_loss(logits_real, logits_fake): loss_real = torch.mean(F.relu(1.0 - logits_real)) loss_fake = torch.mean(F.relu(1.0 + logits_fake)) d_loss = 0.5 * (loss_real + loss_fake) return d_loss def vanilla_d_loss(logits_real, logits_fake): d_loss = 0.5 * ( torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake)) ) return d_loss def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3]) loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3]) loss_real = (weights * loss_real).sum() / weights.sum() loss_fake = (weights * loss_fake).sum() / weights.sum() d_loss = 0.5 * (loss_real + loss_fake) return d_loss def adopt_weight(weight, global_step, threshold=0, value=0.0): if global_step < threshold: weight = value return weight def measure_perplexity(predicted_indices, n_embed): # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) avg_probs = encodings.mean(0) perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() cluster_use = torch.sum(avg_probs > 0) return perplexity, cluster_use def l1(x, y): return torch.abs(x - y) def l2(x, y): return torch.pow((x - y), 2) class LPIPSWithDiscriminator3D(nn.Module): def __init__( self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, perceptual_weight=1.0, disc_num_layers=4, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, use_actnorm=False, disc_conditional=False, disc_loss="hinge", learn_logvar: bool = False, wavelet_weight=0.01, loss_type: str = "l1", ): super().__init__() assert disc_loss in ["hinge", "vanilla"] self.wavelet_weight = wavelet_weight self.kl_weight = kl_weight self.pixel_weight = pixelloss_weight self.perceptual_loss = LPIPS().eval() self.perceptual_weight = perceptual_weight self.logvar = nn.Parameter( torch.full((), logvar_init), requires_grad=learn_logvar ) self.discriminator = NLayerDiscriminator3D( input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm ).apply(weights_init) self.discriminator_iter_start = disc_start self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss self.disc_factor = disc_factor self.discriminator_weight = disc_weight self.disc_conditional = disc_conditional self.loss_func = l1 if loss_type == "l1" else l2 def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): layer = last_layer if last_layer is not None else self.last_layer[0] nll_grads = torch.autograd.grad(nll_loss, layer, retain_graph=True)[0] g_grads = torch.autograd.grad(g_loss, layer, retain_graph=True)[0] d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() d_weight = d_weight * self.discriminator_weight return d_weight def forward( self, inputs, reconstructions, posteriors, optimizer_idx, global_step, split="train", weights=None, last_layer=None, wavelet_coeffs=None, cond=None, ): bs = inputs.shape[0] t = inputs.shape[2] if optimizer_idx == 0: inputs = rearrange(inputs, "b c t h w -> (b t) c h w").contiguous() reconstructions = rearrange( reconstructions, "b c t h w -> (b t) c h w" ).contiguous() rec_loss = self.loss_func(inputs, reconstructions) if self.perceptual_weight > 0: p_loss = self.perceptual_loss(inputs, reconstructions) rec_loss = rec_loss + self.perceptual_weight * p_loss nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar weighted_nll_loss = nll_loss if weights is not None: weighted_nll_loss = weights * nll_loss weighted_nll_loss = ( torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] ) nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] kl_loss = posteriors.kl() kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] if wavelet_coeffs: wl_loss_l2 = torch.sum(l1(wavelet_coeffs[0], wavelet_coeffs[1])) / bs wl_loss_l3 = torch.sum(l1(wavelet_coeffs[2], wavelet_coeffs[3])) / bs wl_loss = wl_loss_l2 + wl_loss_l3 else: wl_loss = torch.tensor(0.0) inputs = rearrange(inputs, "(b t) c h w -> b c t h w", t=t).contiguous() reconstructions = rearrange( reconstructions, "(b t) c h w -> b c t h w", t=t ).contiguous() logits_fake = self.discriminator(reconstructions) g_loss = -torch.mean(logits_fake) if global_step >= self.discriminator_iter_start: if self.disc_factor > 0.0: d_weight = self.calculate_adaptive_weight( nll_loss, g_loss, last_layer=last_layer ) else: d_weight = torch.tensor(1.0) else: d_weight = torch.tensor(0.0) g_loss = torch.tensor(0.0, requires_grad=True) disc_factor = adopt_weight( self.disc_factor, global_step, threshold=self.discriminator_iter_start ) loss = ( weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + self.wavelet_weight * wl_loss ) log = { "{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), "{}/rec_loss".format(split): weighted_nll_loss.detach().mean(), "{}/wl_loss".format(split): wl_loss.detach().mean(), "{}/d_weight".format(split): d_weight.detach(), "{}/disc_factor".format(split): torch.tensor(disc_factor), "{}/g_loss".format(split): g_loss.detach().mean(), } return loss, log elif optimizer_idx == 1: logits_real = self.discriminator(inputs.contiguous().detach()) logits_fake = self.discriminator(reconstructions.contiguous().detach()) disc_factor = adopt_weight( self.disc_factor, global_step, threshold=self.discriminator_iter_start ) d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) log = { "{}/disc_loss".format(split): d_loss.clone().detach().mean(), "{}/logits_real".format(split): logits_real.detach().mean(), "{}/logits_fake".format(split): logits_fake.detach().mean(), } return d_loss, log ================================================ FILE: opensora/models/causalvideovae/model/modeling_videobase.py ================================================ import torch from diffusers import ModelMixin, ConfigMixin from torch import nn import os import json from diffusers.configuration_utils import ConfigMixin from diffusers.models.modeling_utils import ModelMixin from typing import Optional, Union import glob class VideoBaseAE(ModelMixin, ConfigMixin): config_name = "config.json" def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) def encode(self, x: torch.Tensor, *args, **kwargs): pass def decode(self, encoding: torch.Tensor, *args, **kwargs): pass @property def num_training_steps(self) -> int: """Total training steps inferred from datamodule and devices.""" if self.trainer.max_steps: return self.trainer.max_steps limit_batches = self.trainer.limit_train_batches batches = len(self.train_dataloader()) batches = min(batches, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * batches) num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes) if self.trainer.tpu_cores: num_devices = max(num_devices, self.trainer.tpu_cores) effective_accum = self.trainer.accumulate_grad_batches * num_devices return (batches // effective_accum) * self.trainer.max_epochs @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): ckpt_files = glob.glob(os.path.join(pretrained_model_name_or_path, '*.ckpt')) if ckpt_files: # Adapt to checkpoint last_ckpt_file = ckpt_files[-1] config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) model = cls.from_config(config_file) model.init_from_ckpt(last_ckpt_file) return model else: return super().from_pretrained(pretrained_model_name_or_path, **kwargs) ================================================ FILE: opensora/models/causalvideovae/model/modules/__init__.py ================================================ from .block import Block from .attention import * from .conv import * from .normalize import * from .resnet_block import * from .updownsample import * from .wavelet import * ================================================ FILE: opensora/models/causalvideovae/model/modules/attention.py ================================================ import torch.nn as nn import torch.nn.functional as F from .normalize import Normalize from .conv import CausalConv3d import torch from .block import Block try: import torch_npu from opensora.npu_config import npu_config, set_run_dtype except: torch_npu = None npu_config = None # from xformers import ops as xops class AttnBlock3D(Block): """Compatible with old versions, there are issues, use with caution.""" def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) def forward(self, x): h_ = x h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention b, c, t, h, w = q.shape q = q.reshape(b * t, c, h * w) q = q.permute(0, 2, 1) # b,hw,c k = k.reshape(b * t, c, h * w) # b,c,hw w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] w_ = w_ * (int(c) ** (-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values v = v.reshape(b * t, c, h * w) w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] h_ = h_.reshape(b, c, t, h, w) h_ = self.proj_out(h_) return x + h_ class AttnBlock3DFix(nn.Module): """ Thanks to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/172. """ def __init__(self, in_channels, norm_type="groupnorm"): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels, norm_type=norm_type) self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) def forward(self, x): h_ = x h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) b, c, t, h, w = q.shape q = q.permute(0, 2, 3, 4, 1).reshape(b * t, h * w, c).contiguous() k = k.permute(0, 2, 3, 4, 1).reshape(b * t, h * w, c).contiguous() v = v.permute(0, 2, 3, 4, 1).reshape(b * t, h * w, c).contiguous() if torch_npu is None: # attn_output = xops.memory_efficient_attention( # q, k, v, # scale=c ** -0.5 # ) q = q.view(b * t, -1, 1, c).transpose(1, 2) k = k.view(b * t, -1, 1, c).transpose(1, 2) v = v.view(b * t, -1, 1, c).transpose(1, 2) attn_output = F.scaled_dot_product_attention( q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False ) attn_output = attn_output.transpose(1, 2).reshape(b * t, -1, 1 * c) else: # print('npu_config.enable_FA, q.dtype == torch.float32', npu_config.enable_FA, q.dtype == torch.float32) if npu_config.enable_FA and q.dtype == torch.float32: dtype = torch.bfloat16 else: dtype = None with set_run_dtype(q, dtype): query, key, value = npu_config.set_current_run_dtype([q, k, v]) hidden_states = npu_config.run_attention(query, key, value, atten_mask=None, input_layout="BSH", head_dim=c, head_num=1) attn_output = npu_config.restore_dtype(hidden_states) attn_output = attn_output.reshape(b, t, h, w, c).permute(0, 4, 1, 2, 3) h_ = self.proj_out(attn_output) return x + h_ ================================================ FILE: opensora/models/causalvideovae/model/modules/block.py ================================================ import torch.nn as nn class Block(nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) ================================================ FILE: opensora/models/causalvideovae/model/modules/conv.py ================================================ try: import torch_npu from opensora.npu_config import npu_config except: torch_npu = None npu_config = None import torch.nn as nn from typing import Union, Tuple import torch from .block import Block from .ops import cast_tuple from .ops import video_to_image from torch.utils.checkpoint import checkpoint import torch.nn.functional as F from collections import deque class Conv2d(nn.Conv2d): def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int]] = 3, stride: Union[int, Tuple[int]] = 1, padding: Union[str, int, Tuple[int]] = 0, dilation: Union[int, Tuple[int]] = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device=None, dtype=None, ) -> None: super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, ) @video_to_image def forward(self, x): return super().forward(x) class CausalConv3d(Block): def __init__( self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], enable_cached=False, bias=True, **kwargs ): super().__init__() self.kernel_size = cast_tuple(kernel_size, 3) self.time_kernel_size = self.kernel_size[0] self.chan_in = chan_in self.chan_out = chan_out self.stride = kwargs.pop("stride", 1) self.padding = kwargs.pop("padding", 0) self.padding = list(cast_tuple(self.padding, 3)) self.padding[0] = 0 self.stride = cast_tuple(self.stride, 3) self.conv = nn.Conv3d( chan_in, chan_out, self.kernel_size, stride=self.stride, padding=self.padding, bias=bias ) self.enable_cached = enable_cached self.is_first_chunk = True self.causal_cached = deque() self.cache_offset = 0 def forward(self, x): if self.is_first_chunk: first_frame_pad = x[:, :, :1, :, :].repeat( (1, 1, self.time_kernel_size - 1, 1, 1) ) else: first_frame_pad = self.causal_cached.popleft() x = torch.concatenate((first_frame_pad, x), dim=2) if self.enable_cached and self.time_kernel_size != 1: if (self.time_kernel_size - 1) // self.stride[0] != 0: if self.cache_offset == 0: self.causal_cached.append(x[:, :, -(self.time_kernel_size - 1) // self.stride[0]:].clone()) else: self.causal_cached.append(x[:, :, :-self.cache_offset][:, :, -(self.time_kernel_size - 1) // self.stride[0]:].clone()) else: self.causal_cached.append(x[:, :, 0:0, :, :].clone()) elif self.enable_cached: self.causal_cached.append(x[:, :, 0:0, :, :].clone()) x = self.conv(x) return x class CausalConv3d_GC(CausalConv3d): def __init__( self, chan_in, chan_out, kernel_size: Union[int, Tuple[int]], init_method="random", **kwargs ): super().__init__(chan_in, chan_out, kernel_size, init_method, **kwargs) def forward(self, x): # 1 + 16 16 as video, 1 as image first_frame_pad = x[:, :, :1, :, :].repeat( (1, 1, self.time_kernel_size - 1, 1, 1) ) # b c t h w x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16 return checkpoint(self.conv, x) ================================================ FILE: opensora/models/causalvideovae/model/modules/normalize.py ================================================ import torch import torch.nn as nn from .block import Block from einops import rearrange class GroupNorm(Block): def __init__(self, num_channels, num_groups=32, eps=1e-6, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.norm = torch.nn.GroupNorm( num_groups=num_groups, num_channels=num_channels, eps=eps, affine=True ) def forward(self, x): return self.norm(x) class LayerNorm(Block): def __init__(self, num_channels, eps=1e-6, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.norm = torch.nn.LayerNorm(num_channels, eps=eps, elementwise_affine=True) def forward(self, x): if x.dim() == 5: x = rearrange(x, "b c t h w -> b t h w c") x = self.norm(x) x = rearrange(x, "b t h w c -> b c t h w") else: x = rearrange(x, "b c h w -> b h w c") x = self.norm(x) x = rearrange(x, "b h w c -> b c h w") return x def Normalize(in_channels, num_groups=32, norm_type="groupnorm"): if norm_type == "groupnorm": return torch.nn.GroupNorm( num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True ) elif norm_type == "layernorm": return LayerNorm(num_channels=in_channels, eps=1e-6) ================================================ FILE: opensora/models/causalvideovae/model/modules/ops.py ================================================ import torch from einops import rearrange def video_to_image(func): def wrapper(self, x, *args, **kwargs): if x.dim() == 5: t = x.shape[2] if True: x = rearrange(x, "b c t h w -> (b t) c h w") x = func(self, x, *args, **kwargs) x = rearrange(x, "(b t) c h w -> b c t h w", t=t) else: # Conv 2d slice infer result = [] for i in range(t): frame = x[:, :, i, :, :] frame = func(self, frame, *args, **kwargs) result.append(frame.unsqueeze(2)) x = torch.concatenate(result, dim=2) return x return wrapper def nonlinearity(x): return x * torch.sigmoid(x) def cast_tuple(t, length=1): return t if isinstance(t, tuple) or isinstance(t, list) else ((t,) * length) def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): n_dims = len(x.shape) if src_dim < 0: src_dim = n_dims + src_dim if dest_dim < 0: dest_dim = n_dims + dest_dim assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims dims = list(range(n_dims)) del dims[src_dim] permutation = [] ctr = 0 for i in range(n_dims): if i == dest_dim: permutation.append(src_dim) else: permutation.append(dims[ctr]) ctr += 1 x = x.permute(permutation) if make_contiguous: x = x.contiguous() return x ================================================ FILE: opensora/models/causalvideovae/model/modules/quant.py ================================================ import torch import torch.nn as nn import torch.distributed as dist import numpy as np import torch.nn.functional as F from .ops import shift_dim class Codebook(nn.Module): def __init__(self, n_codes, embedding_dim): super().__init__() self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim)) self.register_buffer("N", torch.zeros(n_codes)) self.register_buffer("z_avg", self.embeddings.data.clone()) self.n_codes = n_codes self.embedding_dim = embedding_dim self._need_init = True def _tile(self, x): d, ew = x.shape if d < self.n_codes: n_repeats = (self.n_codes + d - 1) // d std = 0.01 / np.sqrt(ew) x = x.repeat(n_repeats, 1) x = x + torch.randn_like(x) * std return x def _init_embeddings(self, z): # z: [b, c, t, h, w] self._need_init = False flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) y = self._tile(flat_inputs) d = y.shape[0] _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] if dist.is_initialized(): dist.broadcast(_k_rand, 0) self.embeddings.data.copy_(_k_rand) self.z_avg.data.copy_(_k_rand) self.N.data.copy_(torch.ones(self.n_codes)) def forward(self, z): # z: [b, c, t, h, w] if self._need_init and self.training: self._init_embeddings(z) flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) distances = ( (flat_inputs**2).sum(dim=1, keepdim=True) - 2 * flat_inputs @ self.embeddings.t() + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) ) encoding_indices = torch.argmin(distances, dim=1) encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs) encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:]) embeddings = F.embedding(encoding_indices, self.embeddings) embeddings = shift_dim(embeddings, -1, 1) commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) # EMA codebook update if self.training: n_total = encode_onehot.sum(dim=0) encode_sum = flat_inputs.t() @ encode_onehot if dist.is_initialized(): dist.all_reduce(n_total) dist.all_reduce(encode_sum) self.N.data.mul_(0.99).add_(n_total, alpha=0.01) self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) n = self.N.sum() weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n encode_normalized = self.z_avg / weights.unsqueeze(1) self.embeddings.data.copy_(encode_normalized) y = self._tile(flat_inputs) _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] if dist.is_initialized(): dist.broadcast(_k_rand, 0) usage = (self.N.view(self.n_codes, 1) >= 1).float() self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) embeddings_st = (embeddings - z).detach() + z avg_probs = torch.mean(encode_onehot, dim=0) perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) return dict( embeddings=embeddings_st, encodings=encoding_indices, commitment_loss=commitment_loss, perplexity=perplexity, ) def dictionary_lookup(self, encodings): embeddings = F.embedding(encodings, self.embeddings) return embeddings ================================================ FILE: opensora/models/causalvideovae/model/modules/resnet_block.py ================================================ try: import torch_npu from opensora.npu_config import npu_config except: torch_npu = None npu_config = None import torch from .normalize import Normalize from .ops import nonlinearity, video_to_image from .conv import CausalConv3d from .block import Block from torch.utils.checkpoint import checkpoint class ResnetBlock2D(Block): def __init__( self, *, in_channels, out_channels=None, conv_shortcut=False, norm_type, dropout, ): super().__init__() self.in_channels = in_channels self.out_channels = in_channels if out_channels is None else out_channels self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels, norm_type=norm_type) self.conv1 = torch.nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, padding=1 ) self.norm2 = Normalize(out_channels, norm_type=norm_type) self.dropout = torch.nn.Dropout(dropout) self.conv2 = torch.nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1 ) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = torch.nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, padding=1 ) else: self.nin_shortcut = torch.nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0 ) @video_to_image def forward(self, x): h = x if npu_config is None: h = self.norm1(h) else: h = npu_config.run_group_norm(self.norm1, h) h = nonlinearity(h) h = self.conv1(h) if npu_config is None: h = self.norm2(h) else: h = npu_config.run_group_norm(self.norm2, h) h = nonlinearity(h) h = self.dropout(h) h = self.conv2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) x = x + h return x class ResnetBlock3D(Block): def __init__( self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, norm_type, ): super().__init__() self.in_channels = in_channels self.out_channels = in_channels if out_channels is None else out_channels self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels, norm_type=norm_type) self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1) self.norm2 = Normalize(out_channels, norm_type=norm_type) self.dropout = torch.nn.Dropout(dropout) self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = CausalConv3d( in_channels, out_channels, 3, padding=1 ) else: self.nin_shortcut = CausalConv3d( in_channels, out_channels, 1, padding=0 ) def forward(self, x): h = x if npu_config is None: h = self.norm1(h) else: h = npu_config.run_group_norm(self.norm1, h) h = nonlinearity(h) h = self.conv1(h) if npu_config is None: h = self.norm2(h) else: h = npu_config.run_group_norm(self.norm2, h) h = nonlinearity(h) h = self.dropout(h) h = self.conv2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) return x + h class ResnetBlock3D_GC(Block): def __init__( self, *, in_channels, out_channels=None, conv_shortcut=False, norm_type, dropout, ): super().__init__() self.in_channels = in_channels self.out_channels = in_channels if out_channels is None else out_channels self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels, norm_type=norm_type) self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1) self.norm2 = Normalize(out_channels, norm_type=norm_type) self.dropout = torch.nn.Dropout(dropout) self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = CausalConv3d( in_channels, out_channels, 3, padding=1 ) else: self.nin_shortcut = CausalConv3d( in_channels, out_channels, 1, padding=0 ) def forward(self, x): return checkpoint(self._forward, x, use_reentrant=True) def _forward(self, x): h = x h = self.norm1(h) h = nonlinearity(h) h = self.conv1(h) h = self.norm2(h) h = nonlinearity(h) h = self.dropout(h) h = self.conv2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) return x + h ================================================ FILE: opensora/models/causalvideovae/model/modules/updownsample.py ================================================ from typing import Union, Tuple from collections import deque import torch import torch.nn as nn import torch.nn.functional as F from .ops import cast_tuple, video_to_image from .conv import CausalConv3d, CausalConv3d_GC from einops import rearrange from .block import Block try: import torch_npu from opensora.npu_config import npu_config except: torch_npu = None npu_config = None class Upsample(Block): def __init__(self, in_channels, out_channels): super().__init__() self.with_conv = True if self.with_conv: self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) @video_to_image def forward(self, x): x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") if self.with_conv: x = self.conv(x) return x class Downsample(Block): def __init__(self, in_channels, out_channels, undown=False): super().__init__() self.with_conv = True self.undown = undown if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves if self.undown: self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) @video_to_image def forward(self, x): if self.with_conv: if self.undown: if npu_config is not None and npu_config.on_npu: x_dtype = x.dtype x = x.to(npu_config.replaced_type) x = npu_config.run_conv3d(self.conv, x, x_dtype) else: x = self.conv(x) else: pad = (0, 1, 0, 1) if npu_config is not None and npu_config.on_npu: x_dtype = x.dtype x = x.to(npu_config.replaced_type) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = npu_config.run_conv3d(self.conv, x, x_dtype) else: x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) return x class SpatialDownsample2x(Block): def __init__( self, chan_in, chan_out, kernel_size: Union[int, Tuple[int]] = (3, 3), stride: Union[int, Tuple[int]] = (2, 2), **kwargs ): super().__init__() kernel_size = cast_tuple(kernel_size, 2) stride = cast_tuple(stride, 2) self.chan_in = chan_in self.chan_out = chan_out self.kernel_size = kernel_size self.conv = CausalConv3d( self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1, ) + stride, padding=0 ) def forward(self, x): pad = (0,1,0,1,0,0) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) return x class SpatialUpsample2x_GC(Block): def __init__( self, chan_in, chan_out, kernel_size: Union[int, Tuple[int]] = (3, 3), stride: Union[int, Tuple[int]] = (1, 1), unup=False, ): super().__init__() self.chan_in = chan_in self.chan_out = chan_out self.kernel_size = kernel_size self.unup = unup self.conv = CausalConv3d_GC( self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1, ) + stride, padding=1 ) def forward(self, x): if not self.unup: t = x.shape[2] x = rearrange(x, "b c t h w -> b (c t) h w") x = F.interpolate(x, scale_factor=(2,2), mode="nearest") x = rearrange(x, "b (c t) h w -> b c t h w", t=t) x = self.conv(x) return x class SpatialUpsample2x(Block): def __init__( self, chan_in, chan_out, kernel_size: Union[int, Tuple[int]] = (3, 3), stride: Union[int, Tuple[int]] = (1, 1), unup=False, ): super().__init__() self.chan_in = chan_in self.chan_out = chan_out self.kernel_size = kernel_size self.unup = unup self.conv = CausalConv3d( self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1, ) + stride, padding=1 ) def forward(self, x): if not self.unup: t = x.shape[2] x = rearrange(x, "b c t h w -> b (c t) h w") x = F.interpolate(x, scale_factor=(2,2), mode="nearest") x = rearrange(x, "b (c t) h w -> b c t h w", t=t) x = self.conv(x) return x class TimeDownsample2x(Block): def __init__( self, chan_in, chan_out, kernel_size: int = 3 ): super().__init__() self.kernel_size = kernel_size if npu_config is not None and npu_config.on_npu: self.avg_pool = nn.AvgPool2d((kernel_size, 1), stride=(2, 1)) self.pad = nn.ReplicationPad3d((0, 0, 0, 0, self.kernel_size - 1, 0)) else: self.conv = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1)) def forward(self, x): if npu_config is not None and npu_config.on_npu: n, c, d, h, w = x.shape x = self.pad(x) x = x.view(n * c, -1, h * w) pooled = self.avg_pool(x) output = pooled.view(n, c, -1, h, w) return output else: first_frame_pad = x[:, :, :1, :, :].repeat( (1, 1, self.kernel_size - 1, 1, 1) ) x = torch.concatenate((first_frame_pad, x), dim=2) return self.conv(x) class TimeUpsample2x(Block): def __init__( self, chan_in, chan_out ): super().__init__() def forward(self, x): if x.size(2) > 1: x,x_= x[:,:,:1],x[:,:,1:] x_= F.interpolate(x_, scale_factor=(2,1,1), mode='trilinear') x = torch.concat([x, x_], dim=2) return x class TimeDownsampleRes2x(Block): def __init__( self, in_channels, out_channels, kernel_size: int = 3, mix_factor: float = 2.0, ): super().__init__() self.kernel_size = cast_tuple(kernel_size, 3) if npu_config is not None and npu_config.on_npu: self.avg_pool = nn.AvgPool2d((kernel_size, 1), stride=(2, 1)) self.pad = nn.ReplicationPad3d((0, 0, 0, 0, kernel_size - 1, 0)) else: self.avg_pool = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1)) self.conv = nn.Conv3d( in_channels, out_channels, self.kernel_size, stride=(2,1,1), padding=(0,1,1) ) self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) def forward(self, x): alpha = torch.sigmoid(self.mix_factor) if npu_config is not None and npu_config.on_npu: n, c, d, h, w = x.shape x_dtype = x.dtype x = x.to(npu_config.replaced_type) x = self.pad(x) pad_x = x.view(n, c, -1, h, w) avg_x = self.avg_pool(x.view(n * c, -1, h * w)).view(n, c, -1, h, w).to(x_dtype) conv_x = npu_config.run_conv3d(self.conv, pad_x, x_dtype) return alpha * avg_x + (1 - alpha) * conv_x else: first_frame_pad = x[:, :, :1, :, :].repeat( (1, 1, self.kernel_size[0] - 1, 1, 1) ) x = torch.concatenate((first_frame_pad, x), dim=2) return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(x) class TimeUpsampleRes2x(Block): def __init__( self, in_channels, out_channels, kernel_size: int = 3, mix_factor: float = 2.0, ): super().__init__() self.conv = CausalConv3d( in_channels, out_channels, kernel_size, padding=1 ) self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) def forward(self, x): alpha = torch.sigmoid(self.mix_factor) if x.size(2) > 1: x,x_= x[:,:,:1],x[:,:,1:] if npu_config is not None and npu_config.on_npu: x_dtype = x_.dtype x_ = x_.to(npu_config.replaced_type) x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode='trilinear') x_ = x_.to(x_dtype) else: x_= F.interpolate(x_, scale_factor=(2,1,1), mode='trilinear') x = torch.concat([x, x_], dim=2) return alpha * x + (1-alpha) * self.conv(x) class Spatial2xTime2x3DDownsample(Block): def __init__(self, in_channels, out_channels): super().__init__() self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=0, stride=2) def forward(self, x): pad = (0,1,0,1,0,0) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) return x class Spatial2x3DDownsample(Block): def __init__(self, in_channels, out_channels): super().__init__() self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=0, stride=(1,2,2)) def forward(self, x): pad = (0,1,0,1,0,0) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) return x class Spatial2x3DUpsample(Block): def __init__(self, in_channels, out_channels): super().__init__() self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=1) def forward(self, x): x = F.interpolate(x, scale_factor=(1,2,2), mode='trilinear') return self.conv(x) class Spatial2xTime2x3DUpsample(Block): def __init__( self, in_channels, out_channels, t_interpolation="trilinear", enable_cached=False, ): super().__init__() self.t_interpolation = t_interpolation self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=1) self.enable_cached = enable_cached self.causal_cached = deque() def forward(self, x): if x.size(2) > 1 or len(self.causal_cached) > 0: if self.enable_cached and len(self.causal_cached) > 0: x = torch.cat([self.causal_cached.popleft(), x], dim=2) self.causal_cached.append(x[:, :, -2:-1].clone()) x = F.interpolate(x, scale_factor=(2, 1, 1), mode=self.t_interpolation) x = x[:, :, 2:] x = F.interpolate(x, scale_factor=(1, 2, 2), mode="trilinear") else: if self.enable_cached: self.causal_cached.append(x[:, :, -1:].clone()) x, x_ = x[:, :, :1], x[:, :, 1:] x_ = F.interpolate( x_, scale_factor=(2, 1, 1), mode=self.t_interpolation ) x_ = F.interpolate(x_, scale_factor=(1, 2, 2), mode="trilinear") x = F.interpolate(x, scale_factor=(1, 2, 2), mode="trilinear") x = torch.concat([x, x_], dim=2) else: if self.enable_cached: self.causal_cached.append(x[:, :, -1:].clone()) x = F.interpolate(x, scale_factor=(1, 2, 2), mode="trilinear") return self.conv(x) ================================================ FILE: opensora/models/causalvideovae/model/modules/wavelet.py ================================================ import torch import torch.nn.functional as F import torch.nn as nn from ..modules import CausalConv3d from ..modules.ops import video_to_image from einops import rearrange try: import torch_npu from opensora.npu_config import npu_config, set_run_dtype except Exception as e: torch_npu = None npu_config = None class HaarWaveletTransform3D(nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) h = torch.tensor([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]) * 0.3536 g = torch.tensor([[[1, -1], [1, -1]], [[1, -1], [1, -1]]]) * 0.3536 hh = torch.tensor([[[1, 1], [-1, -1]], [[1, 1], [-1, -1]]]) * 0.3536 gh = torch.tensor([[[1, -1], [-1, 1]], [[1, -1], [-1, 1]]]) * 0.3536 h_v = torch.tensor([[[1, 1], [1, 1]], [[-1, -1], [-1, -1]]]) * 0.3536 g_v = torch.tensor([[[1, -1], [1, -1]], [[-1, 1], [-1, 1]]]) * 0.3536 hh_v = torch.tensor([[[1, 1], [-1, -1]], [[-1, -1], [1, 1]]]) * 0.3536 gh_v = torch.tensor([[[1, -1], [-1, 1]], [[-1, 1], [1, -1]]]) * 0.3536 h = h.view(1, 1, 2, 2, 2) g = g.view(1, 1, 2, 2, 2) hh = hh.view(1, 1, 2, 2, 2) gh = gh.view(1, 1, 2, 2, 2) h_v = h_v.view(1, 1, 2, 2, 2) g_v = g_v.view(1, 1, 2, 2, 2) hh_v = hh_v.view(1, 1, 2, 2, 2) gh_v = gh_v.view(1, 1, 2, 2, 2) self.h_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) self.g_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) self.hh_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) self.gh_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) self.h_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) self.g_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) self.hh_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) self.gh_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) self.h_conv.conv.weight.data = h self.g_conv.conv.weight.data = g self.hh_conv.conv.weight.data = hh self.gh_conv.conv.weight.data = gh self.h_v_conv.conv.weight.data = h_v self.g_v_conv.conv.weight.data = g_v self.hh_v_conv.conv.weight.data = hh_v self.gh_v_conv.conv.weight.data = gh_v self.h_conv.requires_grad_(False) self.g_conv.requires_grad_(False) self.hh_conv.requires_grad_(False) self.gh_conv.requires_grad_(False) self.h_v_conv.requires_grad_(False) self.g_v_conv.requires_grad_(False) self.hh_v_conv.requires_grad_(False) self.gh_v_conv.requires_grad_(False) def forward(self, x): assert x.dim() == 5 if torch_npu is not None: dtype = x.dtype x = x.to(npu_config.conv_dtype) self.to(npu_config.conv_dtype) b = x.shape[0] x = rearrange(x, "b c t h w -> (b c) 1 t h w") low_low_low = self.h_conv(x) low_low_low = rearrange(low_low_low, "(b c) 1 t h w -> b c t h w", b=b) low_low_high = self.g_conv(x) low_low_high = rearrange(low_low_high, "(b c) 1 t h w -> b c t h w", b=b) low_high_low = self.hh_conv(x) low_high_low = rearrange(low_high_low, "(b c) 1 t h w -> b c t h w", b=b) low_high_high = self.gh_conv(x) low_high_high = rearrange(low_high_high, "(b c) 1 t h w -> b c t h w", b=b) high_low_low = self.h_v_conv(x) high_low_low = rearrange(high_low_low, "(b c) 1 t h w -> b c t h w", b=b) high_low_high = self.g_v_conv(x) high_low_high = rearrange(high_low_high, "(b c) 1 t h w -> b c t h w", b=b) high_high_low = self.hh_v_conv(x) high_high_low = rearrange(high_high_low, "(b c) 1 t h w -> b c t h w", b=b) high_high_high = self.gh_v_conv(x) high_high_high = rearrange(high_high_high, "(b c) 1 t h w -> b c t h w", b=b) output = torch.cat( [ low_low_low, low_low_high, low_high_low, low_high_high, high_low_low, high_low_high, high_high_low, high_high_high, ], dim=1, ) if torch_npu is not None: x = x.to(dtype) output = output.to(dtype) self.to(dtype) return output class InverseHaarWaveletTransform3D(nn.Module): def __init__(self, enable_cached=False, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.register_buffer('h', torch.tensor([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536 ) self.register_buffer('g', torch.tensor([[[1, -1], [1, -1]], [[1, -1], [1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536 ) self.register_buffer('hh', torch.tensor([[[1, 1], [-1, -1]], [[1, 1], [-1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536 ) self.register_buffer('gh', torch.tensor([[[1, -1], [-1, 1]], [[1, -1], [-1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536 ) self.register_buffer('h_v', torch.tensor([[[1, 1], [1, 1]], [[-1, -1], [-1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536 ) self.register_buffer('g_v', torch.tensor([[[1, -1], [1, -1]], [[-1, 1], [-1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536 ) self.register_buffer('hh_v', torch.tensor([[[1, 1], [-1, -1]], [[-1, -1], [1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536 ) self.register_buffer('gh_v', torch.tensor([[[1, -1], [-1, 1]], [[-1, 1], [1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536 ) self.enable_cached = enable_cached self.is_first_chunk = True def forward(self, coeffs): assert coeffs.dim() == 5 if torch_npu is not None: dtype = coeffs.dtype coeffs = coeffs.to(npu_config.conv_dtype) self.h = self.h.to(npu_config.conv_dtype) self.g = self.g.to(npu_config.conv_dtype) self.hh = self.hh.to(npu_config.conv_dtype) self.gh = self.gh.to(npu_config.conv_dtype) self.h_v = self.h_v.to(npu_config.conv_dtype) self.g_v = self.g_v.to(npu_config.conv_dtype) self.hh_v = self.hh_v.to(npu_config.conv_dtype) self.gh_v = self.gh_v.to(npu_config.conv_dtype) b = coeffs.shape[0] ( low_low_low, low_low_high, low_high_low, low_high_high, high_low_low, high_low_high, high_high_low, high_high_high, ) = coeffs.chunk(8, dim=1) low_low_low = rearrange(low_low_low, "b c t h w -> (b c) 1 t h w") low_low_high = rearrange(low_low_high, "b c t h w -> (b c) 1 t h w") low_high_low = rearrange(low_high_low, "b c t h w -> (b c) 1 t h w") low_high_high = rearrange(low_high_high, "b c t h w -> (b c) 1 t h w") high_low_low = rearrange(high_low_low, "b c t h w -> (b c) 1 t h w") high_low_high = rearrange(high_low_high, "b c t h w -> (b c) 1 t h w") high_high_low = rearrange(high_high_low, "b c t h w -> (b c) 1 t h w") high_high_high = rearrange(high_high_high, "b c t h w -> (b c) 1 t h w") low_low_low = F.conv_transpose3d(low_low_low, self.h, stride=2) low_low_high = F.conv_transpose3d(low_low_high, self.g, stride=2) low_high_low = F.conv_transpose3d(low_high_low, self.hh, stride=2) low_high_high = F.conv_transpose3d(low_high_high, self.gh, stride=2) high_low_low = F.conv_transpose3d(high_low_low, self.h_v, stride=2) high_low_high = F.conv_transpose3d(high_low_high, self.g_v, stride=2) high_high_low = F.conv_transpose3d(high_high_low, self.hh_v, stride=2) high_high_high = F.conv_transpose3d(high_high_high, self.gh_v, stride=2) if self.enable_cached and not self.is_first_chunk: reconstructed = ( low_low_low + low_low_high + low_high_low + low_high_high + high_low_low + high_low_high + high_high_low + high_high_high ) else: reconstructed = ( low_low_low[:, :, 1:] + low_low_high[:, :, 1:] + low_high_low[:, :, 1:] + low_high_high[:, :, 1:] + high_low_low[:, :, 1:] + high_low_high[:, :, 1:] + high_high_low[:, :, 1:] + high_high_high[:, :, 1:] ) reconstructed = rearrange(reconstructed, "(b c) 1 t h w -> b c t h w", b=b) if torch_npu is not None: coeffs = coeffs.to(dtype) reconstructed = reconstructed.to(dtype) self.h = self.h.to(dtype) self.g = self.g.to(dtype) self.hh = self.hh.to(dtype) self.gh = self.gh.to(dtype) self.h_v = self.h_v.to(dtype) self.g_v = self.g_v.to(dtype) self.hh_v = self.hh_v.to(dtype) self.gh_v = self.gh_v.to(dtype) return reconstructed class HaarWaveletTransform2D(nn.Module): def __init__(self): super().__init__() self.register_buffer('aa', torch.tensor([[1, 1], [1, 1]]).view(1, 1, 2, 2) / 2) self.register_buffer('ad', torch.tensor([[1, 1], [-1, -1]]).view(1, 1, 2, 2) / 2) self.register_buffer('da', torch.tensor([[1, -1], [1, -1]]).view(1, 1, 2, 2) / 2) self.register_buffer('dd', torch.tensor([[1, -1], [-1, 1]]).view(1, 1, 2, 2) / 2) @video_to_image def forward(self, x): b, c, h, w = x.shape x = x.reshape(b * c, 1, h, w) low_low = F.conv2d(x, self.aa, stride=2).reshape(b, c, h // 2, w // 2) low_high = F.conv2d(x, self.ad, stride=2).reshape(b, c, h // 2, w // 2) high_low = F.conv2d(x, self.da, stride=2).reshape(b, c, h // 2, w // 2) high_high = F.conv2d(x, self.dd, stride=2).reshape(b, c, h // 2, w // 2) coeffs = torch.cat([low_low, low_high, high_low, high_high], dim=1) return coeffs class InverseHaarWaveletTransform2D(nn.Module): def __init__(self): super().__init__() self.register_buffer('aa', torch.tensor([[1, 1], [1, 1]]).view(1, 1, 2, 2) / 2) self.register_buffer('ad', torch.tensor([[1, 1], [-1, -1]]).view(1, 1, 2, 2) / 2) self.register_buffer('da', torch.tensor([[1, -1], [1, -1]]).view(1, 1, 2, 2) / 2) self.register_buffer('dd', torch.tensor([[1, -1], [-1, 1]]).view(1, 1, 2, 2) / 2) @video_to_image def forward(self, coeffs): low_low, low_high, high_low, high_high = coeffs.chunk(4, dim=1) b, c, height_half, width_half = low_low.shape height = height_half * 2 width = width_half * 2 low_low = F.conv_transpose2d( low_low.reshape(b * c, 1, height_half, width_half), self.aa, stride=2 ) low_high = F.conv_transpose2d( low_high.reshape(b * c, 1, height_half, width_half), self.ad, stride=2 ) high_low = F.conv_transpose2d( high_low.reshape(b * c, 1, height_half, width_half), self.da, stride=2 ) high_high = F.conv_transpose2d( high_high.reshape(b * c, 1, height_half, width_half), self.dd, stride=2 ) return (low_low + low_high + high_low + high_high).reshape(b, c, height, width) ================================================ FILE: opensora/models/causalvideovae/model/registry.py ================================================ class ModelRegistry: _models = {} @classmethod def register(cls, model_name): def decorator(model_class): cls._models[model_name] = model_class return model_class return decorator @classmethod def get_model(cls, model_name): return cls._models.get(model_name) ================================================ FILE: opensora/models/causalvideovae/model/trainer_videobase.py ================================================ from transformers import Trainer import torch.nn.functional as F from typing import Optional import os import torch from transformers.utils import WEIGHTS_NAME import json class VideoBaseTrainer(Trainer): def _save(self, output_dir: Optional[str] = None, state_dict=None): output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) if state_dict is None: state_dict = self.model.state_dict() # get model config model_config = self.model.config.to_dict() # add more information model_config['model'] = self.model.__class__.__name__ with open(os.path.join(output_dir, "config.json"), "w") as file: json.dump(self.model.config.to_dict(), file) torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) torch.save(self.args, os.path.join(output_dir, "training_args.bin")) ================================================ FILE: opensora/models/causalvideovae/model/utils/__init__.py ================================================ ================================================ FILE: opensora/models/causalvideovae/model/utils/distrib_utils.py ================================================ import torch import numpy as np class DiagonalGaussianDistribution(object): def __init__(self, parameters, deterministic=False): self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype) def sample(self): x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device, dtype=self.parameters.dtype) return x def kl(self, other=None): if self.deterministic: return torch.Tensor([0.]) else: if other is None: return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, dim=[1, 2, 3]) def nll(self, sample, dims=[1,2,3]): if self.deterministic: return torch.Tensor([0.]) logtwopi = np.log(2.0 * np.pi) return 0.5 * torch.sum( logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) def mode(self): return self.mean ================================================ FILE: opensora/models/causalvideovae/model/utils/module_utils.py ================================================ import importlib Module = str MODULES_BASE = "opensora.models.causalvideovae.model.modules." def resolve_str_to_obj(str_val, append=True): if append: str_val = MODULES_BASE + str_val module_name, class_name = str_val.rsplit('.', 1) module = importlib.import_module(module_name) return getattr(module, class_name) def create_instance(module_class_str: str, **kwargs): module_name, class_name = module_class_str.rsplit('.', 1) module = importlib.import_module(module_name) class_ = getattr(module, class_name) return class_(**kwargs) ================================================ FILE: opensora/models/causalvideovae/model/utils/scheduler_utils.py ================================================ import torch def cosine_scheduler(step, max_steps, value_base=1, value_end=0): step = torch.tensor(step) cosine_value = 0.5 * (1 + torch.cos(torch.pi * step / max_steps)) value = value_end + (value_base - value_end) * cosine_value return value ================================================ FILE: opensora/models/causalvideovae/model/utils/video_utils.py ================================================ import torch import numpy as np def tensor_to_video(x): x = (x * 2 - 1).detach().cpu() x = torch.clamp(x, -1, 1) x = (x + 1) / 2 x = x.permute(1, 0, 2, 3).float().numpy() # c t h w -> t c h w x = (255 * x).astype(np.uint8) return x ================================================ FILE: opensora/models/causalvideovae/model/utils/wavelet_utils.py ================================================ import torch import torch.nn.functional as F import torch.nn as nn from ..modules import CausalConv3d from einops import rearrange class HaarWaveletTransform3D(nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) h = torch.tensor([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]) * 0.3536 g = torch.tensor([[[1, -1], [1, -1]], [[1, -1], [1, -1]]]) * 0.3536 hh = torch.tensor([[[1, 1], [-1, -1]], [[1, 1], [-1, -1]]]) * 0.3536 gh = torch.tensor([[[1, -1], [-1, 1]], [[1, -1], [-1, 1]]]) * 0.3536 h_v = torch.tensor([[[1, 1], [1, 1]], [[-1, -1], [-1, -1]]]) * 0.3536 g_v = torch.tensor([[[1, -1], [1, -1]], [[-1, 1], [-1, 1]]]) * 0.3536 hh_v = torch.tensor([[[1, 1], [-1, -1]], [[-1, -1], [1, 1]]]) * 0.3536 gh_v = torch.tensor([[[1, -1], [-1, 1]], [[-1, 1], [1, -1]]]) * 0.3536 h = h.view(1, 1, 2, 2, 2) g = g.view(1, 1, 2, 2, 2) hh = hh.view(1, 1, 2, 2, 2) gh = gh.view(1, 1, 2, 2, 2) h_v = h_v.view(1, 1, 2, 2, 2) g_v = g_v.view(1, 1, 2, 2, 2) hh_v = hh_v.view(1, 1, 2, 2, 2) gh_v = gh_v.view(1, 1, 2, 2, 2) self.h_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) self.g_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) self.hh_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) self.gh_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) self.h_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) self.g_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) self.hh_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) self.gh_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) self.h_conv.conv.weight.data = h self.g_conv.conv.weight.data = g self.hh_conv.conv.weight.data = hh self.gh_conv.conv.weight.data = gh self.h_v_conv.conv.weight.data = h_v self.g_v_conv.conv.weight.data = g_v self.hh_v_conv.conv.weight.data = hh_v self.gh_v_conv.conv.weight.data = gh_v self.h_conv.requires_grad_(False) self.g_conv.requires_grad_(False) self.hh_conv.requires_grad_(False) self.gh_conv.requires_grad_(False) self.h_v_conv.requires_grad_(False) self.g_v_conv.requires_grad_(False) self.hh_v_conv.requires_grad_(False) self.gh_v_conv.requires_grad_(False) def forward(self, x): assert x.dim() == 5 b = x.shape[0] x = rearrange(x, "b c t h w -> (b c) 1 t h w") low_low_low = self.h_conv(x) low_low_low = rearrange(low_low_low, "(b c) 1 t h w -> b c t h w", b=b) low_low_high = self.g_conv(x) low_low_high = rearrange(low_low_high, "(b c) 1 t h w -> b c t h w", b=b) low_high_low = self.hh_conv(x) low_high_low = rearrange(low_high_low, "(b c) 1 t h w -> b c t h w", b=b) low_high_high = self.gh_conv(x) low_high_high = rearrange(low_high_high, "(b c) 1 t h w -> b c t h w", b=b) high_low_low = self.h_v_conv(x) high_low_low = rearrange(high_low_low, "(b c) 1 t h w -> b c t h w", b=b) high_low_high = self.g_v_conv(x) high_low_high = rearrange(high_low_high, "(b c) 1 t h w -> b c t h w", b=b) high_high_low = self.hh_v_conv(x) high_high_low = rearrange(high_high_low, "(b c) 1 t h w -> b c t h w", b=b) high_high_high = self.gh_v_conv(x) high_high_high = rearrange(high_high_high, "(b c) 1 t h w -> b c t h w", b=b) output = torch.cat( [ low_low_low, low_low_high, low_high_low, low_high_high, high_low_low, high_low_high, high_high_low, high_high_high, ], dim=1, ) return output class InverseHaarWaveletTransform3D(nn.Module): def __init__(self, enable_cached=False, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.register_buffer('h', torch.tensor([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536 ) self.register_buffer('g', torch.tensor([[[1, -1], [1, -1]], [[1, -1], [1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536 ) self.register_buffer('hh', torch.tensor([[[1, 1], [-1, -1]], [[1, 1], [-1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536 ) self.register_buffer('gh', torch.tensor([[[1, -1], [-1, 1]], [[1, -1], [-1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536 ) self.register_buffer('h_v', torch.tensor([[[1, 1], [1, 1]], [[-1, -1], [-1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536 ) self.register_buffer('g_v', torch.tensor([[[1, -1], [1, -1]], [[-1, 1], [-1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536 ) self.register_buffer('hh_v', torch.tensor([[[1, 1], [-1, -1]], [[-1, -1], [1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536 ) self.register_buffer('gh_v', torch.tensor([[[1, -1], [-1, 1]], [[-1, 1], [1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536 ) self.enable_cached = enable_cached self.causal_cached = None def forward(self, coeffs): assert coeffs.dim() == 5 b = coeffs.shape[0] ( low_low_low, low_low_high, low_high_low, low_high_high, high_low_low, high_low_high, high_high_low, high_high_high, ) = coeffs.chunk(8, dim=1) low_low_low = rearrange(low_low_low, "b c t h w -> (b c) 1 t h w") low_low_high = rearrange(low_low_high, "b c t h w -> (b c) 1 t h w") low_high_low = rearrange(low_high_low, "b c t h w -> (b c) 1 t h w") low_high_high = rearrange(low_high_high, "b c t h w -> (b c) 1 t h w") high_low_low = rearrange(high_low_low, "b c t h w -> (b c) 1 t h w") high_low_high = rearrange(high_low_high, "b c t h w -> (b c) 1 t h w") high_high_low = rearrange(high_high_low, "b c t h w -> (b c) 1 t h w") high_high_high = rearrange(high_high_high, "b c t h w -> (b c) 1 t h w") low_low_low = F.conv_transpose3d(low_low_low, self.h, stride=2) low_low_high = F.conv_transpose3d(low_low_high, self.g, stride=2) low_high_low = F.conv_transpose3d(low_high_low, self.hh, stride=2) low_high_high = F.conv_transpose3d(low_high_high, self.gh, stride=2) high_low_low = F.conv_transpose3d(high_low_low, self.h_v, stride=2) high_low_high = F.conv_transpose3d(high_low_high, self.g_v, stride=2) high_high_low = F.conv_transpose3d(high_high_low, self.hh_v, stride=2) high_high_high = F.conv_transpose3d(high_high_high, self.gh_v, stride=2) if self.enable_cached and self.causal_cached: reconstructed = ( low_low_low + low_low_high + low_high_low + low_high_high + high_low_low + high_low_high + high_high_low + high_high_high ) else: reconstructed = ( low_low_low[:, :, 1:] + low_low_high[:, :, 1:] + low_high_low[:, :, 1:] + low_high_high[:, :, 1:] + high_low_low[:, :, 1:] + high_low_high[:, :, 1:] + high_high_low[:, :, 1:] + high_high_high[:, :, 1:] ) self.causal_cached = True reconstructed = rearrange(reconstructed, "(b c) 1 t h w -> b c t h w", b=b) return reconstructed class HaarWaveletTransform2D(nn.Module): def __init__(self): super().__init__() self.register_buffer('aa', torch.tensor([[1, 1], [1, 1]]).view(1, 1, 2, 2) / 2) self.register_buffer('ad', torch.tensor([[1, 1], [-1, -1]]).view(1, 1, 2, 2) / 2) self.register_buffer('da', torch.tensor([[1, -1], [1, -1]]).view(1, 1, 2, 2) / 2) self.register_buffer('dd', torch.tensor([[1, -1], [-1, 1]]).view(1, 1, 2, 2) / 2) def forward(self, x): b, c, h, w = x.shape x = x.reshape(b * c, 1, h, w) low_low = F.conv2d(x, self.aa, stride=2).reshape(b, c, h // 2, w // 2) low_high = F.conv2d(x, self.ad, stride=2).reshape(b, c, h // 2, w // 2) high_low = F.conv2d(x, self.da, stride=2).reshape(b, c, h // 2, w // 2) high_high = F.conv2d(x, self.dd, stride=2).reshape(b, c, h // 2, w // 2) coeffs = torch.cat([low_low, low_high, high_low, high_high], dim=1) return coeffs class InverseHaarWaveletTransform2D(nn.Module): def __init__(self): super().__init__() self.register_buffer('aa', torch.tensor([[1, 1], [1, 1]]).view(1, 1, 2, 2) / 2) self.register_buffer('ad', torch.tensor([[1, 1], [-1, -1]]).view(1, 1, 2, 2) / 2) self.register_buffer('da', torch.tensor([[1, -1], [1, -1]]).view(1, 1, 2, 2) / 2) self.register_buffer('dd', torch.tensor([[1, -1], [-1, 1]]).view(1, 1, 2, 2) / 2) def forward(self, coeffs): low_low, low_high, high_low, high_high = coeffs.chunk(4, dim=1) b, c, height_half, width_half = low_low.shape height = height_half * 2 width = width_half * 2 low_low = F.conv_transpose2d( low_low.reshape(b * c, 1, height_half, width_half), self.aa, stride=2 ) low_high = F.conv_transpose2d( low_high.reshape(b * c, 1, height_half, width_half), self.ad, stride=2 ) high_low = F.conv_transpose2d( high_low.reshape(b * c, 1, height_half, width_half), self.da, stride=2 ) high_high = F.conv_transpose2d( high_high.reshape(b * c, 1, height_half, width_half), self.dd, stride=2 ) return (low_low + low_high + high_low + high_high).reshape(b, c, height, width) ================================================ FILE: opensora/models/causalvideovae/model/vae/__init__.py ================================================ from .modeling_causalvae import CausalVAEModel from .modeling_wfvae import WFVAEModel from einops import rearrange from torch import nn ================================================ FILE: opensora/models/causalvideovae/model/vae/modeling_causalvae.py ================================================ try: import torch_npu from opensora.npu_config import npu_config except: torch_npu = None npu_config = None from ..modeling_videobase import VideoBaseAE from ..modules import Normalize from ..modules.ops import nonlinearity from typing import Tuple import torch.nn as nn from ..utils.module_utils import resolve_str_to_obj, Module from ..utils.distrib_utils import DiagonalGaussianDistribution from ..registry import ModelRegistry import torch from diffusers.configuration_utils import register_to_config from copy import deepcopy import os class Encoder(nn.Module): def __init__( self, z_channels: int, hidden_size: int, hidden_size_mult: Tuple[int] = (1, 2, 4, 4), attn_resolutions: Tuple[int] = (16,), conv_in: Module = "Conv2d", conv_out: Module = "CasualConv3d", attention: Module = "AttnBlock", resnet_blocks: Tuple[Module] = ( "ResnetBlock2D", "ResnetBlock2D", "ResnetBlock2D", "ResnetBlock3D", ), spatial_downsample: Tuple[Module] = ( "Downsample", "Downsample", "Downsample", "", ), temporal_downsample: Tuple[Module] = ("", "", "TimeDownsampleRes2x", ""), mid_resnet: Module = "ResnetBlock3D", dropout: float = 0.0, resolution: int = 256, num_res_blocks: int = 2, double_z: bool = True, norm_type: str = "groupnorm", ) -> None: super().__init__() assert len(resnet_blocks) == len(hidden_size_mult), print( hidden_size_mult, resnet_blocks ) # ---- Config ---- self.num_resolutions = len(hidden_size_mult) self.resolution = resolution self.num_res_blocks = num_res_blocks # ---- In ---- self.conv_in = resolve_str_to_obj(conv_in)( 3, hidden_size, kernel_size=3, stride=1, padding=1 ) # ---- Downsample ---- curr_res = resolution in_ch_mult = (1,) + tuple(hidden_size_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = hidden_size * in_ch_mult[i_level] block_out = hidden_size * hidden_size_mult[i_level] for i_block in range(self.num_res_blocks): block.append( resolve_str_to_obj(resnet_blocks[i_level])( in_channels=block_in, out_channels=block_out, dropout=dropout, norm_type=norm_type ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(resolve_str_to_obj(attention)(block_in)) down = nn.Module() down.block = block down.attn = attn if spatial_downsample[i_level]: down.downsample = resolve_str_to_obj(spatial_downsample[i_level])( block_in, block_in ) curr_res = curr_res // 2 if temporal_downsample[i_level]: down.time_downsample = resolve_str_to_obj(temporal_downsample[i_level])( block_in, block_in ) self.down.append(down) # ---- Mid ---- self.mid = nn.Module() self.mid.block_1 = resolve_str_to_obj(mid_resnet)( in_channels=block_in, out_channels=block_in, dropout=dropout, norm_type=norm_type ) self.mid.attn_1 = resolve_str_to_obj(attention)(block_in) self.mid.block_2 = resolve_str_to_obj(mid_resnet)( in_channels=block_in, out_channels=block_in, dropout=dropout, norm_type=norm_type ) # ---- Out ---- self.norm_out = Normalize(block_in, norm_type=norm_type) self.conv_out = resolve_str_to_obj(conv_out)( block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1, ) def forward(self, x): h = self.conv_in(x) for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](h) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) if hasattr(self.down[i_level], "downsample"): h = self.down[i_level].downsample(h) if hasattr(self.down[i_level], "time_downsample"): h = self.down[i_level].time_downsample(h) h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) if npu_config is None: h = self.norm_out(h) else: h = npu_config.run_group_norm(self.norm_out, h) h = nonlinearity(h) h = self.conv_out(h) return h class Decoder(nn.Module): def __init__( self, z_channels: int, hidden_size: int, hidden_size_mult: Tuple[int] = (1, 2, 4, 4), attn_resolutions: Tuple[int] = (16,), conv_in: Module = "Conv2d", conv_out: Module = "CasualConv3d", attention: Module = "AttnBlock", resnet_blocks: Tuple[Module] = ( "ResnetBlock3D", "ResnetBlock3D", "ResnetBlock3D", "ResnetBlock3D", ), spatial_upsample: Tuple[Module] = ( "", "SpatialUpsample2x", "SpatialUpsample2x", "SpatialUpsample2x", ), temporal_upsample: Tuple[Module] = ("", "", "", "TimeUpsampleRes2x"), mid_resnet: Module = "ResnetBlock3D", dropout: float = 0.0, resolution: int = 256, num_res_blocks: int = 2, norm_type: str = "groupnorm", ): super().__init__() # ---- Config ---- self.num_resolutions = len(hidden_size_mult) self.resolution = resolution self.num_res_blocks = num_res_blocks # ---- In ---- block_in = hidden_size * hidden_size_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.conv_in = resolve_str_to_obj(conv_in)( z_channels, block_in, kernel_size=3, padding=1 ) # ---- Mid ---- self.mid = nn.Module() self.mid.block_1 = resolve_str_to_obj(mid_resnet)( in_channels=block_in, out_channels=block_in, dropout=dropout, norm_type=norm_type ) self.mid.attn_1 = resolve_str_to_obj(attention)(block_in, norm_type=norm_type) self.mid.block_2 = resolve_str_to_obj(mid_resnet)( in_channels=block_in, out_channels=block_in, dropout=dropout, norm_type=norm_type ) # ---- Upsample ---- self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = hidden_size * hidden_size_mult[i_level] for i_block in range(self.num_res_blocks + 1): block.append( resolve_str_to_obj(resnet_blocks[i_level])( in_channels=block_in, out_channels=block_out, dropout=dropout, norm_type=norm_type ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(resolve_str_to_obj(attention)(block_in, norm_type=norm_type)) up = nn.Module() up.block = block up.attn = attn if spatial_upsample[i_level]: up.upsample = resolve_str_to_obj(spatial_upsample[i_level])( block_in, block_in ) curr_res = curr_res * 2 if temporal_upsample[i_level]: up.time_upsample = resolve_str_to_obj(temporal_upsample[i_level])( block_in, block_in ) self.up.insert(0, up) # ---- Out ---- self.norm_out = Normalize(block_in, norm_type=norm_type) self.conv_out = resolve_str_to_obj(conv_out)( block_in, 3, kernel_size=3, padding=1 ) def forward(self, z): h = self.conv_in(z) h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if hasattr(self.up[i_level], "upsample"): h = self.up[i_level].upsample(h) if hasattr(self.up[i_level], "time_upsample"): h = self.up[i_level].time_upsample(h) if npu_config is None: h = self.norm_out(h) else: h = npu_config.run_group_norm(self.norm_out, h) h = nonlinearity(h) h = self.conv_out(h) return h @ModelRegistry.register("CausalVAE") class CausalVAEModel(VideoBaseAE): @register_to_config def __init__( self, hidden_size: int = 128, z_channels: int = 4, hidden_size_mult: Tuple[int] = (1, 2, 4, 4), attn_resolutions: Tuple[int] = [], dropout: float = 0.0, resolution: int = 256, double_z: bool = True, embed_dim: int = 4, num_res_blocks: int = 2, q_conv: str = "CausalConv3d", encoder_conv_in: Module = "CausalConv3d", encoder_conv_out: Module = "CausalConv3d", encoder_attention: Module = "AttnBlock3D", encoder_resnet_blocks: Tuple[Module] = ( "ResnetBlock3D", "ResnetBlock3D", "ResnetBlock3D", "ResnetBlock3D", ), encoder_spatial_downsample: Tuple[Module] = ( "SpatialDownsample2x", "SpatialDownsample2x", "SpatialDownsample2x", "", ), encoder_temporal_downsample: Tuple[Module] = ( "", "TimeDownsample2x", "TimeDownsample2x", "", ), encoder_mid_resnet: Module = "ResnetBlock3D", decoder_conv_in: Module = "CausalConv3d", decoder_conv_out: Module = "CausalConv3d", decoder_attention: Module = "AttnBlock3D", decoder_resnet_blocks: Tuple[Module] = ( "ResnetBlock3D", "ResnetBlock3D", "ResnetBlock3D", "ResnetBlock3D", ), decoder_spatial_upsample: Tuple[Module] = ( "", "SpatialUpsample2x", "SpatialUpsample2x", "SpatialUpsample2x", ), decoder_temporal_upsample: Tuple[Module] = ( "", "", "TimeUpsample2x", "TimeUpsample2x", ), decoder_mid_resnet: Module = "ResnetBlock3D", use_quant_layer: bool = True, norm_type: str = "groupnorm", ) -> None: super().__init__() self.tile_sample_min_size = 512000 self.tile_sample_min_size_t = 33 self.tile_sample_min_size_dec = 512 self.tile_sample_min_size_t_dec = 17 self.tile_latent_min_size = int(self.tile_sample_min_size_dec / (2 ** (len(hidden_size_mult) - 1))) self.tile_latent_min_size_t = int((self.tile_sample_min_size_t_dec-1) / 4) + 1 self.tile_overlap_t = 2 self.tile_overlap_factor = 0.125 self.use_tiling = False self.use_quant_layer = use_quant_layer self.encoder = Encoder( z_channels=z_channels, hidden_size=hidden_size, hidden_size_mult=hidden_size_mult, attn_resolutions=attn_resolutions, conv_in=encoder_conv_in, conv_out=encoder_conv_out, attention=encoder_attention, resnet_blocks=encoder_resnet_blocks, spatial_downsample=encoder_spatial_downsample, temporal_downsample=encoder_temporal_downsample, mid_resnet=encoder_mid_resnet, dropout=dropout, resolution=resolution, num_res_blocks=num_res_blocks, double_z=double_z, norm_type=norm_type ) self.decoder = Decoder( z_channels=z_channels, hidden_size=hidden_size, hidden_size_mult=hidden_size_mult, attn_resolutions=attn_resolutions, conv_in=decoder_conv_in, conv_out=decoder_conv_out, attention=decoder_attention, resnet_blocks=decoder_resnet_blocks, spatial_upsample=decoder_spatial_upsample, temporal_upsample=decoder_temporal_upsample, mid_resnet=decoder_mid_resnet, dropout=dropout, resolution=resolution, num_res_blocks=num_res_blocks, norm_type=norm_type ) if self.use_quant_layer: quant_conv_cls = resolve_str_to_obj(q_conv) self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1) self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1) def get_encoder(self): if self.use_quant_layer: return [self.quant_conv, self.encoder] return [self.encoder] def get_decoder(self): if self.use_quant_layer: return [self.post_quant_conv, self.decoder] return [self.decoder] def encode(self, x): if self.use_tiling and ( x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size or x.shape[-3] > self.tile_sample_min_size_t ): # import ipdb;ipdb.set_trace() return self.tiled_encode(x) h = self.encoder(x) if self.use_quant_layer: h = self.quant_conv(h) posterior = DiagonalGaussianDistribution(h) return posterior def decode(self, z): if self.use_tiling and ( z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size or z.shape[-3] > self.tile_latent_min_size_t ): return self.tiled_decode(z) if self.use_quant_layer: z = self.post_quant_conv(z) dec = self.decoder(z) return dec def forward(self, input, sample_posterior=True): posterior = self.encode(input) if sample_posterior: z = posterior.sample() else: z = posterior.mode() dec = self.decode(z) return dec, posterior def on_train_start(self): self.ema = deepcopy(self) if self.save_ema == True else None def get_last_layer(self): if hasattr(self.decoder.conv_out, "conv"): return self.decoder.conv_out.conv.weight else: return self.decoder.conv_out.weight def blend_v( self, a: torch.Tensor, b: torch.Tensor, blend_extent: int ) -> torch.Tensor: blend_extent = min(a.shape[3], b.shape[3], blend_extent) for y in range(blend_extent): b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * ( 1 - y / blend_extent ) + b[:, :, :, y, :] * (y / blend_extent) return b def blend_h( self, a: torch.Tensor, b: torch.Tensor, blend_extent: int ) -> torch.Tensor: blend_extent = min(a.shape[4], b.shape[4], blend_extent) for x in range(blend_extent): b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * ( 1 - x / blend_extent ) + b[:, :, :, :, x] * (x / blend_extent) return b def tiled_encode(self, x): t = x.shape[2] t_chunk_idx = [i for i in range(0, t, self.tile_sample_min_size_t-1)] # print('tiled_encode', t_chunk_idx) if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0: t_chunk_start_end = [[0, t]] else: t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i+1]+1+(self.tile_overlap_t-1)*4] for i in range(len(t_chunk_idx)-1)] if t_chunk_start_end[-1][-1] > t: t_chunk_start_end[-1][-1] = t elif t_chunk_start_end[-1][-1] < t: last_start_end = [t_chunk_idx[-1], t] t_chunk_start_end.append(last_start_end) moments = [] # print('tiled_encode t_chunk_start_end', t_chunk_start_end) for idx, (start, end) in enumerate(t_chunk_start_end): chunk_x = x[:, :, start: end] if idx != 0: moment = self.tiled_encode2d(chunk_x, return_moments=True)[:, :, 1+(self.tile_overlap_t-1):] else: moment = self.tiled_encode2d(chunk_x, return_moments=True) moments.append(moment) moments = torch.cat(moments, dim=2) posterior = DiagonalGaussianDistribution(moments) return posterior def tiled_decode(self, x): t = x.shape[2] t_chunk_idx = [i for i in range(0, t, self.tile_latent_min_size_t-1)] # print('tiled_decode', t_chunk_idx) if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0: t_chunk_start_end = [[0, t]] else: t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i+1]+1+(self.tile_overlap_t-1)] for i in range(len(t_chunk_idx)-1)] if t_chunk_start_end[-1][-1] > t: t_chunk_start_end[-1][-1] = t elif t_chunk_start_end[-1][-1] < t: last_start_end = [t_chunk_idx[-1], t] t_chunk_start_end.append(last_start_end) dec_ = [] # print('tiled_decode t_chunk_start_end', t_chunk_start_end) for idx, (start, end) in enumerate(t_chunk_start_end): # import ipdb;ipdb.set_trace() chunk_x = x[:, :, start: end] if idx != 0: dec = self.tiled_decode2d(chunk_x)[:, :, 1+(self.tile_overlap_t-1)*4:] else: dec = self.tiled_decode2d(chunk_x) # print(chunk_x.shape, dec.shape) dec_.append(dec) dec_ = torch.cat(dec_, dim=2) return dec_ def tiled_encode2d(self, x, return_moments=False): overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) row_limit = self.tile_latent_min_size - blend_extent # print('overlap_size, blend_extent, row_limit', overlap_size, blend_extent, row_limit) # Split the image into 512x512 tiles and encode them separately. rows = [] for i in range(0, x.shape[3], overlap_size): row = [] for j in range(0, x.shape[4], overlap_size): # print(i, j) tile = x[ :, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size, ] tile = self.encoder(tile) if self.use_quant_layer: tile = self.quant_conv(tile) row.append(tile) rows.append(row) result_rows = [] for i, row in enumerate(rows): result_row = [] for j, tile in enumerate(row): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: tile = self.blend_v(rows[i - 1][j], tile, blend_extent) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_extent) result_row.append(tile[:, :, :, :row_limit, :row_limit]) result_rows.append(torch.cat(result_row, dim=4)) moments = torch.cat(result_rows, dim=3) posterior = DiagonalGaussianDistribution(moments) if return_moments: return moments return posterior def tiled_decode2d(self, z): overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) row_limit = self.tile_sample_min_size - blend_extent # Split z into overlapping 64x64 tiles and decode them separately. # The tiles have an overlap to avoid seams between tiles. # print('tiled_decode2d', list(range(0, z.shape[3], overlap_size)), list(range(0, z.shape[4], overlap_size))) # import ipdb;ipdb.set_trace() rows = [] for i in range(0, z.shape[3], overlap_size): row = [] for j in range(0, z.shape[4], overlap_size): tile = z[ :, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size, ] if self.use_quant_layer: tile = self.post_quant_conv(tile) decoded = self.decoder(tile) row.append(decoded) rows.append(row) result_rows = [] for i, row in enumerate(rows): result_row = [] for j, tile in enumerate(row): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: tile = self.blend_v(rows[i - 1][j], tile, blend_extent) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_extent) result_row.append(tile[:, :, :, :row_limit, :row_limit]) result_rows.append(torch.cat(result_row, dim=4)) dec = torch.cat(result_rows, dim=3) return dec def enable_tiling(self, use_tiling: bool = True): self.use_tiling = use_tiling def disable_tiling(self): self.enable_tiling(False) def init_from_ckpt(self, path, ignore_keys=list()): sd = torch.load(path, map_location="cpu") print("init from " + path) if ( "ema_state_dict" in sd and len(sd["ema_state_dict"]) > 0 and os.environ.get("NOT_USE_EMA_MODEL", 0) == 0 ): print("Load from ema model!") sd = sd["ema_state_dict"] sd = {key.replace("module.", ""): value for key, value in sd.items()} elif "state_dict" in sd: print("Load from normal model!") if "gen_model" in sd["state_dict"]: sd = sd["state_dict"]["gen_model"] else: sd = sd["state_dict"] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False) ================================================ FILE: opensora/models/causalvideovae/model/vae/modeling_wfvae.py ================================================ try: import torch_npu from opensora.npu_config import npu_config except: torch_npu = None npu_config = None from ..modeling_videobase import VideoBaseAE from diffusers.configuration_utils import register_to_config import torch import torch.nn as nn from ..modules import ( ResnetBlock2D, ResnetBlock3D, Conv2d, HaarWaveletTransform3D, InverseHaarWaveletTransform3D, CausalConv3d, Normalize, AttnBlock3DFix, nonlinearity, ) import torch.nn as nn from ..utils.distrib_utils import DiagonalGaussianDistribution import torch from copy import deepcopy import os from ..registry import ModelRegistry from einops import rearrange from collections import deque from ..utils.module_utils import resolve_str_to_obj, Module from typing import List class Encoder(VideoBaseAE): @register_to_config def __init__( self, latent_dim: int = 8, base_channels: int = 128, num_resblocks: int = 2, energy_flow_hidden_size: int = 64, dropout: float = 0.0, attention_type: str = "AttnBlock3DFix", use_attention: bool = True, norm_type: str = "groupnorm", l1_dowmsample_block: str = "Downsample", l1_downsample_wavelet: str = "HaarWaveletTransform2D", l2_dowmsample_block: str = "Spatial2xTime2x3DDownsample", l2_downsample_wavelet: str = "HaarWaveletTransform3D", ) -> None: super().__init__() self.down1 = nn.Sequential( Conv2d(24, base_channels, kernel_size=3, stride=1, padding=1), *[ ResnetBlock2D( in_channels=base_channels, out_channels=base_channels, dropout=dropout, norm_type=norm_type, ) for _ in range(num_resblocks) ], resolve_str_to_obj(l1_dowmsample_block)(in_channels=base_channels, out_channels=base_channels), ) self.down2 = nn.Sequential( Conv2d( base_channels + energy_flow_hidden_size, base_channels * 2, kernel_size=3, stride=1, padding=1, ), *[ ResnetBlock3D( in_channels=base_channels * 2, out_channels=base_channels * 2, dropout=dropout, norm_type=norm_type, ) for _ in range(num_resblocks) ], resolve_str_to_obj(l2_dowmsample_block)(base_channels * 2, base_channels * 2), ) # Connection if l1_dowmsample_block == "Downsample": # Bad code. For temporal usage. l1_channels = 12 else: l1_channels = 24 self.connect_l1 = Conv2d( l1_channels, energy_flow_hidden_size, kernel_size=3, stride=1, padding=1 ) self.connect_l2 = Conv2d( 24, energy_flow_hidden_size, kernel_size=3, stride=1, padding=1 ) # Mid mid_layers = [ ResnetBlock3D( in_channels=base_channels * 2 + energy_flow_hidden_size, out_channels=base_channels * 4, dropout=dropout, norm_type=norm_type, ), ResnetBlock3D( in_channels=base_channels * 4, out_channels=base_channels * 4, dropout=dropout, norm_type=norm_type, ), ] if use_attention: mid_layers.insert( 1, resolve_str_to_obj(attention_type)(in_channels=base_channels * 4, norm_type=norm_type) ) self.mid = nn.Sequential(*mid_layers) self.norm_out = Normalize(base_channels * 4, norm_type=norm_type) self.conv_out = CausalConv3d( base_channels * 4, latent_dim * 2, kernel_size=3, stride=1, padding=1 ) self.wavelet_transform_in = HaarWaveletTransform3D() self.wavelet_transform_l1 = resolve_str_to_obj(l1_downsample_wavelet)() self.wavelet_transform_l2 = resolve_str_to_obj(l2_downsample_wavelet)() def forward(self, x): coeffs = self.wavelet_transform_in(x) l1_coeffs = coeffs[:, :3] l1_coeffs = self.wavelet_transform_l1(l1_coeffs) l1 = self.connect_l1(l1_coeffs) l2_coeffs = self.wavelet_transform_l2(l1_coeffs[:, :3]) l2 = self.connect_l2(l2_coeffs) h = self.down1(coeffs) h = torch.concat([h, l1], dim=1) h = self.down2(h) h = torch.concat([h, l2], dim=1) h = self.mid(h) if npu_config is None: h = self.norm_out(h) else: h = npu_config.run_group_norm(self.norm_out, h) h = nonlinearity(h) h = self.conv_out(h) return h, (l1_coeffs, l2_coeffs) class Decoder(VideoBaseAE): @register_to_config def __init__( self, latent_dim: int = 8, base_channels: int = 128, num_resblocks: int = 2, dropout: float = 0.0, energy_flow_hidden_size: int = 128, attention_type: str = "AttnBlock3DFix", use_attention: bool = True, norm_type: str = "groupnorm", t_interpolation: str = "nearest", connect_res_layer_num: int = 1, l1_upsample_block: str = "Upsample", l1_upsample_wavelet: str = "InverseHaarWaveletTransform2D", l2_upsample_block: str = "Spatial2xTime2x3DUpsample", l2_upsample_wavelet: str = "InverseHaarWaveletTransform3D", ) -> None: super().__init__() self.energy_flow_hidden_size = energy_flow_hidden_size self.conv_in = CausalConv3d( latent_dim, base_channels * 4, kernel_size=3, stride=1, padding=1 ) mid_layers = [ ResnetBlock3D( in_channels=base_channels * 4, out_channels=base_channels * 4, dropout=dropout, norm_type=norm_type, ), ResnetBlock3D( in_channels=base_channels * 4, out_channels=base_channels * 4 + energy_flow_hidden_size, dropout=dropout, norm_type=norm_type, ), ] if use_attention: mid_layers.insert( 1, resolve_str_to_obj(attention_type)(in_channels=base_channels * 4, norm_type=norm_type) ) self.mid = nn.Sequential(*mid_layers) self.up2 = nn.Sequential( *[ ResnetBlock3D( in_channels=base_channels * 4, out_channels=base_channels * 4, dropout=dropout, norm_type=norm_type, ) for _ in range(num_resblocks) ], resolve_str_to_obj(l2_upsample_block)( base_channels * 4, base_channels * 4, t_interpolation=t_interpolation ), ResnetBlock3D( in_channels=base_channels * 4, out_channels=base_channels * 4 + energy_flow_hidden_size, dropout=dropout, norm_type=norm_type, ), ) self.up1 = nn.Sequential( *[ ResnetBlock3D( in_channels=base_channels * (4 if i == 0 else 2), out_channels=base_channels * 2, dropout=dropout, norm_type=norm_type, ) for i in range(num_resblocks) ], resolve_str_to_obj(l1_upsample_block)(in_channels=base_channels * 2, out_channels=base_channels * 2), ResnetBlock3D( in_channels=base_channels * 2, out_channels=base_channels * 2, dropout=dropout, norm_type=norm_type, ), ) self.layer = nn.Sequential( *[ ResnetBlock3D( in_channels=base_channels * (2 if i == 0 else 1), out_channels=base_channels, dropout=dropout, norm_type=norm_type, ) for i in range(2) ], ) # Connection if l1_upsample_block == "Upsample": # Bad code. For temporal usage. l1_channels = 12 else: l1_channels = 24 self.connect_l1 = nn.Sequential( *[ ResnetBlock3D( in_channels=energy_flow_hidden_size, out_channels=energy_flow_hidden_size, dropout=dropout, norm_type=norm_type, ) for _ in range(connect_res_layer_num) ], Conv2d(energy_flow_hidden_size, l1_channels, kernel_size=3, stride=1, padding=1), ) self.connect_l2 = nn.Sequential( *[ ResnetBlock3D( in_channels=energy_flow_hidden_size, out_channels=energy_flow_hidden_size, dropout=dropout, norm_type=norm_type, ) for _ in range(connect_res_layer_num) ], Conv2d(energy_flow_hidden_size, 24, kernel_size=3, stride=1, padding=1), ) # Out self.norm_out = Normalize(base_channels, norm_type=norm_type) self.conv_out = Conv2d(base_channels, 24, kernel_size=3, stride=1, padding=1) self.inverse_wavelet_transform_out = InverseHaarWaveletTransform3D() self.inverse_wavelet_transform_l1 = resolve_str_to_obj(l1_upsample_wavelet)() self.inverse_wavelet_transform_l2 = resolve_str_to_obj(l2_upsample_wavelet)() def forward(self, z): h = self.conv_in(z) h = self.mid(h) l2_coeffs = self.connect_l2(h[:, -self.energy_flow_hidden_size :]) l2 = self.inverse_wavelet_transform_l2(l2_coeffs) h = self.up2(h[:, : -self.energy_flow_hidden_size]) l1_coeffs = h[:, -self.energy_flow_hidden_size :] l1_coeffs = self.connect_l1(l1_coeffs) l1_coeffs[:, :3] = l1_coeffs[:, :3] + l2 l1 = self.inverse_wavelet_transform_l1(l1_coeffs) h = self.up1(h[:, : -self.energy_flow_hidden_size]) h = self.layer(h) if npu_config is None: h = self.norm_out(h) else: h = npu_config.run_group_norm(self.norm_out, h) h = nonlinearity(h) h = self.conv_out(h) h[:, :3] = h[:, :3] + l1 dec = self.inverse_wavelet_transform_out(h) return dec, (l1_coeffs, l2_coeffs) @ModelRegistry.register("WFVAE") class WFVAEModel(VideoBaseAE): @register_to_config def __init__( self, latent_dim: int = 8, base_channels: int = 128, encoder_num_resblocks: int = 2, encoder_energy_flow_hidden_size: int = 64, decoder_num_resblocks: int = 2, decoder_energy_flow_hidden_size: int = 128, attention_type: str = "AttnBlock3DFix", use_attention: bool = True, dropout: float = 0.0, norm_type: str = "groupnorm", t_interpolation: str = "nearest", connect_res_layer_num: int = 1, scale: List[float] = [0.18215, 0.18215, 0.18215, 0.18215, 0.18215, 0.18215, 0.18215, 0.18215], shift: List[float] = [0, 0, 0, 0, 0, 0, 0, 0], # Module config l1_dowmsample_block: str = "Downsample", l1_downsample_wavelet: str = "HaarWaveletTransform2D", l2_dowmsample_block: str = "Spatial2xTime2x3DDownsample", l2_downsample_wavelet: str = "HaarWaveletTransform3D", l1_upsample_block: str = "Upsample", l1_upsample_wavelet: str = "InverseHaarWaveletTransform2D", l2_upsample_block: str = "Spatial2xTime2x3DUpsample", l2_upsample_wavelet: str = "InverseHaarWaveletTransform3D", ) -> None: super().__init__() self.use_tiling = False # Hardcode for now self.t_chunk_enc = 8 self.t_chunk_dec = 2 self.t_upsample_times = 4 // 2 self.use_quant_layer = False self.encoder = Encoder( latent_dim=latent_dim, base_channels=base_channels, num_resblocks=encoder_num_resblocks, energy_flow_hidden_size=encoder_energy_flow_hidden_size, dropout=dropout, use_attention=use_attention, norm_type=norm_type, l1_dowmsample_block=l1_dowmsample_block, l1_downsample_wavelet=l1_downsample_wavelet, l2_dowmsample_block=l2_dowmsample_block, l2_downsample_wavelet=l2_downsample_wavelet, attention_type=attention_type ) self.decoder = Decoder( latent_dim=latent_dim, base_channels=base_channels, num_resblocks=decoder_num_resblocks, energy_flow_hidden_size=decoder_energy_flow_hidden_size, dropout=dropout, use_attention=use_attention, norm_type=norm_type, t_interpolation=t_interpolation, connect_res_layer_num=connect_res_layer_num, l1_upsample_block=l1_upsample_block, l1_upsample_wavelet=l1_upsample_wavelet, l2_upsample_block=l2_upsample_block, l2_upsample_wavelet=l2_upsample_wavelet, attention_type=attention_type ) # Set cache offset for trilinear lossless upsample. self._set_cache_offset([self.decoder.up2, self.decoder.connect_l2, self.decoder.conv_in, self.decoder.mid], 1) self._set_cache_offset([self.decoder.up2[-2:], self.decoder.up1, self.decoder.connect_l1, self.decoder.layer], self.t_upsample_times) def get_encoder(self): if self.use_quant_layer: return [self.quant_conv, self.encoder] return [self.encoder] def get_decoder(self): if self.use_quant_layer: return [self.post_quant_conv, self.decoder] return [self.decoder] def _empty_causal_cached(self, parent): for name, module in parent.named_modules(): if hasattr(module, 'causal_cached'): module.causal_cached = deque() def _set_causal_cached(self, enable_cached=True): for name, module in self.named_modules(): if hasattr(module, 'enable_cached'): module.enable_cached = enable_cached def _set_cache_offset(self, modules, cache_offset=0): for module in modules: for submodule in module.modules(): if hasattr(submodule, 'cache_offset'): submodule.cache_offset = cache_offset def _set_first_chunk(self, is_first_chunk=True): for module in self.modules(): if hasattr(module, 'is_first_chunk'): module.is_first_chunk = is_first_chunk def build_chunk_start_end(self, t, decoder_mode=False): start_end = [[0, 1]] start = 1 end = start while True: if start >= t: break end = min(t, end + (self.t_chunk_dec if decoder_mode else self.t_chunk_enc) ) start_end.append([start, end]) start = end return start_end def encode(self, x): self._empty_causal_cached(self.encoder) self._set_first_chunk(True) if self.use_tiling: h = self.tile_encode(x) l1, l2 = None, None else: h, (l1, l2) = self.encoder(x) if self.use_quant_layer: h = self.quant_conv(h) posterior = DiagonalGaussianDistribution(h) return posterior def tile_encode(self, x): b, c, t, h, w = x.shape start_end = self.build_chunk_start_end(t) result = [] for idx, (start, end) in enumerate(start_end): self._set_first_chunk(idx == 0) chunk = x[:, :, start:end, :, :] chunk = self.encoder(chunk)[0] if self.use_quant_layer: chunk = self.quant_conv(chunk) result.append(chunk) return torch.cat(result, dim=2) def decode(self, z): self._empty_causal_cached(self.decoder) self._set_first_chunk(True) if self.use_tiling: dec = self.tile_decode(z) l1, l2 = None, None else: if self.use_quant_layer: z = self.post_quant_conv(z) dec, (l1, l2) = self.decoder(z) return dec def tile_decode(self, x): b, c, t, h, w = x.shape start_end = self.build_chunk_start_end(t, decoder_mode=True) result = [] for idx, (start, end) in enumerate(start_end): self._set_first_chunk(idx==0) if end + 1 < t: chunk = x[:, :, start:end+1, :, :] else: chunk = x[:, :, start:end, :, :] if self.use_quant_layer: chunk = self.post_quant_conv(chunk) chunk = self.decoder(chunk)[0] if end + 1 < t: chunk = chunk[:, :, :-4] result.append(chunk.clone()) else: result.append(chunk.clone()) return torch.cat(result, dim=2) def forward(self, input, sample_posterior=True): posterior = self.encode(input) if sample_posterior: z = posterior.sample() else: z = posterior.mode() dec = self.decode(z) return dec, posterior def get_last_layer(self): if hasattr(self.decoder.conv_out, "conv"): return self.decoder.conv_out.conv.weight else: return self.decoder.conv_out.weight def enable_tiling(self, use_tiling: bool = True): self.use_tiling = use_tiling self._set_causal_cached(use_tiling) def disable_tiling(self): self.enable_tiling(False) def init_from_ckpt(self, path, ignore_keys=list()): sd = torch.load(path, map_location="cpu") print("init from " + path) if ( "ema_state_dict" in sd and len(sd["ema_state_dict"]) > 0 and os.environ.get("NOT_USE_EMA_MODEL", 0) == 0 ): print("Load from ema model!") sd = sd["ema_state_dict"] sd = {key.replace("module.", ""): value for key, value in sd.items()} elif "state_dict" in sd: print("Load from normal model!") if "gen_model" in sd["state_dict"]: sd = sd["state_dict"]["gen_model"] else: sd = sd["state_dict"] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False) print(missing_keys, unexpected_keys) ================================================ FILE: opensora/models/causalvideovae/sample/rec_video_vae.py ================================================ import argparse from tqdm import tqdm import torch import sys from torch.utils.data import DataLoader, Subset import os from accelerate import Accelerator sys.path.append(".") from opensora.models.causalvideovae.model import * from opensora.models.causalvideovae.dataset.video_dataset import ValidVideoDataset from opensora.models.causalvideovae.utils.video_utils import custom_to_video @torch.no_grad() def main(args: argparse.Namespace): accelerator = Accelerator() device = accelerator.device real_video_dir = args.real_video_dir generated_video_dir = args.generated_video_dir sample_rate = args.sample_rate resolution = args.resolution crop_size = args.crop_size num_frames = args.num_frames sample_rate = args.sample_rate device = args.device sample_fps = args.sample_fps batch_size = args.batch_size num_workers = args.num_workers subset_size = args.subset_size if not os.path.exists(args.generated_video_dir): os.makedirs(args.generated_video_dir, exist_ok=True) data_type = torch.bfloat16 # ---- Load Model ---- device = args.device model_cls = ModelRegistry.get_model(args.model_name) vae = model_cls.from_pretrained(args.from_pretrained) vae = vae.to(device).to(data_type) if args.enable_tiling: vae.enable_tiling() vae.tile_overlap_factor = args.tile_overlap_factor # ---- Prepare Dataset ---- dataset = ValidVideoDataset( real_video_dir=real_video_dir, num_frames=num_frames, sample_rate=sample_rate, crop_size=crop_size, resolution=resolution, ) if subset_size: indices = range(subset_size) dataset = Subset(dataset, indices=indices) dataloader = DataLoader( dataset, batch_size=batch_size, pin_memory=False, num_workers=num_workers ) dataloader = accelerator.prepare(dataloader) # ---- Inference ---- for batch in tqdm(dataloader, disable=not accelerator.is_local_main_process): x, file_names = batch['video'], batch['file_name'] x = x.to(device=device, dtype=data_type) # b c t h w x = x * 2 - 1 encode_result = vae.encode(x) if isinstance(encode_result, tuple): encode_result = encode_result[0] latents = encode_result.sample().to(data_type) video_recon = vae.decode(latents) if isinstance(video_recon, tuple): video_recon = video_recon[0] for idx, video in enumerate(video_recon): output_path = os.path.join(generated_video_dir, file_names[idx]) if args.output_origin: os.makedirs(os.path.join(generated_video_dir, "origin/"), exist_ok=True) origin_output_path = os.path.join(generated_video_dir, "origin/", file_names[idx]) custom_to_video( x[idx], fps=sample_fps / sample_rate, output_file=origin_output_path ) custom_to_video( video, fps=sample_fps / sample_rate, output_file=output_path ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--real_video_dir", type=str, default="") parser.add_argument("--generated_video_dir", type=str, default="") parser.add_argument("--from_pretrained", type=str, default="") parser.add_argument("--sample_fps", type=int, default=30) parser.add_argument("--resolution", type=int, default=336) parser.add_argument("--crop_size", type=int, default=None) parser.add_argument("--num_frames", type=int, default=17) parser.add_argument("--sample_rate", type=int, default=1) parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--num_workers", type=int, default=8) parser.add_argument("--subset_size", type=int, default=None) parser.add_argument("--tile_overlap_factor", type=float, default=0.25) parser.add_argument('--enable_tiling', action='store_true') parser.add_argument('--output_origin', action='store_true') parser.add_argument("--model_name", type=str, default=None, help="") parser.add_argument("--device", type=str, default="cuda") args = parser.parse_args() main(args) ================================================ FILE: opensora/models/causalvideovae/utils/__init__.py ================================================ ================================================ FILE: opensora/models/causalvideovae/utils/dataset_utils.py ================================================ import math from einops import rearrange import decord from torch.nn import functional as F import torch IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG'] def is_image_file(filename): return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) class DecordInit(object): """Using Decord(https://github.com/dmlc/decord) to initialize the video_reader.""" def __init__(self, num_threads=1): self.num_threads = num_threads self.ctx = decord.cpu(0) def __call__(self, filename): """Perform the Decord initialization. Args: results (dict): The resulting dict to be modified and passed to the next transform in pipeline. """ reader = decord.VideoReader(filename, ctx=self.ctx, num_threads=self.num_threads) return reader def __repr__(self): repr_str = (f'{self.__class__.__name__}(' f'sr={self.sr},' f'num_threads={self.num_threads})') return repr_str def pad_to_multiple(number, ds_stride): remainder = number % ds_stride if remainder == 0: return number else: padding = ds_stride - remainder return number + padding ================================================ FILE: opensora/models/causalvideovae/utils/downloader.py ================================================ import gdown import os opensora_cache_home = os.path.expanduser( os.getenv("OPENSORA_HOME", os.path.join("~/.cache", "opensora")) ) def gdown_download(id, fname, cache_dir=None): cache_dir = opensora_cache_home if not cache_dir else cache_dir os.makedirs(cache_dir, exist_ok=True) destination = os.path.join(cache_dir, fname) if os.path.exists(destination): return destination gdown.download(id=id, output=destination, quiet=False) return destination ================================================ FILE: opensora/models/causalvideovae/utils/video_utils.py ================================================ import torch import numpy as np import numpy.typing as npt import cv2 from decord import VideoReader, cpu def array_to_video( image_array: npt.NDArray, fps: float = 30.0, output_file: str = "output_video.mp4" ) -> None: """b h w c""" height, width, channels = image_array[0].shape fourcc = cv2.VideoWriter_fourcc(*"mp4v") video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height)) for image in image_array: image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) video_writer.write(image_rgb) video_writer.release() def custom_to_video( x: torch.Tensor, fps: float = 2.0, output_file: str = "output_video.mp4" ) -> None: x = x.detach().cpu() x = torch.clamp(x, -1, 1) x = (x + 1) / 2 x = x.permute(1, 2, 3, 0).float().numpy() x = (255 * x).astype(np.uint8) array_to_video(x, fps=fps, output_file=output_file) return def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor: decord_vr = VideoReader(video_path, ctx=cpu(0), num_threads=8) total_frames = len(decord_vr) sample_frames_len = sample_rate * num_frames if total_frames > sample_frames_len: s = 0 e = s + sample_frames_len num_frames = num_frames else: s = 0 e = total_frames num_frames = int(total_frames / sample_frames_len * num_frames) print( f"sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}", video_path, total_frames, ) frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int) video_data = decord_vr.get_batch(frame_id_list).asnumpy() video_data = torch.from_numpy(video_data) video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W) return video_data def tensor_to_video(x): """[0-1] tensor to video""" x = (x * 2 - 1).detach().cpu() x = torch.clamp(x, -1, 1) x = (x + 1) / 2 x = x.permute(1, 0, 2, 3).float().numpy() # c t h w -> t c h w x = (255 * x).astype(np.uint8) return x ================================================ FILE: opensora/models/diffusion/__init__.py ================================================ from .opensora_v1_3.modeling_opensora import OpenSora_v1_3_models from .opensora_v1_3.modeling_inpaint import OpenSoraInpaint_v1_3_models Diffusion_models = {} Diffusion_models.update(OpenSora_v1_3_models) Diffusion_models.update(OpenSoraInpaint_v1_3_models) from .opensora_v1_3.modeling_opensora import OpenSora_v1_3_models_class from .opensora_v1_3.modeling_inpaint import OpenSoraInpaint_v1_3_models_class Diffusion_models_class = {} Diffusion_models_class.update(OpenSora_v1_3_models_class) Diffusion_models_class.update(OpenSoraInpaint_v1_3_models_class) ================================================ FILE: opensora/models/diffusion/common.py ================================================ import torch from einops import rearrange, repeat from typing import Any, Dict, Optional, Tuple import torch import torch.nn.functional as F from torch import nn from diffusers.models.attention_processor import Attention as Attention_ try: import torch_npu from opensora.npu_config import npu_config, set_run_dtype from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info as xccl_info from opensora.acceleration.communications import all_to_all_SBH except: torch_npu = None npu_config = None set_run_dtype = None from opensora.utils.parallel_states import get_sequence_parallel_state, nccl_info as xccl_info from opensora.utils.communications import all_to_all_SBH class PatchEmbed2D(nn.Module): """2D Image to Patch Embedding but with video""" def __init__( self, patch_size=16, in_channels=3, embed_dim=768, bias=True, ): super().__init__() self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), bias=bias ) def forward(self, latent): b, _, _, _, _ = latent.shape latent = rearrange(latent, 'b c t h w -> (b t) c h w') latent = self.proj(latent) latent = rearrange(latent, '(b t) c h w -> b (t h w) c', b=b) return latent class PositionGetter3D(object): """ return positions of patches """ def __init__(self, ): self.cache_positions = {} def __call__(self, b, t, h, w, device): if not (b,t,h,w) in self.cache_positions: x = torch.arange(w, device=device) y = torch.arange(h, device=device) z = torch.arange(t, device=device) pos = torch.cartesian_prod(z, y, x) # print('PositionGetter3D', PositionGetter3D) pos = pos.reshape(t * h * w, 3).transpose(0, 1).reshape(3, -1, 1).contiguous().expand(3, -1, b).clone() poses = (pos[0].contiguous(), pos[1].contiguous(), pos[2].contiguous()) max_poses = (int(poses[0].max()), int(poses[1].max()), int(poses[2].max())) self.cache_positions[b, t, h, w] = (poses, max_poses) pos = self.cache_positions[b, t, h, w] return pos class RoPE3D(torch.nn.Module): def __init__(self, freq=10000.0, F0=1.0, interpolation_scale_thw=(1, 1, 1)): super().__init__() self.base = freq self.F0 = F0 self.interpolation_scale_t = interpolation_scale_thw[0] self.interpolation_scale_h = interpolation_scale_thw[1] self.interpolation_scale_w = interpolation_scale_thw[2] self.cache = {} def get_cos_sin(self, D, seq_len, device, dtype, interpolation_scale=1): if (D, seq_len, device, dtype) not in self.cache: inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) / interpolation_scale freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) freqs = torch.cat((freqs, freqs), dim=-1) cos = freqs.cos() # (Seq, Dim) sin = freqs.sin() self.cache[D, seq_len, device, dtype] = (cos, sin) return self.cache[D, seq_len, device, dtype] @staticmethod def rotate_half(x): x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) def apply_rope1d(self, tokens, pos1d, cos, sin): assert pos1d.ndim == 2 # for (ntokens x batch_size x nheads x dim) cos = torch.nn.functional.embedding(pos1d, cos)[:, :, None, :] sin = torch.nn.functional.embedding(pos1d, sin)[:, :, None, :] return (tokens * cos) + (self.rotate_half(tokens) * sin) def forward(self, tokens, positions): """ input: * tokens: ntokens x batch_size x nheads x dim * positions: batch_size x ntokens x 3 (t, y and x position of each token) output: * tokens after appplying RoPE3D (ntokens x batch_size x nheads x dim) """ assert tokens.size(3) % 3 == 0, "number of dimensions should be a multiple of three" D = tokens.size(3) // 3 poses, max_poses = positions assert len(poses) == 3 and poses[0].ndim == 2# Batch, Seq, 3 cos_t, sin_t = self.get_cos_sin(D, max_poses[0] + 1, tokens.device, tokens.dtype, self.interpolation_scale_t) cos_y, sin_y = self.get_cos_sin(D, max_poses[1] + 1, tokens.device, tokens.dtype, self.interpolation_scale_h) cos_x, sin_x = self.get_cos_sin(D, max_poses[2] + 1, tokens.device, tokens.dtype, self.interpolation_scale_w) # split features into three along the feature dimension, and apply rope1d on each half t, y, x = tokens.chunk(3, dim=-1) t = self.apply_rope1d(t, poses[0], cos_t, sin_t) y = self.apply_rope1d(y, poses[1], cos_y, sin_y) x = self.apply_rope1d(x, poses[2], cos_x, sin_x) tokens = torch.cat((t, y, x), dim=-1) return tokens ================================================ FILE: opensora/models/diffusion/opensora_v1_3/__init__.py ================================================ ================================================ FILE: opensora/models/diffusion/opensora_v1_3/modeling_inpaint.py ================================================ import os import numpy as np from torch import nn import torch from einops import rearrange, repeat from typing import Any, Dict, Optional, Tuple from diffusers.configuration_utils import register_to_config from opensora.models.diffusion.common import PatchEmbed2D from opensora.utils.utils import to_2tuple from opensora.models.diffusion.opensora_v1_3.modeling_opensora import OpenSoraT2V_v1_3 as OpenSoraT2V import glob def zero_module(module): for p in module.parameters(): nn.init.zeros_(p) return module class OpenSoraInpaint_v1_3(OpenSoraT2V): _supports_gradient_checkpointing = True @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, cross_attention_dim: Optional[int] = None, attention_bias: bool = True, sample_size_h: Optional[int] = None, sample_size_w: Optional[int] = None, sample_size_t: Optional[int] = None, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None, activation_fn: str = "geglu", only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = False, norm_eps: float = 1e-6, caption_channels: int = None, interpolation_scale_h: float = 1.0, interpolation_scale_w: float = 1.0, interpolation_scale_t: float = 1.0, sparse1d: bool = False, sparse_n: int = 2, # inpaint vae_scale_factor_t: int = 4, ): super().__init__( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, in_channels=in_channels, out_channels=out_channels, num_layers=num_layers, dropout=dropout, cross_attention_dim=cross_attention_dim, attention_bias=attention_bias, sample_size_h=sample_size_h, sample_size_w=sample_size_w, sample_size_t=sample_size_t, patch_size=patch_size, patch_size_t=patch_size_t, activation_fn=activation_fn, only_cross_attention=only_cross_attention, double_self_attention=double_self_attention, upcast_attention=upcast_attention, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, caption_channels=caption_channels, interpolation_scale_h=interpolation_scale_h, interpolation_scale_w=interpolation_scale_w, interpolation_scale_t=interpolation_scale_t, sparse1d=sparse1d, sparse_n=sparse_n, ) self.vae_scale_factor_t = vae_scale_factor_t # init masked_pixel_values and mask conv_in self._init_patched_inputs_for_inpainting() def _init_patched_inputs_for_inpainting(self): self.config.sample_size = to_2tuple(self.config.sample_size) self.pos_embed_masked_hidden_states = nn.ModuleList( [ PatchEmbed2D( patch_size=self.config.patch_size, in_channels=self.config.in_channels, embed_dim=self.config.hidden_size, ), zero_module(nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)), ] ) self.pos_embed_mask = nn.ModuleList( [ PatchEmbed2D( patch_size=self.config.patch_size, in_channels=self.vae_scale_factor_t, embed_dim=self.config.hidden_size, ), zero_module(nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)), ] ) def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, batch_size, frame): # inpaint assert hidden_states.shape[1] == 2 * self.config.in_channels + self.vae_scale_factor_t in_channels = self.config.in_channels input_hidden_states, input_masked_hidden_states, input_mask = hidden_states[:, :in_channels], hidden_states[:, in_channels: 2 * in_channels], hidden_states[:, 2 * in_channels:] input_hidden_states = self.pos_embed(input_hidden_states.to(self.dtype)) input_masked_hidden_states = self.pos_embed_masked_hidden_states[0](input_masked_hidden_states.to(self.dtype)) input_masked_hidden_states = self.pos_embed_masked_hidden_states[1](input_masked_hidden_states) input_mask = self.pos_embed_mask[0](input_mask.to(self.dtype)) input_mask = self.pos_embed_mask[1](input_mask) hidden_states = input_hidden_states + input_masked_hidden_states + input_mask added_cond_kwargs = {"resolution": None, "aspect_ratio": None} timestep, embedded_timestep = self.adaln_single( timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=self.dtype ) # b 6d, b d encoder_hidden_states = self.caption_projection(encoder_hidden_states) # b, 1, l, d or b, 1, l, d assert encoder_hidden_states.shape[1] == 1 encoder_hidden_states = rearrange(encoder_hidden_states, 'b 1 l d -> (b 1) l d') return hidden_states, encoder_hidden_states, timestep, embedded_timestep def OpenSoraInpaint_v1_3_2B_122(**kwargs): return OpenSoraInpaint_v1_3( num_layers=32, attention_head_dim=96, num_attention_heads=24, patch_size_t=1, patch_size=2, caption_channels=4096, cross_attention_dim=2304, activation_fn="gelu-approximate", **kwargs ) OpenSoraInpaint_v1_3_models = { "OpenSoraInpaint_v1_3-2B/122": OpenSoraInpaint_v1_3_2B_122, # 2.7B } OpenSoraInpaint_v1_3_models_class = { "OpenSoraInpaint_v1_3-2B/122": OpenSoraInpaint_v1_3, "OpenSoraInpaint_v1_3": OpenSoraInpaint_v1_3, } ================================================ FILE: opensora/models/diffusion/opensora_v1_3/modeling_opensora.py ================================================ import os import numpy as np from torch import nn import torch from einops import rearrange, repeat from typing import Any, Dict, Optional, Tuple from torch.nn import functional as F from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.utils import is_torch_version, deprecate from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin from diffusers.models.normalization import AdaLayerNormSingle from diffusers.models.embeddings import PixArtAlphaTextProjection from opensora.models.diffusion.opensora_v1_3.modules import BasicTransformerBlock, Attention from opensora.models.diffusion.common import PatchEmbed2D from opensora.utils.utils import to_2tuple try: import torch_npu from opensora.npu_config import npu_config from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info except: torch_npu = None npu_config = None from opensora.utils.parallel_states import get_sequence_parallel_state, nccl_info class OpenSoraT2V_v1_3(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, cross_attention_dim: Optional[int] = None, attention_bias: bool = True, sample_size_h: Optional[int] = None, sample_size_w: Optional[int] = None, sample_size_t: Optional[int] = None, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None, activation_fn: str = "geglu", only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = False, norm_eps: float = 1e-6, caption_channels: int = None, interpolation_scale_h: float = 1.0, interpolation_scale_w: float = 1.0, interpolation_scale_t: float = 1.0, sparse1d: bool = False, sparse_n: int = 2, ): super().__init__() # Set some common variables used across the board. self.out_channels = in_channels if out_channels is None else out_channels self.config.hidden_size = self.config.num_attention_heads * self.config.attention_head_dim self.gradient_checkpointing = False self._init_patched_inputs() def _init_patched_inputs(self): self.config.sample_size = (self.config.sample_size_h, self.config.sample_size_w) interpolation_scale_thw = ( self.config.interpolation_scale_t, self.config.interpolation_scale_h, self.config.interpolation_scale_w ) self.caption_projection = PixArtAlphaTextProjection( in_features=self.config.caption_channels, hidden_size=self.config.hidden_size ) self.pos_embed = PatchEmbed2D( patch_size=self.config.patch_size, in_channels=self.config.in_channels, embed_dim=self.config.hidden_size, ) self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( self.config.hidden_size, self.config.num_attention_heads, self.config.attention_head_dim, dropout=self.config.dropout, cross_attention_dim=self.config.cross_attention_dim, activation_fn=self.config.activation_fn, attention_bias=self.config.attention_bias, only_cross_attention=self.config.only_cross_attention, double_self_attention=self.config.double_self_attention, upcast_attention=self.config.upcast_attention, norm_elementwise_affine=self.config.norm_elementwise_affine, norm_eps=self.config.norm_eps, interpolation_scale_thw=interpolation_scale_thw, sparse1d=self.config.sparse1d if i > 1 and i < 30 else False, sparse_n=self.config.sparse_n, sparse_group=i % 2 == 1, ) for i in range(self.config.num_layers) ] ) self.norm_out = nn.LayerNorm(self.config.hidden_size, elementwise_affine=False, eps=1e-6) self.scale_shift_table = nn.Parameter(torch.randn(2, self.config.hidden_size) / self.config.hidden_size**0.5) self.proj_out = nn.Linear( self.config.hidden_size, self.config.patch_size_t * self.config.patch_size * self.config.patch_size * self.out_channels ) self.adaln_single = AdaLayerNormSingle(self.config.hidden_size) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( self, hidden_states: torch.Tensor, timestep: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, **kwargs, ): batch_size, c, frame, h, w = hidden_states.shape # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None and attention_mask.ndim == 4: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) # b, frame, h, w -> a video # b, 1, h, w -> only images attention_mask = attention_mask.to(self.dtype) attention_mask = attention_mask.unsqueeze(1) # b 1 t h w attention_mask = F.max_pool3d( attention_mask, kernel_size=(self.config.patch_size_t, self.config.patch_size, self.config.patch_size), stride=(self.config.patch_size_t, self.config.patch_size, self.config.patch_size) ) attention_mask = rearrange(attention_mask, 'b 1 t h w -> (b 1) 1 (t h w)') attention_mask = (1 - attention_mask.bool().to(self.dtype)) * -10000.0 # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # b, 1, l encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0 # 1. Input frame = ((frame - 1) // self.config.patch_size_t + 1) if frame % 2 == 1 else frame // self.config.patch_size_t # patchfy height, width = hidden_states.shape[-2] // self.config.patch_size, hidden_states.shape[-1] // self.config.patch_size hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs( hidden_states, encoder_hidden_states, timestep, batch_size, frame ) # To # x (t*h*w b d) or (t//sp*h*w b d) # cond_1 (l b d) or (l//sp b d) hidden_states = rearrange(hidden_states, 'b s h -> s b h', b=batch_size).contiguous() encoder_hidden_states = rearrange(encoder_hidden_states, 'b s h -> s b h', b=batch_size).contiguous() timestep = timestep.view(batch_size, 6, -1).transpose(0, 1).contiguous() sparse_mask = {} if npu_config is None: if get_sequence_parallel_state(): head_num = self.config.num_attention_heads // nccl_info.world_size else: head_num = self.config.num_attention_heads else: head_num = None for sparse_n in [1, 4]: sparse_mask[sparse_n] = Attention.prepare_sparse_mask(attention_mask, encoder_attention_mask, sparse_n, head_num) # 2. Blocks for i, block in enumerate(self.transformer_blocks): if i > 1 and i < 30: attention_mask, encoder_attention_mask = sparse_mask[block.attn1.processor.sparse_n][block.attn1.processor.sparse_group] else: attention_mask, encoder_attention_mask = sparse_mask[1][block.attn1.processor.sparse_group] if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, frame, height, width, **ckpt_kwargs, ) else: hidden_states = block( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, timestep=timestep, frame=frame, height=height, width=width, ) # To (b, t*h*w, h) or (b, t//sp*h*w, h) hidden_states = rearrange(hidden_states, 's b h -> b s h', b=batch_size).contiguous() # 3. Output output = self._get_output_for_patched_inputs( hidden_states=hidden_states, timestep=timestep, embedded_timestep=embedded_timestep, num_frames=frame, height=height, width=width, ) # b c t h w if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, batch_size, frame): hidden_states = self.pos_embed(hidden_states.to(self.dtype)) added_cond_kwargs = {"resolution": None, "aspect_ratio": None} timestep, embedded_timestep = self.adaln_single( timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=self.dtype ) # b 6d, b d encoder_hidden_states = self.caption_projection(encoder_hidden_states) # b, 1, l, d or b, 1, l, d assert encoder_hidden_states.shape[1] == 1 encoder_hidden_states = rearrange(encoder_hidden_states, 'b 1 l d -> (b 1) l d') return hidden_states, encoder_hidden_states, timestep, embedded_timestep def _get_output_for_patched_inputs( self, hidden_states, timestep, embedded_timestep, num_frames, height, width ): shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) # Modulation hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.squeeze(1) # unpatchify hidden_states = hidden_states.reshape( shape=(-1, num_frames, height, width, self.config.patch_size_t, self.config.patch_size, self.config.patch_size, self.out_channels) ) hidden_states = torch.einsum("nthwopqc->nctohpwq", hidden_states) output = hidden_states.reshape( shape=(-1, self.out_channels, num_frames * self.config.patch_size_t, height * self.config.patch_size, width * self.config.patch_size) ) return output def OpenSoraT2V_v1_3_2B_122(**kwargs): kwargs.pop('skip_connection', None) return OpenSoraT2V_v1_3( num_layers=32, attention_head_dim=96, num_attention_heads=24, patch_size_t=1, patch_size=2, caption_channels=4096, cross_attention_dim=2304, activation_fn="gelu-approximate", **kwargs ) OpenSora_v1_3_models = { "OpenSoraT2V_v1_3-2B/122": OpenSoraT2V_v1_3_2B_122, # 2.7B } OpenSora_v1_3_models_class = { "OpenSoraT2V_v1_3-2B/122": OpenSoraT2V_v1_3, "OpenSoraT2V_v1_3": OpenSoraT2V_v1_3, } if __name__ == '__main__': from opensora.models.causalvideovae import ae_stride_config, ae_channel_config from opensora.models.causalvideovae import ae_norm, ae_denorm from opensora.models import CausalVAEModelWrapper args = type('args', (), { 'ae': 'WFVAEModel_D8_4x8x8', 'model_max_length': 300, 'max_height': 176, 'max_width': 176, 'num_frames': 33, 'compress_kv_factor': 1, 'interpolation_scale_t': 1, 'interpolation_scale_h': 1, 'interpolation_scale_w': 1, "sparse1d": True, "sparse_n": 4, "rank": 64, } ) b = 2 c = 8 cond_c = 4096 num_timesteps = 1000 ae_stride_t, ae_stride_h, ae_stride_w = ae_stride_config[args.ae] latent_size = (args.max_height // ae_stride_h, args.max_width // ae_stride_w) num_frames = (args.num_frames - 1) // ae_stride_t + 1 device = torch.device('cuda:0') model = OpenSoraT2V_v1_3_2B_122( in_channels=c, out_channels=c, sample_size_h=latent_size, sample_size_w=latent_size, sample_size_t=num_frames, activation_fn="gelu-approximate", attention_bias=True, double_self_attention=False, norm_elementwise_affine=False, norm_eps=1e-06, only_cross_attention=False, upcast_attention=False, interpolation_scale_t=args.interpolation_scale_t, interpolation_scale_h=args.interpolation_scale_h, interpolation_scale_w=args.interpolation_scale_w, sparse1d=args.sparse1d, sparse_n=args.sparse_n ) try: path = "/storage/ongoing/new/7.19anyres/Open-Sora-Plan/bs32x8x1_anyx93x640x640_fps16_lr1e-5_snr5_ema9999_sparse1d4_dit_l_mt5xxl_vpred_zerosnr/checkpoint-43000/model_ema/diffusion_pytorch_model.safetensors" # ckpt = torch.load(path, map_location="cpu") from safetensors.torch import load_file as safe_load ckpt = safe_load(path, device="cpu") msg = model.load_state_dict(ckpt, strict=True) print(msg) except Exception as e: print(e) print(model) print(f'{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9} B') # import sys;sys.exit() model = model.to(device) x = torch.randn(b, c, 1+(args.num_frames-1)//ae_stride_t, args.max_height//ae_stride_h, args.max_width//ae_stride_w).to(device) cond = torch.randn(b, 1, args.model_max_length, cond_c).to(device) attn_mask = torch.randint(0, 2, (b, 1+(args.num_frames-1)//ae_stride_t, args.max_height//ae_stride_h, args.max_width//ae_stride_w)).to(device) # B L or B 1+num_images L cond_mask = torch.randint(0, 2, (b, 1, args.model_max_length)).to(device) # B L or B 1+num_images L timestep = torch.randint(0, 1000, (b,), device=device) model_kwargs = dict( hidden_states=x, encoder_hidden_states=cond, attention_mask=attn_mask, encoder_attention_mask=cond_mask, timestep=timestep ) with torch.no_grad(): output = model(**model_kwargs) print(output[0].shape) ================================================ FILE: opensora/models/diffusion/opensora_v1_3/modules.py ================================================ import torch from einops import rearrange, repeat from typing import Any, Dict, Optional, Tuple import torch.nn.functional as F from torch import nn from typing import Optional, Tuple from diffusers.utils import logging from diffusers.utils.torch_utils import maybe_allow_in_graph from diffusers.models.attention import FeedForward from diffusers.models.attention_processor import Attention as Attention_ from diffusers.models.embeddings import Timesteps, TimestepEmbedding try: import torch_npu from opensora.npu_config import npu_config, set_run_dtype from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info as xccl_info from opensora.acceleration.communications import all_to_all_SBH except: torch_npu = None npu_config = None set_run_dtype = None from opensora.utils.parallel_states import get_sequence_parallel_state, nccl_info as xccl_info from opensora.utils.communications import all_to_all_SBH from ..common import RoPE3D, PositionGetter3D logger = logging.get_logger(__name__) class Attention(Attention_): def __init__( self, interpolation_scale_thw, sparse1d, sparse_n, sparse_group, is_cross_attn, **kwags ): processor = OpenSoraAttnProcessor2_0( interpolation_scale_thw=interpolation_scale_thw, sparse1d=sparse1d, sparse_n=sparse_n, sparse_group=sparse_group, is_cross_attn=is_cross_attn ) super().__init__(processor=processor, **kwags) @staticmethod def prepare_sparse_mask(attention_mask, encoder_attention_mask, sparse_n, head_num): attention_mask = attention_mask.unsqueeze(1) encoder_attention_mask = encoder_attention_mask.unsqueeze(1) l = attention_mask.shape[-1] if l % (sparse_n * sparse_n) == 0: pad_len = 0 else: pad_len = sparse_n * sparse_n - l % (sparse_n * sparse_n) attention_mask_sparse = F.pad(attention_mask, (0, pad_len, 0, 0), value=-9980.0) attention_mask_sparse_1d = rearrange( attention_mask_sparse, 'b 1 1 (g k) -> (k b) 1 1 g', k=sparse_n ) attention_mask_sparse_1d_group = rearrange( attention_mask_sparse, 'b 1 1 (n m k) -> (m b) 1 1 (n k)', m=sparse_n, k=sparse_n ) encoder_attention_mask_sparse = encoder_attention_mask.repeat(sparse_n, 1, 1, 1) if npu_config is not None: attention_mask_sparse_1d = npu_config.get_attention_mask( attention_mask_sparse_1d, attention_mask_sparse_1d.shape[-1] ) attention_mask_sparse_1d_group = npu_config.get_attention_mask( attention_mask_sparse_1d_group, attention_mask_sparse_1d_group.shape[-1] ) encoder_attention_mask_sparse_1d = npu_config.get_attention_mask( encoder_attention_mask_sparse, attention_mask_sparse_1d.shape[-1] ) encoder_attention_mask_sparse_1d_group = encoder_attention_mask_sparse_1d else: attention_mask_sparse_1d = attention_mask_sparse_1d.repeat_interleave(head_num, dim=1) attention_mask_sparse_1d_group = attention_mask_sparse_1d_group.repeat_interleave(head_num, dim=1) encoder_attention_mask_sparse_1d = encoder_attention_mask_sparse.repeat_interleave(head_num, dim=1) encoder_attention_mask_sparse_1d_group = encoder_attention_mask_sparse_1d return { False: (attention_mask_sparse_1d, encoder_attention_mask_sparse_1d), True: (attention_mask_sparse_1d_group, encoder_attention_mask_sparse_1d_group) } def prepare_attention_mask( self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 ) -> torch.Tensor: r""" Prepare the attention mask for the attention computation. Args: attention_mask (`torch.Tensor`): The attention mask to prepare. target_length (`int`): The target length of the attention mask. This is the length of the attention mask after padding. batch_size (`int`): The batch size, which is used to repeat the attention mask. out_dim (`int`, *optional*, defaults to `3`): The output dimension of the attention mask. Can be either `3` or `4`. Returns: `torch.Tensor`: The prepared attention mask. """ head_size = self.heads if get_sequence_parallel_state(): head_size = head_size // xccl_info.world_size # e.g, 24 // 8 if attention_mask is None: # b 1 t*h*w in sa, b 1 l in ca return attention_mask current_length: int = attention_mask.shape[-1] if current_length != target_length: print(f'attention_mask.shape, {attention_mask.shape}, current_length, {current_length}, target_length, {target_length}') attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) if out_dim == 3: if attention_mask.shape[0] < batch_size * head_size: attention_mask = attention_mask.repeat_interleave(head_size, dim=0) elif out_dim == 4: attention_mask = attention_mask.unsqueeze(1) attention_mask = attention_mask.repeat_interleave(head_size, dim=1) return attention_mask class OpenSoraAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ def __init__(self, interpolation_scale_thw=(1, 1, 1), sparse1d=False, sparse_n=2, sparse_group=False, is_cross_attn=True): self.sparse1d = sparse1d self.sparse_n = sparse_n self.sparse_group = sparse_group self.is_cross_attn = is_cross_attn self.interpolation_scale_thw = interpolation_scale_thw self._init_rope(interpolation_scale_thw) if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def _init_rope(self, interpolation_scale_thw): self.rope = RoPE3D(interpolation_scale_thw=interpolation_scale_thw) self.position_getter = PositionGetter3D() def _sparse_1d(self, x, frame, height, width): """ require the shape of (ntokens x batch_size x dim) """ l = x.shape[0] assert l == frame*height*width pad_len = 0 if l % (self.sparse_n * self.sparse_n) != 0: pad_len = self.sparse_n * self.sparse_n - l % (self.sparse_n * self.sparse_n) if pad_len != 0: x = F.pad(x, (0, 0, 0, 0, 0, pad_len)) if not self.sparse_group: x = rearrange(x, '(g k) b d -> g (k b) d', k=self.sparse_n) else: x = rearrange(x, '(n m k) b d -> (n k) (m b) d', m=self.sparse_n, k=self.sparse_n) return x, pad_len def _reverse_sparse_1d(self, x, frame, height, width, pad_len): """ require the shape of (ntokens x batch_size x dim) """ assert x.shape[0] == (frame*height*width+pad_len) // self.sparse_n if not self.sparse_group: x = rearrange(x, 'g (k b) d -> (g k) b d', k=self.sparse_n) else: x = rearrange(x, '(n k) (m b) d -> (n m k) b d', m=self.sparse_n, k=self.sparse_n) x = x[:frame*height*width, :, :] return x def _sparse_1d_kv(self, x): """ require the shape of (ntokens x batch_size x dim) """ x = repeat(x, 's b d -> s (k b) d', k=self.sparse_n) return x def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, frame: int = 8, height: int = 16, width: int = 16, *args, **kwargs, ) -> torch.FloatTensor: residual = hidden_states sequence_length, batch_size, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) # if attention_mask is not None: # if npu_config is None: # # scaled_dot_product_attention expects attention_mask shape to be # # (batch, heads, source_length, target_length) # if get_sequence_parallel_state(): # # sequence_length has been split, so we need sequence_length * nccl_info.world_size # # (sp*b 1 s), where s has not been split # # (sp*b 1 s) -prepare-> (sp*b*head 1 s) -> (sp*b head 1 s), where head has been split (e.g, 24 // 8) # attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length * xccl_info.world_size, batch_size) # attention_mask = attention_mask.view(batch_size, attn.heads // xccl_info.world_size, -1, attention_mask.shape[-1]) # else: # attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads FA_head_num = attn.heads total_frame = frame if get_sequence_parallel_state(): sp_size = xccl_info.world_size FA_head_num = attn.heads // sp_size total_frame = frame * sp_size # apply all_to_all to gather sequence and split attention heads [s // sp * b, h, d] -> [s * b, h // sp, d] query = all_to_all_SBH(query.view(-1, attn.heads, head_dim), scatter_dim=1, gather_dim=0) key = all_to_all_SBH(key.view(-1, attn.heads, head_dim), scatter_dim=1, gather_dim=0) value = all_to_all_SBH(value.view(-1, attn.heads, head_dim), scatter_dim=1, gather_dim=0) query = query.view(-1, batch_size, FA_head_num, head_dim) key = key.view(-1, batch_size, FA_head_num, head_dim) if not self.is_cross_attn: # require the shape of (ntokens x batch_size x nheads x dim) pos_thw = self.position_getter(batch_size, t=total_frame, h=height, w=width, device=query.device) query = self.rope(query, pos_thw) key = self.rope(key, pos_thw) # query = rearrange(query, 's b h d -> b h s d') # key = rearrange(key, 's b h d -> b h s d') # dtype = query.dtype # query = self.rope(query.to(torch.float16), pos_thw) # key = self.rope(key.to(torch.float16), pos_thw) # query = rearrange(query, 'b h s d -> s b h d').to(dtype) # key = rearrange(key, 'b h s d -> s b h d').to(dtype) query = query.view(-1, batch_size, FA_head_num * head_dim) key = key.view(-1, batch_size, FA_head_num * head_dim) value = value.view(-1, batch_size, FA_head_num * head_dim) # print(f'q {query.shape}, k {key.shape}, v {value.shape}') if self.sparse1d: query, pad_len = self._sparse_1d(query, total_frame, height, width) if self.is_cross_attn: key = self._sparse_1d_kv(key) value = self._sparse_1d_kv(value) else: key, pad_len = self._sparse_1d(key, total_frame, height, width) value, pad_len = self._sparse_1d(value, total_frame, height, width) # print(f'after sparse q {query.shape}, k {key.shape}, v {value.shape}') if npu_config is not None: hidden_states = npu_config.run_attention(query, key, value, attention_mask, "SBH", head_dim, FA_head_num) else: query = rearrange(query, 's b (h d) -> b h s d', h=FA_head_num) key = rearrange(key, 's b (h d) -> b h s d', h=FA_head_num) value = rearrange(value, 's b (h d) -> b h s d', h=FA_head_num) # 0, -10000 ->(bool) False, True ->(any) True ->(not) False # 0, 0 ->(bool) False, False ->(any) False ->(not) True # if attention_mask is None or not torch.any(attention_mask.bool()): # 0 mean visible # attention_mask = None # the output of sdp = (batch, num_heads, seq_len, head_dim) with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=False, enable_mem_efficient=True): hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = rearrange(hidden_states, 'b h s d -> s b (h d)', h=FA_head_num) if self.sparse1d: hidden_states = self._reverse_sparse_1d(hidden_states, total_frame, height, width, pad_len) # [s, b, h // sp * d] -> [s // sp * b, h, d] -> [s // sp, b, h * d] if get_sequence_parallel_state(): hidden_states = all_to_all_SBH(hidden_states.reshape(-1, FA_head_num, head_dim), scatter_dim=0, gather_dim=1) hidden_states = hidden_states.view(-1, batch_size, inner_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) # if attn.residual_connection: # print('attn.residual_connection') # hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states @maybe_allow_in_graph class BasicTransformerBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", attention_bias: bool = False, only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, final_dropout: bool = False, ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, interpolation_scale_thw: Tuple[int] = (1, 1, 1), sparse1d: bool = False, sparse_n: int = 2, sparse_group: bool = False, ): super().__init__() # Define 3 blocks. Each block has its own normalization layer. # 1. Self-Attn self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, out_bias=attention_out_bias, interpolation_scale_thw=interpolation_scale_thw, sparse1d=sparse1d, sparse_n=sparse_n, sparse_group=sparse_group, is_cross_attn=False, ) # 2. Cross-Attn self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim if not double_self_attention else None, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, out_bias=attention_out_bias, interpolation_scale_thw=interpolation_scale_thw, sparse1d=sparse1d, sparse_n=sparse_n, sparse_group=sparse_group, is_cross_attn=True, ) # is self-attn if encoder_hidden_states is none # 3. Feed-forward self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) # 4. Scale-shift. self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) def forward( self, hidden_states: torch.FloatTensor, attention_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, timestep: Optional[torch.LongTensor] = None, frame: int = None, height: int = None, width: int = None, ) -> torch.FloatTensor: # 0. Self-Attention batch_size = hidden_states.shape[1] shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( self.scale_shift_table[:, None] + timestep.reshape(6, batch_size, -1) ).chunk(6, dim=0) norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=None, attention_mask=attention_mask, frame=frame, height=height, width=width, ) attn_output = gate_msa * attn_output hidden_states = attn_output + hidden_states # 3. Cross-Attention norm_hidden_states = hidden_states attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, frame=frame, height=height, width=width, ) hidden_states = attn_output + hidden_states # 4. Feed-forward norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp ff_output = self.ff(norm_hidden_states) ff_output = gate_mlp * ff_output hidden_states = ff_output + hidden_states return hidden_states ================================================ FILE: opensora/models/frame_interpolation/cfgs/AMT-G.yaml ================================================ seed: 2023 network: name: networks.AMT-G.Model params: corr_radius: 3 corr_lvls: 4 num_flows: 5 ================================================ FILE: opensora/models/frame_interpolation/interpolation.py ================================================ # this script is modified from https://github.com/MCG-NKU/AMT/blob/main/demos/demo_2x.py from json import load import os import cv2 import sys import glob import torch import argparse import numpy as np import os.path as osp from warnings import warn from omegaconf import OmegaConf from torchvision.utils import make_grid sys.path.append('.') from utils.utils import ( read, write, img2tensor, tensor2img, check_dim_and_resize ) from utils.build_utils import build_from_cfg from utils.utils import InputPadder AMT_G = { 'name': 'networks.AMT-G.Model', 'params':{ 'corr_radius': 3, 'corr_lvls': 4, 'num_flows': 5, } } def init(device="cuda"): ''' initialize the device and the anchor resolution. ''' if device == 'cuda': anchor_resolution = 1024 * 512 anchor_memory = 1500 * 1024**2 anchor_memory_bias = 2500 * 1024**2 vram_avail = torch.cuda.get_device_properties(device).total_memory print("VRAM available: {:.1f} MB".format(vram_avail / 1024 ** 2)) else: # Do not resize in cpu mode anchor_resolution = 8192*8192 anchor_memory = 1 anchor_memory_bias = 0 vram_avail = 1 return anchor_resolution, anchor_memory, anchor_memory_bias, vram_avail def get_input_video_from_path(input_path, device="cuda"): ''' Get the input video from the input_path. params: input_path: str, the path of the input video. devices: str, the device to run the model. returns: inputs: list, the list of the input frames. scale: float, the scale of the input frames. padder: InputPadder, the padder to pad the input frames. ''' anchor_resolution, anchor_memory, anchor_memory_bias, vram_avail = init(device) if osp.splitext(input_path)[-1] in ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.MP4', '.AVI', '.MOV', '.MKV', '.FLV', '.WMV', '.WEBM']: vcap = cv2.VideoCapture(input_path) inputs = [] w = int(vcap.get(cv2.CAP_PROP_FRAME_WIDTH)) h = int(vcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory) scale = 1 if scale > 1 else scale scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16 if scale < 1: print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}") padding = int(16 / scale) padder = InputPadder((h, w), padding) while True: ret, frame = vcap.read() if ret is False: break frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame_t = img2tensor(frame).to(device) frame_t = padder.pad(frame_t) inputs.append(frame_t) print(f'Loading the [video] from {input_path}, the number of frames [{len(inputs)}]') else: raise TypeError("Input should be a video.") return inputs, scale, padder def load_model(ckpt_path, device="cuda"): ''' load the frame interpolation model. ''' network_cfg = AMT_G network_name = network_cfg['name'] print(f'Loading [{network_name}] from [{ckpt_path}]...') model = build_from_cfg(network_cfg) ckpt = torch.load(ckpt_path) model.load_state_dict(ckpt['state_dict']) model = model.to(device) model.eval() return model def interpolater(model, inputs, scale, padder, iters=1): ''' interpolating with the interpolation model. params: model: nn.Module, the frame interpolation model. inputs: list, the list of the input frames. scale: float, the scale of the input frames. iters: int, the number of iterations of interpolation. The final frames model generating is 2 ** iters * (m - 1) + 1 and m is input frames. returns: outputs: list, the list of the output frames. ''' print(f'Start frame interpolation:') embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device) for i in range(iters): print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}') outputs = [inputs[0]] for in_0, in_1 in zip(inputs[:-1], inputs[1:]): in_0 = in_0.to(device) in_1 = in_1.to(device) with torch.no_grad(): imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred'] outputs += [imgt_pred.cpu(), in_1.cpu()] inputs = outputs outputs = padder.unpad(*outputs) return outputs def write(outputs, input_path, output_path, frame_rate=30): ''' write results to the output_path. ''' if osp.exists(output_path) is False: os.makedirs(output_path) size = outputs[0].shape[2:][::-1] _, file_name_with_extension = os.path.split(input_path) file_name, _ = os.path.splitext(file_name_with_extension) save_video_path = f'{output_path}/output_{file_name}.mp4' writer = cv2.VideoWriter(save_video_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_rate, size) for i, imgt_pred in enumerate(outputs): imgt_pred = tensor2img(imgt_pred) imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR) writer.write(imgt_pred) print(f"Demo video is saved to [{save_video_path}]") writer.release() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--ckpt', type=str, default='amt-g.pth', help="The pretrained model.") parser.add_argument('--niters', type=int, default=1, help="Iter of Interpolation. The number of frames will be double after per iter.") parser.add_argument('--input', default="test.mp4", help="Input video.") parser.add_argument('--output_path', type=str, default='results', help="Output path.") parser.add_argument('--frame_rate', type=int, default=30, help="Frames rate of the output video.") args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ckpt_path = args.ckpt input_path = args.input output_path = args.output_path iters = int(args.niters) frame_rate = int(args.frame_rate) inputs, scale, padder = get_input_video_from_path(input_path, device) model = load_model(ckpt_path, device) outputs = interpolater(model, inputs, scale, padder, iters) write(outputs, input_path, output_path, frame_rate) ================================================ FILE: opensora/models/frame_interpolation/networks/AMT-G.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from networks.blocks.raft import ( coords_grid, BasicUpdateBlock, BidirCorrBlock ) from networks.blocks.feat_enc import ( LargeEncoder ) from networks.blocks.ifrnet import ( resize, Encoder, InitDecoder, IntermediateDecoder ) from networks.blocks.multi_flow import ( multi_flow_combine, MultiFlowDecoder ) class Model(nn.Module): def __init__(self, corr_radius=3, corr_lvls=4, num_flows=5, channels=[84, 96, 112, 128], skip_channels=84): super(Model, self).__init__() self.radius = corr_radius self.corr_levels = corr_lvls self.num_flows = num_flows self.feat_encoder = LargeEncoder(output_dim=128, norm_fn='instance', dropout=0.) self.encoder = Encoder(channels, large=True) self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels) self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels) self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels) self.decoder1 = MultiFlowDecoder(channels[0], skip_channels, num_flows) self.update4 = self._get_updateblock(112, None) self.update3_low = self._get_updateblock(96, 2.0) self.update2_low = self._get_updateblock(84, 4.0) self.update3_high = self._get_updateblock(96, None) self.update2_high = self._get_updateblock(84, None) self.comb_block = nn.Sequential( nn.Conv2d(3*self.num_flows, 6*self.num_flows, 7, 1, 3), nn.PReLU(6*self.num_flows), nn.Conv2d(6*self.num_flows, 3, 7, 1, 3), ) def _get_updateblock(self, cdim, scale_factor=None): return BasicUpdateBlock(cdim=cdim, hidden_dim=192, flow_dim=64, corr_dim=256, corr_dim2=192, fc_dim=188, scale_factor=scale_factor, corr_levels=self.corr_levels, radius=self.radius) def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 # based on linear assumption t1_scale = 1. / embt t0_scale = 1. / (1. - embt) if downsample != 1: inv = 1 / downsample flow0 = inv * resize(flow0, scale_factor=inv) flow1 = inv * resize(flow1, scale_factor=inv) corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale) corr = torch.cat([corr0, corr1], dim=1) flow = torch.cat([flow0, flow1], dim=1) return corr, flow def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs): mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) img0 = img0 - mean_ img1 = img1 - mean_ img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1 b, _, h, w = img0_.shape coord = coords_grid(b, h // 8, w // 8, img0.device) fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8] corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels) # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4] # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16] f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_) f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_) ######################################### the 4th decoder ######################################### up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt) corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord, up_flow0_4, up_flow1_4, embt, downsample=1) # residue update with lookup corr delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4) delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1) up_flow0_4 = up_flow0_4 + delta_flow0_4 up_flow1_4 = up_flow1_4 + delta_flow1_4 ft_3_ = ft_3_ + delta_ft_3_ ######################################### the 3rd decoder ######################################### up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) corr_3, flow_3 = self._corr_scale_lookup(corr_fn, coord, up_flow0_3, up_flow1_3, embt, downsample=2) # residue update with lookup corr delta_ft_2_, delta_flow_3 = self.update3_low(ft_2_, flow_3, corr_3) delta_flow0_3, delta_flow1_3 = torch.chunk(delta_flow_3, 2, 1) up_flow0_3 = up_flow0_3 + delta_flow0_3 up_flow1_3 = up_flow1_3 + delta_flow1_3 ft_2_ = ft_2_ + delta_ft_2_ # residue update with lookup corr (hr) corr_3 = resize(corr_3, scale_factor=2.0) up_flow_3 = torch.cat([up_flow0_3, up_flow1_3], dim=1) delta_ft_2_, delta_up_flow_3 = self.update3_high(ft_2_, up_flow_3, corr_3) ft_2_ += delta_ft_2_ up_flow0_3 += delta_up_flow_3[:, 0:2] up_flow1_3 += delta_up_flow_3[:, 2:4] ######################################### the 2nd decoder ######################################### up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) corr_2, flow_2 = self._corr_scale_lookup(corr_fn, coord, up_flow0_2, up_flow1_2, embt, downsample=4) # residue update with lookup corr delta_ft_1_, delta_flow_2 = self.update2_low(ft_1_, flow_2, corr_2) delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1) up_flow0_2 = up_flow0_2 + delta_flow0_2 up_flow1_2 = up_flow1_2 + delta_flow1_2 ft_1_ = ft_1_ + delta_ft_1_ # residue update with lookup corr (hr) corr_2 = resize(corr_2, scale_factor=4.0) up_flow_2 = torch.cat([up_flow0_2, up_flow1_2], dim=1) delta_ft_1_, delta_up_flow_2 = self.update2_high(ft_1_, up_flow_2, corr_2) ft_1_ += delta_ft_1_ up_flow0_2 += delta_up_flow_2[:, 0:2] up_flow1_2 += delta_up_flow_2[:, 2:4] ######################################### the 1st decoder ######################################### up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) if scale_factor != 1.0: up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) mask = resize(mask, scale_factor=(1.0/scale_factor)) img_res = resize(img_res, scale_factor=(1.0/scale_factor)) # Merge multiple predictions imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1, mask, img_res, mean_) imgt_pred = torch.clamp(imgt_pred, 0, 1) if eval: return { 'imgt_pred': imgt_pred, } else: up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w) up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w) return { 'imgt_pred': imgt_pred, 'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], 'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], 'ft_pred': [ft_1_, ft_2_, ft_3_], } ================================================ FILE: opensora/models/frame_interpolation/networks/__init__.py ================================================ ================================================ FILE: opensora/models/frame_interpolation/networks/blocks/__init__.py ================================================ ================================================ FILE: opensora/models/frame_interpolation/networks/blocks/feat_enc.py ================================================ import torch import torch.nn as nn class BottleneckBlock(nn.Module): def __init__(self, in_planes, planes, norm_fn='group', stride=1): super(BottleneckBlock, self).__init__() self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) self.relu = nn.ReLU(inplace=True) num_groups = planes // 8 if norm_fn == 'group': self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) if not stride == 1: self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) elif norm_fn == 'batch': self.norm1 = nn.BatchNorm2d(planes//4) self.norm2 = nn.BatchNorm2d(planes//4) self.norm3 = nn.BatchNorm2d(planes) if not stride == 1: self.norm4 = nn.BatchNorm2d(planes) elif norm_fn == 'instance': self.norm1 = nn.InstanceNorm2d(planes//4) self.norm2 = nn.InstanceNorm2d(planes//4) self.norm3 = nn.InstanceNorm2d(planes) if not stride == 1: self.norm4 = nn.InstanceNorm2d(planes) elif norm_fn == 'none': self.norm1 = nn.Sequential() self.norm2 = nn.Sequential() self.norm3 = nn.Sequential() if not stride == 1: self.norm4 = nn.Sequential() if stride == 1: self.downsample = None else: self.downsample = nn.Sequential( nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) def forward(self, x): y = x y = self.relu(self.norm1(self.conv1(y))) y = self.relu(self.norm2(self.conv2(y))) y = self.relu(self.norm3(self.conv3(y))) if self.downsample is not None: x = self.downsample(x) return self.relu(x+y) class ResidualBlock(nn.Module): def __init__(self, in_planes, planes, norm_fn='group', stride=1): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) self.relu = nn.ReLU(inplace=True) num_groups = planes // 8 if norm_fn == 'group': self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) if not stride == 1: self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) elif norm_fn == 'batch': self.norm1 = nn.BatchNorm2d(planes) self.norm2 = nn.BatchNorm2d(planes) if not stride == 1: self.norm3 = nn.BatchNorm2d(planes) elif norm_fn == 'instance': self.norm1 = nn.InstanceNorm2d(planes) self.norm2 = nn.InstanceNorm2d(planes) if not stride == 1: self.norm3 = nn.InstanceNorm2d(planes) elif norm_fn == 'none': self.norm1 = nn.Sequential() self.norm2 = nn.Sequential() if not stride == 1: self.norm3 = nn.Sequential() if stride == 1: self.downsample = None else: self.downsample = nn.Sequential( nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) def forward(self, x): y = x y = self.relu(self.norm1(self.conv1(y))) y = self.relu(self.norm2(self.conv2(y))) if self.downsample is not None: x = self.downsample(x) return self.relu(x+y) class SmallEncoder(nn.Module): def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): super(SmallEncoder, self).__init__() self.norm_fn = norm_fn if self.norm_fn == 'group': self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) elif self.norm_fn == 'batch': self.norm1 = nn.BatchNorm2d(32) elif self.norm_fn == 'instance': self.norm1 = nn.InstanceNorm2d(32) elif self.norm_fn == 'none': self.norm1 = nn.Sequential() self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) self.relu1 = nn.ReLU(inplace=True) self.in_planes = 32 self.layer1 = self._make_layer(32, stride=1) self.layer2 = self._make_layer(64, stride=2) self.layer3 = self._make_layer(96, stride=2) self.dropout = None if dropout > 0: self.dropout = nn.Dropout2d(p=dropout) self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): if m.weight is not None: nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0) def _make_layer(self, dim, stride=1): layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) layers = (layer1, layer2) self.in_planes = dim return nn.Sequential(*layers) def forward(self, x): # if input is list, combine batch dimension is_list = isinstance(x, tuple) or isinstance(x, list) if is_list: batch_dim = x[0].shape[0] x = torch.cat(x, dim=0) x = self.conv1(x) x = self.norm1(x) x = self.relu1(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.conv2(x) if self.training and self.dropout is not None: x = self.dropout(x) if is_list: x = torch.split(x, [batch_dim, batch_dim], dim=0) return x class BasicEncoder(nn.Module): def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): super(BasicEncoder, self).__init__() self.norm_fn = norm_fn if self.norm_fn == 'group': self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) elif self.norm_fn == 'batch': self.norm1 = nn.BatchNorm2d(64) elif self.norm_fn == 'instance': self.norm1 = nn.InstanceNorm2d(64) elif self.norm_fn == 'none': self.norm1 = nn.Sequential() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) self.relu1 = nn.ReLU(inplace=True) self.in_planes = 64 self.layer1 = self._make_layer(64, stride=1) self.layer2 = self._make_layer(72, stride=2) self.layer3 = self._make_layer(128, stride=2) # output convolution self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) self.dropout = None if dropout > 0: self.dropout = nn.Dropout2d(p=dropout) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): if m.weight is not None: nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0) def _make_layer(self, dim, stride=1): layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) layers = (layer1, layer2) self.in_planes = dim return nn.Sequential(*layers) def forward(self, x): # if input is list, combine batch dimension is_list = isinstance(x, tuple) or isinstance(x, list) if is_list: batch_dim = x[0].shape[0] x = torch.cat(x, dim=0) x = self.conv1(x) x = self.norm1(x) x = self.relu1(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.conv2(x) if self.training and self.dropout is not None: x = self.dropout(x) if is_list: x = torch.split(x, [batch_dim, batch_dim], dim=0) return x class LargeEncoder(nn.Module): def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): super(LargeEncoder, self).__init__() self.norm_fn = norm_fn if self.norm_fn == 'group': self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) elif self.norm_fn == 'batch': self.norm1 = nn.BatchNorm2d(64) elif self.norm_fn == 'instance': self.norm1 = nn.InstanceNorm2d(64) elif self.norm_fn == 'none': self.norm1 = nn.Sequential() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) self.relu1 = nn.ReLU(inplace=True) self.in_planes = 64 self.layer1 = self._make_layer(64, stride=1) self.layer2 = self._make_layer(112, stride=2) self.layer3 = self._make_layer(160, stride=2) self.layer3_2 = self._make_layer(160, stride=1) # output convolution self.conv2 = nn.Conv2d(self.in_planes, output_dim, kernel_size=1) self.dropout = None if dropout > 0: self.dropout = nn.Dropout2d(p=dropout) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): if m.weight is not None: nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0) def _make_layer(self, dim, stride=1): layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) layers = (layer1, layer2) self.in_planes = dim return nn.Sequential(*layers) def forward(self, x): # if input is list, combine batch dimension is_list = isinstance(x, tuple) or isinstance(x, list) if is_list: batch_dim = x[0].shape[0] x = torch.cat(x, dim=0) x = self.conv1(x) x = self.norm1(x) x = self.relu1(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer3_2(x) x = self.conv2(x) if self.training and self.dropout is not None: x = self.dropout(x) if is_list: x = torch.split(x, [batch_dim, batch_dim], dim=0) return x ================================================ FILE: opensora/models/frame_interpolation/networks/blocks/ifrnet.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from utils.flow_utils import warp def resize(x, scale_factor): return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True): return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias), nn.PReLU(out_channels) ) class ResBlock(nn.Module): def __init__(self, in_channels, side_channels, bias=True): super(ResBlock, self).__init__() self.side_channels = side_channels self.conv1 = nn.Sequential( nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), nn.PReLU(in_channels) ) self.conv2 = nn.Sequential( nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), nn.PReLU(side_channels) ) self.conv3 = nn.Sequential( nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), nn.PReLU(in_channels) ) self.conv4 = nn.Sequential( nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), nn.PReLU(side_channels) ) self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias) self.prelu = nn.PReLU(in_channels) def forward(self, x): out = self.conv1(x) res_feat = out[:, :-self.side_channels, ...] side_feat = out[:, -self.side_channels:, :, :] side_feat = self.conv2(side_feat) out = self.conv3(torch.cat([res_feat, side_feat], 1)) res_feat = out[:, :-self.side_channels, ...] side_feat = out[:, -self.side_channels:, :, :] side_feat = self.conv4(side_feat) out = self.conv5(torch.cat([res_feat, side_feat], 1)) out = self.prelu(x + out) return out class Encoder(nn.Module): def __init__(self, channels, large=False): super(Encoder, self).__init__() self.channels = channels prev_ch = 3 for idx, ch in enumerate(channels, 1): k = 7 if large and idx == 1 else 3 p = 3 if k ==7 else 1 self.register_module(f'pyramid{idx}', nn.Sequential( convrelu(prev_ch, ch, k, 2, p), convrelu(ch, ch, 3, 1, 1) )) prev_ch = ch def forward(self, in_x): fs = [] for idx in range(len(self.channels)): out_x = getattr(self, f'pyramid{idx+1}')(in_x) fs.append(out_x) in_x = out_x return fs class InitDecoder(nn.Module): def __init__(self, in_ch, out_ch, skip_ch) -> None: super().__init__() self.convblock = nn.Sequential( convrelu(in_ch*2+1, in_ch*2), ResBlock(in_ch*2, skip_ch), nn.ConvTranspose2d(in_ch*2, out_ch+4, 4, 2, 1, bias=True) ) def forward(self, f0, f1, embt): h, w = f0.shape[2:] embt = embt.repeat(1, 1, h, w) out = self.convblock(torch.cat([f0, f1, embt], 1)) flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1) ft_ = out[:, 4:, ...] return flow0, flow1, ft_ class IntermediateDecoder(nn.Module): def __init__(self, in_ch, out_ch, skip_ch) -> None: super().__init__() self.convblock = nn.Sequential( convrelu(in_ch*3+4, in_ch*3), ResBlock(in_ch*3, skip_ch), nn.ConvTranspose2d(in_ch*3, out_ch+4, 4, 2, 1, bias=True) ) def forward(self, ft_, f0, f1, flow0_in, flow1_in): f0_warp = warp(f0, flow0_in) f1_warp = warp(f1, flow1_in) f_in = torch.cat([ft_, f0_warp, f1_warp, flow0_in, flow1_in], 1) out = self.convblock(f_in) flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1) ft_ = out[:, 4:, ...] flow0 = flow0 + 2.0 * resize(flow0_in, scale_factor=2.0) flow1 = flow1 + 2.0 * resize(flow1_in, scale_factor=2.0) return flow0, flow1, ft_ ================================================ FILE: opensora/models/frame_interpolation/networks/blocks/multi_flow.py ================================================ import torch import torch.nn as nn from utils.flow_utils import warp from networks.blocks.ifrnet import ( convrelu, resize, ResBlock, ) def multi_flow_combine(comb_block, img0, img1, flow0, flow1, mask=None, img_res=None, mean=None): ''' A parallel implementation of multiple flow field warping comb_block: An nn.Seqential object. img shape: [b, c, h, w] flow shape: [b, 2*num_flows, h, w] mask (opt): If 'mask' is None, the function conduct a simple average. img_res (opt): If 'img_res' is None, the function adds zero instead. mean (opt): If 'mean' is None, the function adds zero instead. ''' b, c, h, w = flow0.shape num_flows = c // 2 flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) mask = mask.reshape(b, num_flows, 1, h, w ).reshape(-1, 1, h, w) if mask is not None else None img_res = img_res.reshape(b, num_flows, 3, h, w ).reshape(-1, 3, h, w) if img_res is not None else 0 img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w) img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w) mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1 ) if mean is not None else 0 img0_warp = warp(img0, flow0) img1_warp = warp(img1, flow1) img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res img_warps = img_warps.reshape(b, num_flows, 3, h, w) imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w)) return imgt_pred class MultiFlowDecoder(nn.Module): def __init__(self, in_ch, skip_ch, num_flows=3): super(MultiFlowDecoder, self).__init__() self.num_flows = num_flows self.convblock = nn.Sequential( convrelu(in_ch*3+4, in_ch*3), ResBlock(in_ch*3, skip_ch), nn.ConvTranspose2d(in_ch*3, 8*num_flows, 4, 2, 1, bias=True) ) def forward(self, ft_, f0, f1, flow0, flow1): n = self.num_flows f0_warp = warp(f0, flow0) f1_warp = warp(f1, flow1) out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1)) delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2*n, 2*n, n, 3*n], 1) mask = torch.sigmoid(mask) flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0 ).repeat(1, self.num_flows, 1, 1) flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0 ).repeat(1, self.num_flows, 1, 1) return flow0, flow1, mask, img_res ================================================ FILE: opensora/models/frame_interpolation/networks/blocks/raft.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F def resize(x, scale_factor): return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) def bilinear_sampler(img, coords, mask=False): """ Wrapper for grid_sample, uses pixel coordinates """ H, W = img.shape[-2:] xgrid, ygrid = coords.split([1,1], dim=-1) xgrid = 2*xgrid/(W-1) - 1 ygrid = 2*ygrid/(H-1) - 1 grid = torch.cat([xgrid, ygrid], dim=-1) img = F.grid_sample(img, grid, align_corners=True) if mask: mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) return img, mask.float() return img def coords_grid(batch, ht, wd, device): coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device), indexing='ij') coords = torch.stack(coords[::-1], dim=0).float() return coords[None].repeat(batch, 1, 1, 1) class SmallUpdateBlock(nn.Module): def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, fc_dim, corr_levels=4, radius=3, scale_factor=None): super(SmallUpdateBlock, self).__init__() cor_planes = corr_levels * (2 * radius + 1) **2 self.scale_factor = scale_factor self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3) self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1) self.conv = nn.Conv2d(corr_dim+flow_dim, fc_dim, 3, padding=1) self.gru = nn.Sequential( nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1), nn.LeakyReLU(negative_slope=0.1, inplace=True), nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), ) self.feat_head = nn.Sequential( nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), nn.LeakyReLU(negative_slope=0.1, inplace=True), nn.Conv2d(hidden_dim, cdim, 3, padding=1), ) self.flow_head = nn.Sequential( nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), nn.LeakyReLU(negative_slope=0.1, inplace=True), nn.Conv2d(hidden_dim, 4, 3, padding=1), ) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) def forward(self, net, flow, corr): net = resize(net, 1 / self.scale_factor ) if self.scale_factor is not None else net cor = self.lrelu(self.convc1(corr)) flo = self.lrelu(self.convf1(flow)) flo = self.lrelu(self.convf2(flo)) cor_flo = torch.cat([cor, flo], dim=1) inp = self.lrelu(self.conv(cor_flo)) inp = torch.cat([inp, flow, net], dim=1) out = self.gru(inp) delta_net = self.feat_head(out) delta_flow = self.flow_head(out) if self.scale_factor is not None: delta_net = resize(delta_net, scale_factor=self.scale_factor) delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor) return delta_net, delta_flow class BasicUpdateBlock(nn.Module): def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, corr_dim2, fc_dim, corr_levels=4, radius=3, scale_factor=None, out_num=1): super(BasicUpdateBlock, self).__init__() cor_planes = corr_levels * (2 * radius + 1) **2 self.scale_factor = scale_factor self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1) self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3) self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1) self.conv = nn.Conv2d(flow_dim+corr_dim2, fc_dim, 3, padding=1) self.gru = nn.Sequential( nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1), nn.LeakyReLU(negative_slope=0.1, inplace=True), nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), ) self.feat_head = nn.Sequential( nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), nn.LeakyReLU(negative_slope=0.1, inplace=True), nn.Conv2d(hidden_dim, cdim, 3, padding=1), ) self.flow_head = nn.Sequential( nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), nn.LeakyReLU(negative_slope=0.1, inplace=True), nn.Conv2d(hidden_dim, 4*out_num, 3, padding=1), ) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) def forward(self, net, flow, corr): net = resize(net, 1 / self.scale_factor ) if self.scale_factor is not None else net cor = self.lrelu(self.convc1(corr)) cor = self.lrelu(self.convc2(cor)) flo = self.lrelu(self.convf1(flow)) flo = self.lrelu(self.convf2(flo)) cor_flo = torch.cat([cor, flo], dim=1) inp = self.lrelu(self.conv(cor_flo)) inp = torch.cat([inp, flow, net], dim=1) out = self.gru(inp) delta_net = self.feat_head(out) delta_flow = self.flow_head(out) if self.scale_factor is not None: delta_net = resize(delta_net, scale_factor=self.scale_factor) delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor) return delta_net, delta_flow class BidirCorrBlock: def __init__(self, fmap1, fmap2, num_levels=4, radius=4): self.num_levels = num_levels self.radius = radius self.corr_pyramid = [] self.corr_pyramid_T = [] corr = BidirCorrBlock.corr(fmap1, fmap2) batch, h1, w1, dim, h2, w2 = corr.shape corr_T = corr.clone().permute(0, 4, 5, 3, 1, 2) corr = corr.reshape(batch*h1*w1, dim, h2, w2) corr_T = corr_T.reshape(batch*h2*w2, dim, h1, w1) self.corr_pyramid.append(corr) self.corr_pyramid_T.append(corr_T) for _ in range(self.num_levels-1): corr = F.avg_pool2d(corr, 2, stride=2) corr_T = F.avg_pool2d(corr_T, 2, stride=2) self.corr_pyramid.append(corr) self.corr_pyramid_T.append(corr_T) def __call__(self, coords0, coords1): r = self.radius coords0 = coords0.permute(0, 2, 3, 1) coords1 = coords1.permute(0, 2, 3, 1) assert coords0.shape == coords1.shape, f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]" batch, h1, w1, _ = coords0.shape out_pyramid = [] out_pyramid_T = [] for i in range(self.num_levels): corr = self.corr_pyramid[i] corr_T = self.corr_pyramid_T[i] dx = torch.linspace(-r, r, 2*r+1, device=coords0.device) dy = torch.linspace(-r, r, 2*r+1, device=coords0.device) delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), axis=-1) delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) centroid_lvl_0 = coords0.reshape(batch*h1*w1, 1, 1, 2) / 2**i centroid_lvl_1 = coords1.reshape(batch*h1*w1, 1, 1, 2) / 2**i coords_lvl_0 = centroid_lvl_0 + delta_lvl coords_lvl_1 = centroid_lvl_1 + delta_lvl corr = bilinear_sampler(corr, coords_lvl_0) corr_T = bilinear_sampler(corr_T, coords_lvl_1) corr = corr.view(batch, h1, w1, -1) corr_T = corr_T.view(batch, h1, w1, -1) out_pyramid.append(corr) out_pyramid_T.append(corr_T) out = torch.cat(out_pyramid, dim=-1) out_T = torch.cat(out_pyramid_T, dim=-1) return out.permute(0, 3, 1, 2).contiguous().float(), out_T.permute(0, 3, 1, 2).contiguous().float() @staticmethod def corr(fmap1, fmap2): batch, dim, ht, wd = fmap1.shape fmap1 = fmap1.view(batch, dim, ht*wd) fmap2 = fmap2.view(batch, dim, ht*wd) corr = torch.matmul(fmap1.transpose(1,2), fmap2) corr = corr.view(batch, ht, wd, 1, ht, wd) return corr / torch.sqrt(torch.tensor(dim).float()) ================================================ FILE: opensora/models/frame_interpolation/readme.md ================================================ #### Frame Interpolation We use AMT as our frame interpolation model. (Thanks [AMT](https://github.com/MCG-NKU/AMT)) After sampling, you can use frame interpolation model to interpolate your video smoothly. 1. Download the pretrained weights from [AMT](https://github.com/MCG-NKU/AMT), we recommend using the largest model AMT-G to achieve the best performance. 2. Run the script of frame interpolation. ``` python opensora/models/frame_interpolation/interpolation.py --ckpt /path/to/ckpt --niters 1 --input /path/to/input/video.mp4 --output_path /path/to/output/floder --frame_rate 30 ``` 3. The output video will be stored at output_path and its duration time is equal `the total number of frames after frame interpolation / the frame rate` ##### Frame Interpolation Specific Settings * `--ckpt`: Pretrained model of [AMT](https://github.com/MCG-NKU/AMT). We use AMT-G as our frame interpolation model. * `--niter`: Iterations of interpolation. With $m$ input frames, `[N_ITER]` $=n$ corresponds to $2^n\times (m-1)+1$ output frames. * `--input`: Path of the input video. * `--output_path`: Folder Path of the output video. * `--frame_rate"`: Frame rate of the output video. ================================================ FILE: opensora/models/frame_interpolation/utils/__init__.py ================================================ ================================================ FILE: opensora/models/frame_interpolation/utils/build_utils.py ================================================ import importlib def base_build_fn(module, cls, params): return getattr(importlib.import_module( module, package=None), cls)(**params) def build_from_cfg(config): module, cls = config['name'].rsplit(".", 1) params = config.get('params', {}) return base_build_fn(module, cls, params) ================================================ FILE: opensora/models/frame_interpolation/utils/dist_utils.py ================================================ import os import torch def get_world_size(): """Find OMPI world size without calling mpi functions :rtype: int """ if os.environ.get('PMI_SIZE') is not None: return int(os.environ.get('PMI_SIZE') or 1) elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None: return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1) else: return torch.cuda.device_count() def get_global_rank(): """Find OMPI world rank without calling mpi functions :rtype: int """ if os.environ.get('PMI_RANK') is not None: return int(os.environ.get('PMI_RANK') or 0) elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None: return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0) else: return 0 def get_local_rank(): """Find OMPI local rank without calling mpi functions :rtype: int """ if os.environ.get('MPI_LOCALRANKID') is not None: return int(os.environ.get('MPI_LOCALRANKID') or 0) elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None: return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0) else: return 0 def get_master_ip(): if os.environ.get('AZ_BATCH_MASTER_NODE') is not None: return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0] elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None: return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') else: return "127.0.0.1" ================================================ FILE: opensora/models/frame_interpolation/utils/flow_utils.py ================================================ import numpy as np import torch from PIL import ImageFile import torch.nn.functional as F ImageFile.LOAD_TRUNCATED_IMAGES = True def warp(img, flow): B, _, H, W = flow.shape xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1) yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W) grid = torch.cat([xx, yy], 1).to(img) flow_ = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), flow[:, 1:2, :, :] / ((H - 1.0) / 2.0)], 1) grid_ = (grid + flow_).permute(0, 2, 3, 1) output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True) return output def make_colorwheel(): """ Generates a color wheel for optical flow visualization as presented in: Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf Code follows the original C++ source code of Daniel Scharstein. Code follows the the Matlab source code of Deqing Sun. Returns: np.ndarray: Color wheel """ RY = 15 YG = 6 GC = 4 CB = 11 BM = 13 MR = 6 ncols = RY + YG + GC + CB + BM + MR colorwheel = np.zeros((ncols, 3)) col = 0 # RY colorwheel[0:RY, 0] = 255 colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) col = col+RY # YG colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) colorwheel[col:col+YG, 1] = 255 col = col+YG # GC colorwheel[col:col+GC, 1] = 255 colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) col = col+GC # CB colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) colorwheel[col:col+CB, 2] = 255 col = col+CB # BM colorwheel[col:col+BM, 2] = 255 colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) col = col+BM # MR colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) colorwheel[col:col+MR, 0] = 255 return colorwheel def flow_uv_to_colors(u, v, convert_to_bgr=False): """ Applies the flow color wheel to (possibly clipped) flow components u and v. According to the C++ source code of Daniel Scharstein According to the Matlab source code of Deqing Sun Args: u (np.ndarray): Input horizontal flow of shape [H,W] v (np.ndarray): Input vertical flow of shape [H,W] convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. Returns: np.ndarray: Flow visualization image of shape [H,W,3] """ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) colorwheel = make_colorwheel() # shape [55x3] ncols = colorwheel.shape[0] rad = np.sqrt(np.square(u) + np.square(v)) a = np.arctan2(-v, -u)/np.pi fk = (a+1) / 2*(ncols-1) k0 = np.floor(fk).astype(np.int32) k1 = k0 + 1 k1[k1 == ncols] = 0 f = fk - k0 for i in range(colorwheel.shape[1]): tmp = colorwheel[:,i] col0 = tmp[k0] / 255.0 col1 = tmp[k1] / 255.0 col = (1-f)*col0 + f*col1 idx = (rad <= 1) col[idx] = 1 - rad[idx] * (1-col[idx]) col[~idx] = col[~idx] * 0.75 # out of range # Note the 2-i => BGR instead of RGB ch_idx = 2-i if convert_to_bgr else i flow_image[:,:,ch_idx] = np.floor(255 * col) return flow_image def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): """ Expects a two dimensional flow image of shape. Args: flow_uv (np.ndarray): Flow UV image of shape [H,W,2] clip_flow (float, optional): Clip maximum of flow values. Defaults to None. convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. Returns: np.ndarray: Flow visualization image of shape [H,W,3] """ assert flow_uv.ndim == 3, 'input flow must have three dimensions' assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' if clip_flow is not None: flow_uv = np.clip(flow_uv, 0, clip_flow) u = flow_uv[:,:,0] v = flow_uv[:,:,1] rad = np.sqrt(np.square(u) + np.square(v)) rad_max = np.max(rad) epsilon = 1e-5 u = u / (rad_max + epsilon) v = v / (rad_max + epsilon) return flow_uv_to_colors(u, v, convert_to_bgr) ================================================ FILE: opensora/models/frame_interpolation/utils/utils.py ================================================ import re import sys import torch import random import numpy as np from PIL import ImageFile import torch.nn.functional as F from imageio import imread, imwrite ImageFile.LOAD_TRUNCATED_IMAGES = True class AverageMeter(): def __init__(self): self.reset() def reset(self): self.val = 0. self.avg = 0. self.sum = 0. self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count class AverageMeterGroups: def __init__(self) -> None: self.meter_dict = dict() def update(self, dict, n=1): for name, val in dict.items(): if self.meter_dict.get(name) is None: self.meter_dict[name] = AverageMeter() self.meter_dict[name].update(val, n) def reset(self, name=None): if name is None: for v in self.meter_dict.values(): v.reset() else: meter = self.meter_dict.get(name) if meter is not None: meter.reset() def avg(self, name): meter = self.meter_dict.get(name) if meter is not None: return meter.avg class InputPadder: """ Pads images such that dimensions are divisible by divisor """ def __init__(self, dims, divisor=16): self.ht, self.wd = dims[-2:] pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] def pad(self, *inputs): if len(inputs) == 1: return F.pad(inputs[0], self._pad, mode='replicate') else: return [F.pad(x, self._pad, mode='replicate') for x in inputs] def unpad(self, *inputs): if len(inputs) == 1: return self._unpad(inputs[0]) else: return [self._unpad(x) for x in inputs] def _unpad(self, x): ht, wd = x.shape[-2:] c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] return x[..., c[0]:c[1], c[2]:c[3]] def img2tensor(img): if img.shape[-1] > 3: img = img[:,:,:3] return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0 def tensor2img(img_t): return (img_t * 255.).detach( ).squeeze(0).permute(1, 2, 0).cpu().numpy( ).clip(0, 255).astype(np.uint8) def seed_all(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def read(file): if file.endswith('.float3'): return readFloat(file) elif file.endswith('.flo'): return readFlow(file) elif file.endswith('.ppm'): return readImage(file) elif file.endswith('.pgm'): return readImage(file) elif file.endswith('.png'): return readImage(file) elif file.endswith('.jpg'): return readImage(file) elif file.endswith('.pfm'): return readPFM(file)[0] else: raise Exception('don\'t know how to read %s' % file) def write(file, data): if file.endswith('.float3'): return writeFloat(file, data) elif file.endswith('.flo'): return writeFlow(file, data) elif file.endswith('.ppm'): return writeImage(file, data) elif file.endswith('.pgm'): return writeImage(file, data) elif file.endswith('.png'): return writeImage(file, data) elif file.endswith('.jpg'): return writeImage(file, data) elif file.endswith('.pfm'): return writePFM(file, data) else: raise Exception('don\'t know how to write %s' % file) def readPFM(file): file = open(file, 'rb') color = None width = None height = None scale = None endian = None header = file.readline().rstrip() if header.decode("ascii") == 'PF': color = True elif header.decode("ascii") == 'Pf': color = False else: raise Exception('Not a PFM file.') dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) if dim_match: width, height = list(map(int, dim_match.groups())) else: raise Exception('Malformed PFM header.') scale = float(file.readline().decode("ascii").rstrip()) if scale < 0: endian = '<' scale = -scale else: endian = '>' data = np.fromfile(file, endian + 'f') shape = (height, width, 3) if color else (height, width) data = np.reshape(data, shape) data = np.flipud(data) return data, scale def writePFM(file, image, scale=1): file = open(file, 'wb') color = None if image.dtype.name != 'float32': raise Exception('Image dtype must be float32.') image = np.flipud(image) if len(image.shape) == 3 and image.shape[2] == 3: color = True elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: color = False else: raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') file.write('PF\n' if color else 'Pf\n'.encode()) file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0])) endian = image.dtype.byteorder if endian == '<' or endian == '=' and sys.byteorder == 'little': scale = -scale file.write('%f\n'.encode() % scale) image.tofile(file) def readFlow(name): if name.endswith('.pfm') or name.endswith('.PFM'): return readPFM(name)[0][:,:,0:2] f = open(name, 'rb') header = f.read(4) if header.decode("utf-8") != 'PIEH': raise Exception('Flow file header does not contain PIEH') width = np.fromfile(f, np.int32, 1).squeeze() height = np.fromfile(f, np.int32, 1).squeeze() flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2)) return flow.astype(np.float32) def readImage(name): if name.endswith('.pfm') or name.endswith('.PFM'): data = readPFM(name)[0] if len(data.shape)==3: return data[:,:,0:3] else: return data return imread(name) def writeImage(name, data): if name.endswith('.pfm') or name.endswith('.PFM'): return writePFM(name, data, 1) return imwrite(name, data) def writeFlow(name, flow): f = open(name, 'wb') f.write('PIEH'.encode('utf-8')) np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) flow = flow.astype(np.float32) flow.tofile(f) def readFloat(name): f = open(name, 'rb') if(f.readline().decode("utf-8")) != 'float\n': raise Exception('float file %s did not contain keyword' % name) dim = int(f.readline()) dims = [] count = 1 for i in range(0, dim): d = int(f.readline()) dims.append(d) count *= d dims = list(reversed(dims)) data = np.fromfile(f, np.float32, count).reshape(dims) if dim > 2: data = np.transpose(data, (2, 1, 0)) data = np.transpose(data, (1, 0, 2)) return data def writeFloat(name, data): f = open(name, 'wb') dim=len(data.shape) if dim>3: raise Exception('bad float file dimension: %d' % dim) f.write(('float\n').encode('ascii')) f.write(('%d\n' % dim).encode('ascii')) if dim == 1: f.write(('%d\n' % data.shape[0]).encode('ascii')) else: f.write(('%d\n' % data.shape[1]).encode('ascii')) f.write(('%d\n' % data.shape[0]).encode('ascii')) for i in range(2, dim): f.write(('%d\n' % data.shape[i]).encode('ascii')) data = data.astype(np.float32) if dim==2: data.tofile(f) else: np.transpose(data, (2, 0, 1)).tofile(f) def check_dim_and_resize(tensor_list): shape_list = [] for t in tensor_list: shape_list.append(t.shape[2:]) if len(set(shape_list)) > 1: desired_shape = shape_list[0] print(f'Inconsistent size of input video frames. All frames will be resized to {desired_shape}') resize_tensor_list = [] for t in tensor_list: resize_tensor_list.append(torch.nn.functional.interpolate(t, size=tuple(desired_shape), mode='bilinear')) tensor_list = resize_tensor_list return tensor_list ================================================ FILE: opensora/models/prompt_refiner/inference.py ================================================ from transformers import AutoModelForCausalLM, AutoTokenizer import torch from tqdm import tqdm import argparse def get_output(prompt): template = "Refine the sentence: \"{}\" to contain subject description, action, scene description. " \ "(Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. " \ "Make sure it is a fluent sentence, not nonsense." prompt = template.format(prompt) messages = [ {"role": "system", "content": "You are a caption refiner."}, {"role": "user", "content": prompt} ] input_ids = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) model_inputs = tokenizer([input_ids], return_tensors="pt").to(device) generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512) generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] print('\nInput\n:', prompt) print('\nOutput\n:', response) return response def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--mode_path", type=str, default="llama3_8B_lora_merged_cn") parser.add_argument("--prompt", type=str, default='a dog is running.') args = parser.parse_args() return args if __name__ == '__main__': args = parse_args() device = torch.device('cuda') tokenizer = AutoTokenizer.from_pretrained(args.mode_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(args.mode_path,torch_dtype=torch.bfloat16, trust_remote_code=True).to(device).eval() response = get_output(args.prompt) ================================================ FILE: opensora/models/prompt_refiner/merge.py ================================================ import os from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel import torch import argparse def get_lora_model(base_model_path, lora_model_input_path, lora_model_output_path): model = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, device_map="auto",trust_remote_code=True) model = PeftModel.from_pretrained(model, lora_model_input_path) merged_model = model.merge_and_unload() merged_model.save_pretrained(lora_model_output_path, safe_serialization=True) print("Merge lora to base model") tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True) tokenizer.save_pretrained(lora_model_output_path) print("Save tokenizer") def get_model_result(base_model_path, fintune_model_path): tokenizer = AutoTokenizer.from_pretrained(base_model_path) device = "cuda" fintune_model = AutoModelForCausalLM.from_pretrained( fintune_model_path, device_map="auto", torch_dtype=torch.bfloat16, ).eval() base_model = AutoModelForCausalLM.from_pretrained( base_model_path, device_map="auto", torch_dtype=torch.bfloat16, ).eval() template = "Refine the sentence: \"{}\" to contain subject description, action, scene description. " \ "(Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. " \ "Make sure it is a fluent sentence, not nonsense." prompt = "a dog和一只猫" prompt = template.format(prompt) messages = [ {"role": "system", "content": "You are a caption refiner."}, {"role": "user", "content": prompt} ] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) model_inputs = tokenizer([text], return_tensors="pt").to(device) def get_result(model_inputs, model): generated_ids = model.generate( model_inputs.input_ids, max_new_tokens=512, eos_token_id=tokenizer.get_vocab()["<|eot_id|>"] ) generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] return response base_model_response = get_result(model_inputs, base_model) fintune_model_response = get_result(model_inputs, fintune_model) print("\nInput\n", prompt) print("\nResult before fine-tune:\n", base_model_response) print("\nResult after fine-tune:\n", fintune_model_response) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--base_path", type=str, default="Meta-Llama-3___1-8B-Instruct") parser.add_argument("--lora_in_path", type=str, default="llama3_1_instruct_lora/checkpoint-1008") parser.add_argument("--lora_out_path", type=str, default="llama3_1_instruct_lora/llama3_8B_lora_merged_cn") args = parser.parse_args() return args if __name__ == '__main__': args = parse_args() get_lora_model(args.base_path, args.lora_in_path, args.lora_out_path) get_model_result(args.base_path, args.lora_out_path) ================================================ FILE: opensora/models/prompt_refiner/train.py ================================================ from datasets import Dataset import pandas as pd from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer, GenerationConfig from peft import LoraConfig, TaskType, get_peft_model import torch import argparse ins = "Refine the sentence to contain subject description, action, scene description. " \ "(Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. " \ "Make sure it is a fluent sentence, not nonsense." def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--data_path", type=str, default='refine_32255.json') parser.add_argument("--model_path", type=str, default='Meta-Llama-3___1-8B-Instruct') parser.add_argument("--lora_out_path", type=str, default="llama3_1_instruct_lora") args = parser.parse_args() return args args = parse_args() df = pd.read_json(args.data_path) ds = Dataset.from_pandas(df) tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token def process_func(example): MAX_LENGTH = 2048 input_ids, attention_mask, labels = [], [], [] instruction = tokenizer(f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a caption refiner.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{example['instruction'] + example['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", add_special_tokens=False) # add_special_tokens 不在开头加 special_tokens response = tokenizer(f"{example['output']}<|eot_id|>", add_special_tokens=False) input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id] attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1] labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id] if len(input_ids) > MAX_LENGTH: input_ids = input_ids[:MAX_LENGTH] attention_mask = attention_mask[:MAX_LENGTH] labels = labels[:MAX_LENGTH] return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels } tokenized_id = ds.map(process_func, remove_columns=ds.column_names) model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="auto",torch_dtype=torch.bfloat16) print(model) model.enable_input_require_grads() config = LoraConfig( task_type=TaskType.CAUSAL_LM, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], inference_mode=False, r=64, lora_alpha=64, lora_dropout=0.1 ) print(config) model = get_peft_model(model, config) model.print_trainable_parameters() args = TrainingArguments( output_dir=args.lora_out_path, per_device_train_batch_size=32, gradient_accumulation_steps=1, logging_steps=1, num_train_epochs=1, save_steps=20, dataloader_num_workers=4, learning_rate=1.5e-4, warmup_ratio=0.1, save_on_each_node=True, gradient_checkpointing=True, report_to='wandb', ) trainer = Trainer( model=model, args=args, train_dataset=tokenized_id, data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True), ) trainer.train() ================================================ FILE: opensora/models/text_encoder/__init__.py ================================================ from opensora.models.text_encoder.clip import CLIPWrapper from opensora.models.text_encoder.t5 import T5Wrapper text_encoder = { 'google/mt5-xl': T5Wrapper, 'google/mt5-xxl': T5Wrapper, 'google/umt5-xl': T5Wrapper, 'google/umt5-xxl': T5Wrapper, 'google/t5-v1_1-xl': T5Wrapper, 'DeepFloyd/t5-v1_1-xxl': T5Wrapper, 'openai/clip-vit-large-patch14': CLIPWrapper, 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k': CLIPWrapper } def get_text_warpper(text_encoder_name): """deprecation""" encoder_key = None for key in text_encoder.keys(): if key in text_encoder_name: encoder_key = key break text_enc = text_encoder.get(encoder_key, None) assert text_enc is not None return text_enc ================================================ FILE: opensora/models/text_encoder/clip.py ================================================ import torch from torch import nn from transformers import CLIPTextModelWithProjection try: import torch_npu except: torch_npu = None class CLIPWrapper(nn.Module): def __init__(self, args, **kwargs): super(CLIPWrapper, self).__init__() self.model_name = args.text_encoder_name_2 if torch_npu is not None: self.model_name = '/home/save_dir/pretrained/clip/models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k/snapshots/bc7788f151930d91b58474715fdce5524ad9a189' else: self.model_name = '/storage/cache_dir/CLIP-ViT-bigG-14-laion2B-39B-b160k' print(f'Loading CLIP model from {self.model_name}...') self.text_enc = CLIPTextModelWithProjection.from_pretrained(self.model_name, cache_dir=args.cache_dir, **kwargs).eval() def forward(self, input_ids, attention_mask): text_encoder_embs = self.text_enc(input_ids=input_ids, output_hidden_states=True)[0] return text_encoder_embs.detach() ================================================ FILE: opensora/models/text_encoder/t5.py ================================================ import torch from torch import nn from transformers import T5EncoderModel try: import torch_npu except: torch_npu = None class T5Wrapper(nn.Module): def __init__(self, args, **kwargs): super(T5Wrapper, self).__init__() self.model_name = args.text_encoder_name_1 print(f'Loading T5 model from {self.model_name}...') self.text_enc = T5EncoderModel.from_pretrained(self.model_name, cache_dir=args.cache_dir, **kwargs).eval() def forward(self, input_ids, attention_mask): text_encoder_embs = self.text_enc(input_ids=input_ids, attention_mask=attention_mask)['last_hidden_state'] return text_encoder_embs.detach() ================================================ FILE: opensora/npu_config.py ================================================ import math import mmap import os import pickle import random import numpy as np import torch import subprocess import sys import threading import gc import torch.distributed as dist from opensora.adaptor.zp_manager import zp_manager try: import torch_npu npu_is_available = True from torch_npu.contrib import transfer_to_npu except: npu_is_available = False from contextlib import contextmanager import types def compress_video(input_file, output_file, out_size): """使用 ffmpeg 压缩视频文件。""" command = [ 'ffmpeg', '-i', input_file, '-vf', f"scale='min({out_size},iw)':'min({out_size},ih)':force_original_aspect_ratio=decrease", '-c:v', 'libx264', '-crf', '18', '-preset', 'slow', '-c:a', 'copy', output_file ] subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) @contextmanager def set_run_dtype(x, dtype=None): # 保存原始环境变量的值(如果存在) npu_config.original_run_dtype = x.dtype # 设置环境变量为指定的值 npu_config.current_run_dtype = dtype try: # Yield control back to the body of the `with` statement yield finally: # 恢复原始的环境变量值 npu_config.current_run_dtype = None npu_config.original_run_dtype = None class NPUConfig: N_NPU_PER_NODE = 8 def __init__(self): self.on_npu = npu_is_available self.node_world_size = self.N_NPU_PER_NODE self.profiling = False self.profiling_step = 5 self.enable_FA = True self.enable_FP32 = False self.load_pickle = True self.use_small_dataset = False self.current_run_dtype = None self.original_run_dtype = None self.zp_manager = zp_manager self.replaced_type = torch.float32 self.conv_dtype = torch.float16 if self.enable_FA and self.enable_FP32: self.inf_float = -10000.0 else: self.inf_float = -10000.0 if self.use_small_dataset: self.load_pickle = False self._loss = [] self.work_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) self.pickle_save_path = f"{self.work_path}/pickles" self.mm = dict() if self.on_npu: import deepspeed import sys torch_npu.npu.set_compile_mode(jit_compile=False) import deepspeed.runtime.utils as utils from opensora.adaptor.utils import all_gather_dp_groups, all_gather_into_tensor_dp_groups utils.all_gather_dp_groups = all_gather_dp_groups import deepspeed.runtime.bf16_optimizer as bf16_optimizer from opensora.adaptor.bf16_optimizer import BF16_Optimizer self.replace_methods(bf16_optimizer.BF16_Optimizer, BF16_Optimizer) from opensora.adaptor.stage_1_and_2 import DeepSpeedZeroOptimizer import deepspeed.runtime.zero.stage_1_and_2 as stage_1_and_2 self.replace_methods(stage_1_and_2.DeepSpeedZeroOptimizer, DeepSpeedZeroOptimizer, ['_has_inf_or_nan']) import deepspeed.runtime.engine as engine from opensora.adaptor.engine import DeepSpeedEngine self.replace_methods(engine.DeepSpeedEngine, DeepSpeedEngine, skip_fcns=['__init__', '_copy_recovery_script', '_change_recovery_script_permissions']) if "RANK" in os.environ: self.rank = int(os.environ["RANK"]) self.world_size = int(os.environ["WORLD_SIZE"]) torch_npu.npu.set_device(self.get_local_rank()) else: self.rank = torch.cuda.current_device() self.world_size = self.N_NPU_PER_NODE self.print_with_rank(f"The npu_config.on_npu is {self.on_npu}") self.bind_thread_to_cpu() gc.set_threshold(700, 10, 10000) def get_total_cores(self): try: total_cores = os.sysconf('SC_NPROCESSORS_ONLN') except (AttributeError, ValueError): total_cores = os.cpu_count() return total_cores def bind_thread_to_cpu(self): total_cores = self.get_total_cores() # 每个卡的核心数量 cores_per_rank = total_cores // 8 # 计算本地rank local_rank = self.rank % 8 # 计算当前 rank 的 CPU 核范围 start_core = local_rank * cores_per_rank end_core = start_core + cores_per_rank - 1 # 构建 CPU 核范围字符串 cpu_cores_range = f"{start_core}-{end_core}" pid = os.getpid() command = f"taskset -cp {cpu_cores_range} {pid}" subprocess.run(command, shell=True, check=True) return f"Binding Cores:{self.rank}:{pid}:{cpu_cores_range}" def replace_methods(self, target_class, source_class, skip_fcns=[], only_include_fcns=None): for attr_name in dir(source_class): attr_value = getattr(source_class, attr_name) if attr_name in source_class.__dict__: attr_class_value = source_class.__dict__[attr_name] else: attr_class_value = attr_value if (isinstance(attr_class_value, staticmethod) or isinstance(attr_class_value, classmethod) or attr_name in skip_fcns): print(f"skip replace {attr_name}") continue if only_include_fcns is not None and attr_name not in only_include_fcns: continue elif isinstance(attr_value, types.FunctionType): setattr(target_class, attr_name, attr_value) def get_attention_mask(self, attention_mask, repeat_num): if self.on_npu and attention_mask is not None: if npu_config.enable_FA: attention_mask = attention_mask.to(torch.bool) attention_mask = attention_mask.repeat_interleave(repeat_num, dim=-2) return attention_mask def set_current_run_dtype(self, variables): if variables[0].dtype != self.current_run_dtype and self.current_run_dtype is not None: for index, var in enumerate(variables): variables[index] = var.to(self.current_run_dtype) return tuple(variables) def restore_dtype(self, x): if x.dtype != self.original_run_dtype and self.original_run_dtype is not None: x = x.to(self.original_run_dtype) return x def get_output_video_path(self, name): os.makedirs(f"{self.work_path}/output_videos", exist_ok=True) return f"{self.work_path}/output_videos/{name}" def get_node_id(self): return self.rank // self.node_world_size def get_node_size(self): return self.world_size // self.node_world_size def get_local_rank(self): return self.rank % self.N_NPU_PER_NODE def get_pickle_path(self, file_name): return f"{self.pickle_save_path}/{file_name}_local_n63" def free_mm(self): for key, value in self.mm.items(): value.close() self.mm.clear() def __del__(self): self.free_mm() def try_load_pickle(self, file_name, function): file_name = self.get_pickle_path(file_name) if os.path.exists(file_name) and self.load_pickle: with open(file_name, 'rb') as file: # self.mm[file_name] = mmap.mmap(file.fileno(), 0, access=mmap.ACCESS_READ) # # 使用 mmap 进行数据读取 # loaded_data = pickle.loads(self.mm[file_name][:]) loaded_data = pickle.load(file) return loaded_data else: data = function() if not self.use_small_dataset: if self.rank % self.N_NPU_PER_NODE == 0: # 只需要rank0保存文件 os.makedirs(self.pickle_save_path, exist_ok=True) with open(file_name, 'wb') as file: pickle.dump(data, file, pickle.HIGHEST_PROTOCOL) return data def try_get_vid_path(self, file, out_size=1024): output_file = file.rsplit(".", 1)[0] + f"_resize{out_size}.mp4" if not os.path.exists(output_file): return file # compress_video(file, output_file, out_size) return output_file def npu_format_cast(self, x): return torch_npu.npu_format_cast(x, 2) def calc_grad_norm(self, model): # 计算并打印梯度范数 # model_engine = accelerator.deepspeed_engine_wrapped.engine # gradients = model_engine.get_gradients() # grad_norm = get_grad_norm(gradients) # 计算并打印梯度范数 grad_norm = 0 n_grad = 0 # for name, param in model.named_parameters(): # grad_data = deepspeed.utils.safe_get_full_grad(param) # # self.print_tensor_stats(grad_data, name=name) # # if grad_data is not None: # param_norm = grad_data.norm(2) # grad_norm += param_norm.item() ** 2 # n_grad += 1 # grad_norm = (grad_norm / n_grad) ** (1. / 2) return grad_norm def _run(self, operator, x, tmp_dtype, out_dtype=None, out_nd_format=False): if self.on_npu: if out_dtype is None: out_dtype = x.dtype with torch.cuda.amp.autocast(enabled=False): x = operator.to(device=x.device, dtype=tmp_dtype)(x.to(tmp_dtype)) x = x.to(out_dtype) if out_nd_format: return self.npu_format_cast(x) else: return x else: return operator(x) def run_group_norm(self, operator, x): return self._run(operator, x, torch.float32) def run_layer_norm(self, operator, x): return self._run(operator, x, torch.float32) def print_tensor_stats(self, tensor, name="Tensor", rank=None): if rank and rank != self.rank: return if tensor is None: self.print_msg(f"Tensor {name} is None.") return x_dtype = tensor.dtype tensor = tensor.to(torch.bfloat16) max_val = tensor.max().item() min_val = tensor.min().item() abs_max_val = min(abs(max_val), abs(min_val)) mean_val = tensor.mean().item() median_val = tensor.median().item() std_val = tensor.std().item() shape = tensor.shape self.print_msg( f"{name} - Max: {max_val}, Min: {min_val}, Mean: {mean_val}, AbsMax: {abs_max_val}," f"Median: {median_val}, Std: {std_val}, Shape: {shape}, Type: {x_dtype}") def run_conv3d(self, operator, x, out_dtype): return self._run(operator, x, self.conv_dtype, out_dtype, out_nd_format=True) def run_pool_2d(self, operator, x): return self._run(operator, x, self.replaced_type) def run_pad_2d(self, operator, x, pad, mode="constant"): if self.on_npu: x_dtype = x.dtype x = x.to(self.replaced_type) x = operator(x, pad, mode) x = x.to(x_dtype) else: x = operator(x, pad, mode) return x def seed_everything(self, seed=100): seed += self.rank random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def print_with_rank(self, msg, rank=0, save=False): if self.rank == rank: print(f"{msg}", flush=True) if save: self._loss.append(msg) def print_msg(self, msg, on=True, rank=None): if on: if self.rank == rank or rank is None: print(f"[RANK-{self.rank}]: {msg}", flush=True) def save_loss(self, filename, rank=0): if self.rank == rank: import json with open(filename, 'w') as file: json.dump(self._loss, file, indent=4) def run_attention(self, query, key, value, atten_mask, input_layout, head_dim, head_num): if self.enable_FA: hidden_states = torch_npu.npu_fusion_attention(query, key, value, atten_mask=atten_mask, input_layout=input_layout, scale=1 / math.sqrt(head_dim), head_num=head_num)[0] else: hidden_states = self.scaled_dot_product_attention(query, key, value, atten_mask=atten_mask, input_layout=input_layout, scale=1 / math.sqrt(head_dim), head_num=head_num) return hidden_states def scaled_dot_product_attention(self, query, key, value, input_layout, head_num=None, atten_mask=None, scale=None, dropout_p=0.0, is_causal=False) -> torch.Tensor: # L, S = query.size(-2), key.size(-2) def trans_tensor_shape(x, layout, head_num): if layout == "BSH": batch = x.shape[0] x = x.view(batch, -1, head_num, x.shape[-1] // head_num).transpose(1, 2).contiguous() elif layout == "SBH": batch = x.shape[1] x = x.view(-1, batch * head_num, x.shape[-1] // head_num).transpose(0, 1).contiguous() x = x.view(batch, head_num, -1, x.shape[-1]) return x query = trans_tensor_shape(query, input_layout, head_num) key = trans_tensor_shape(key, input_layout, head_num) value = trans_tensor_shape(value, input_layout, head_num) attn_weight = query @ key.transpose(-2, -1) * scale attn_bias = torch.zeros_like(attn_weight, dtype=query.dtype, device=query.device) if is_causal: assert atten_mask is None temp_mask = torch.zeros_like(attn_weight, dtype=torch.bool, device=query.device).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), npu_config.inf_float) attn_bias.to(query.dtype) if atten_mask is not None: assert (not self.enable_FA) and atten_mask.dtype != torch.bool, \ "attention_mask must not be bool type when use this function" attn_weight += attn_bias attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) output = attn_weight @ value if input_layout == "BSH": output = output.transpose(1, 2).contiguous().view(output.shape[0], -1, head_num * output.shape[-1]) else: output = output.view(output.shape[0] * head_num, -1, output.shape[-1]).transpose(0, 1).contiguous() output = output.view(output.shape[0], -1, head_num * output.shape[-1]) return output def print_tensor_with_rank(self, name, tensor, rank=[0], dim_print_cnt=[]): if type(rank) is not list: rank = [rank] if self.rank in rank: def print_dim(tensor_, indices): if tensor_.dim() == len(indices): return '{0:10.5f} '.format(tensor[tuple(indices)].detach().item()) else: cur_dim = len(indices) ret = '' for x in range(0, tensor_.size(cur_dim), tensor_.size(cur_dim) // dim_print_cnt[cur_dim]): ret += print_dim(tensor_, indices + [x]) return ret + '\n' print(name, tensor.size(), self.rank, '\n', print_dim(tensor, [])) npu_config = NPUConfig() ================================================ FILE: opensora/sample/caption_refiner.py ================================================ import torch from torch import nn from transformers import AutoTokenizer, AutoModelForCausalLM TEMPLATE = """ Refine the sentence: \"{}\" to contain subject description, action, scene description. " \ "(Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. " \ "Make sure it is a fluent sentence, not nonsense. """ class OpenSoraCaptionRefiner(nn.Module): def __init__(self, args, dtype, device): super().__init__() self.tokenizer = AutoTokenizer.from_pretrained( args.caption_refiner, trust_remote_code=True ) self.model = AutoModelForCausalLM.from_pretrained( args.caption_refiner, torch_dtype=dtype, trust_remote_code=True ).to(device).eval() self.device = device def get_refiner_output(self, prompt): prompt = TEMPLATE.format(prompt) messages = [ {"role": "system", "content": "You are a caption refiner."}, {"role": "user", "content": prompt} ] input_ids = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) model_inputs = self.tokenizer([input_ids], return_tensors="pt").to(self.device) generated_ids = self.model.generate(model_inputs.input_ids, max_new_tokens=512) generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] return response ================================================ FILE: opensora/sample/pipeline_inpaint.py ================================================ import inspect import os from typing import Callable, Dict, List, Optional, Tuple, Union from dataclasses import dataclass from altair import condition import numpy as np import torch from einops import rearrange from PIL import Image import decord from transformers import CLIPTextModelWithProjection, CLIPTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel import torch.nn.functional as F from torchvision.transforms import Compose, Lambda, Resize from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.image_processor import VaeImageProcessor from diffusers.models import AutoencoderKL, HunyuanDiT2DModel from diffusers.models.embeddings import get_2d_rotary_pos_embed from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import DDPMScheduler from diffusers.utils import logging, BaseOutput from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline from opensora.models.diffusion.opensora_v1_3.modeling_inpaint import OpenSoraInpaint_v1_3 from opensora.sample.pipeline_opensora import OpenSoraPipeline, OpenSoraPipelineOutput, rescale_noise_cfg from opensora.dataset.transform import CenterCropResizeVideo, SpatialStrideCropVideo,ToTensorAfterResize, maxhwresize from opensora.utils.mask_utils import MaskProcessor, MaskCompressor, GaussianNoiseAdder, MaskType, STR_TO_TYPE, TYPE_TO_STR try: import torch_npu from opensora.npu_config import npu_config from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info except: torch_npu = None npu_config = None from opensora.utils.parallel_states import get_sequence_parallel_state, nccl_info logger = logging.get_logger(__name__) # pylint: disable=invalid-name def is_video_file(file_path): video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.mpeg', '.mpg', '.3gp'} file_extension = os.path.splitext(file_path)[1].lower() return file_extension in video_extensions def is_image_file(file_path): image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.webp'} file_extension = os.path.splitext(file_path)[1].lower() return file_extension in image_extensions def open_image(file_path): image = Image.open(file_path).convert("RGB") return image def open_video(file_path, start_frame_idx, num_frames, frame_interval=1): decord_vr = decord.VideoReader(file_path, ctx=decord.cpu(0), num_threads=1) total_frames = len(decord_vr) frame_indices = list(range(start_frame_idx, min(start_frame_idx + num_frames * frame_interval, total_frames), frame_interval)) if len(frame_indices) == 0: raise ValueError("No frames selected. Check your start_frame_idx and num_frames.") if len(frame_indices) < num_frames: raise ValueError(f"Requested {num_frames} frames but only {len(frame_indices)} frames are available, please adjust the start_frame_idx and num_frames or decrease the frame_interval.") video_data = decord_vr.get_batch(frame_indices).asnumpy() video_data = torch.from_numpy(video_data) video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W) return video_data def get_pixel_values(file_path, num_frames): if is_image_file(file_path[0]): pixel_values = [open_image(path) for path in file_path] pixel_values = [torch.from_numpy(np.array(image)) for image in pixel_values] pixel_values = [rearrange(image, 'h w c -> c h w').unsqueeze(0) for image in pixel_values] elif is_video_file(file_path[0]): pixel_values = [open_video(video_path, 0, num_frames) for video_path in file_path] return pixel_values class OpenSoraInpaintPipeline(OpenSoraPipeline): def __init__( self, vae: AutoencoderKL, text_encoder: T5EncoderModel, tokenizer: MT5Tokenizer, transformer: OpenSoraInpaint_v1_3, scheduler: DDPMScheduler, text_encoder_2: CLIPTextModelWithProjection = None, tokenizer_2: CLIPTokenizer = None, ): super().__init__( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, scheduler=scheduler, text_encoder_2=text_encoder_2, tokenizer_2=tokenizer_2, ) # If performing continuation or random, the default mask is half of the frame, which can be modified self.mask_processor = MaskProcessor(min_clear_ratio=0.5, max_clear_ratio=0.5) self.mask_compressor = MaskCompressor(ae_stride_t=self.vae.vae_scale_factor[0], ae_stride_h=self.vae.vae_scale_factor[1], ae_stride_w=self.vae.vae_scale_factor[2]) self.noise_adder = None def check_inputs( self, conditional_pixel_values_path, conditional_pixel_values_indices, mask_type, max_hxw, noise_strength, prompt, num_frames, height, width, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, prompt_embeds_2=None, negative_prompt_embeds_2=None, prompt_attention_mask_2=None, negative_prompt_attention_mask_2=None, callback_on_step_end_tensor_inputs=None, ): if conditional_pixel_values_path is None: raise ValueError("conditional_pixel_values_path should be provided") else: if not isinstance(conditional_pixel_values_path, list) or not isinstance(conditional_pixel_values_path[0], str): raise ValueError("conditional_pixel_values_path should be a list of strings") if not is_image_file(conditional_pixel_values_path[0]) and not is_video_file(conditional_pixel_values_path[0]): raise ValueError("conditional_pixel_values_path should be an image or video file path") if is_video_file(conditional_pixel_values_path[0]) and len(conditional_pixel_values_path) > 1: raise ValueError("conditional_pixel_values_path should be a list of image file paths or a single video file path") if conditional_pixel_values_indices is not None \ and (not isinstance(conditional_pixel_values_indices, list) or not isinstance(conditional_pixel_values_indices[0], int) \ or len(conditional_pixel_values_indices) != len(conditional_pixel_values_path)): raise ValueError("conditional_pixel_values_indices should be a list of integers with the same length as conditional_pixel_values_path") if mask_type is not None and not mask_type in STR_TO_TYPE.keys() and not mask_type in STR_TO_TYPE.values(): raise ValueError(f"Invalid mask type: {mask_type}") if not isinstance(max_hxw, int) or not (max_hxw >= 102400 and max_hxw <= 236544): raise ValueError("max_hxw should be an integer between 102400 and 236544") if not isinstance(noise_strength, float) or not (noise_strength >= 0 and noise_strength <= 1): raise ValueError("noise_strength should be a non-negative float") super().check_inputs( prompt, num_frames, height, width, negative_prompt, prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask, prompt_embeds_2, negative_prompt_embeds_2, prompt_attention_mask_2, negative_prompt_attention_mask_2, callback_on_step_end_tensor_inputs, ) def get_resize_transform( self, ori_height, ori_width, height=None, width=None, crop_for_hw=False, hw_stride=32, max_hxw=236544, # 480 x 480 ): if crop_for_hw: assert height is not None and width is not None transform = CenterCropResizeVideo((height, width)) else: new_height, new_width = maxhwresize(ori_height, ori_width, max_hxw) transform = Compose( [ CenterCropResizeVideo((new_height, new_width)), # We use CenterCropResizeVideo to share the same height and width, ensuring that the shape of the crop remains consistent when multiple images are captured SpatialStrideCropVideo(stride=hw_stride), ] ) return transform def get_video_transform(self): norm_fun = Lambda(lambda x: 2. * x - 1.) transform = Compose([ ToTensorAfterResize(), norm_fun ]) return transform def get_mask_type_cond_indices(self, mask_type, conditional_pixel_values_path, conditional_pixel_values_indices, num_frames): if mask_type is not None and mask_type in STR_TO_TYPE.keys(): mask_type = STR_TO_TYPE[mask_type] if is_image_file(conditional_pixel_values_path[0]): if len(conditional_pixel_values_path) == 1: mask_type = MaskType.i2v if mask_type is None else mask_type if num_frames > 1: conditional_pixel_values_indices = [0] if conditional_pixel_values_indices is None else conditional_pixel_values_indices assert len(conditional_pixel_values_indices) == 1, "conditional_pixel_values_indices should be a list of integers with the same length as conditional_pixel_values_path" elif len(conditional_pixel_values_path) == 2: mask_type = MaskType.transition if mask_type is None else mask_type if num_frames > 1: conditional_pixel_values_indices = [0, -1] if conditional_pixel_values_indices is None else conditional_pixel_values_indices assert len(conditional_pixel_values_indices) == 2, "conditional_pixel_values_indices should be a list of integers with the same length as conditional_pixel_values_path" else: if num_frames > 1: assert conditional_pixel_values_indices is not None and len(conditional_pixel_values_path) == len(conditional_pixel_values_indices), "conditional_pixel_values_indices should be a list of integers with the same length as conditional_pixel_values_path" mask_type = MaskType.random_temporal if mask_type is None else mask_type elif is_video_file(conditional_pixel_values_path[0]): # When the input is a video, video continuation is executed by default, with a continuation rate of double mask_type = MaskType.continuation if mask_type is None else mask_type return mask_type, conditional_pixel_values_indices def get_masked_pixel_values_mask( self, conditional_pixel_values, conditional_pixel_values_indices, mask_type, batch_size, num_samples_per_prompt, num_frames, height, width, video_transform, weight_dtype, device ): if device is None: device = getattr(self, '_execution_device', None) or getattr(self, 'device', None) or torch.device('cuda') conditional_pixel_values = conditional_pixel_values.to(device=device, dtype=weight_dtype) if conditional_pixel_values.shape[0] == num_frames: inpaint_cond_data = self.mask_processor(conditional_pixel_values, mask_type=mask_type) masked_pixel_values, mask = inpaint_cond_data['masked_pixel_values'], inpaint_cond_data['mask'] else: input_pixel_values = torch.zeros([num_frames, 3, height, width], device=device, dtype=weight_dtype) input_mask = torch.ones([num_frames, 1, height, width], device=device, dtype=weight_dtype) input_pixel_values[conditional_pixel_values_indices] = conditional_pixel_values input_mask[conditional_pixel_values_indices] = 0 masked_pixel_values = input_pixel_values * (input_mask < 0.5) mask = input_mask print('conditional_pixel_values_indices', conditional_pixel_values_indices) print('mask_type', TYPE_TO_STR[mask_type]) masked_pixel_values = video_transform(masked_pixel_values) masked_pixel_values = masked_pixel_values.unsqueeze(0).repeat(batch_size * num_samples_per_prompt, 1, 1, 1, 1).transpose(1, 2).contiguous() # b c t h w mask = mask.unsqueeze(0).repeat(batch_size * num_samples_per_prompt, 1, 1, 1, 1).transpose(1, 2).contiguous() # b c t h w if self.noise_adder is not None: # add some noise to improve motion strength masked_pixel_values = self.noise_adder(masked_pixel_values, mask) masked_pixel_values = masked_pixel_values.to(self.vae.vae.dtype) masked_pixel_values = self.vae.encode(masked_pixel_values) mask = self.mask_compressor(mask) masked_pixel_values = torch.cat([masked_pixel_values] * 2) if self.do_classifier_free_guidance else masked_pixel_values mask = torch.cat([mask] * 2) if self.do_classifier_free_guidance else mask masked_pixel_values = masked_pixel_values.to(weight_dtype) mask = mask.to(weight_dtype) return masked_pixel_values, mask @torch.no_grad() def __call__( self, conditional_pixel_values_path: Union[str, List[str]] = None, conditional_pixel_values_indices: Union[int, List[int]] = None, mask_type: Union[str, MaskType] = None, crop_for_hw: bool = False, max_hxw: int = 236544, noise_strength: Optional[float] = 0.0, prompt: Union[str, List[str]] = None, num_frames: Optional[int] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_samples_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, prompt_embeds_2: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds_2: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, prompt_attention_mask_2: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask_2: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], guidance_rescale: float = 0.0, max_sequence_length: int = 512, device = None, ): if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. default height and width num_frames = num_frames or (self.transformer.config.sample_size_t - 1) * self.vae.vae_scale_factor[0] + 1 height = height or self.transformer.config.sample_size[0] * self.vae.vae_scale_factor[1] width = width or self.transformer.config.sample_size[1] * self.vae.vae_scale_factor[2] # 1. Check inputs. Raise error if not correct self.check_inputs( conditional_pixel_values_path, conditional_pixel_values_indices, mask_type, max_hxw, noise_strength, prompt, num_frames, height, width, negative_prompt, prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask, prompt_embeds_2, negative_prompt_embeds_2, prompt_attention_mask_2, negative_prompt_attention_mask_2, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = device or getattr(self, '_execution_device', None) or getattr(self, 'device', None) or torch.device('cuda') # 3. Encode input prompt ( prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask, ) = self.encode_prompt( prompt=prompt, device=device, dtype=self.transformer.dtype, num_samples_per_prompt=num_samples_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, max_sequence_length=max_sequence_length, text_encoder_index=0, ) if self.tokenizer_2 is not None: ( prompt_embeds_2, negative_prompt_embeds_2, prompt_attention_mask_2, negative_prompt_attention_mask_2, ) = self.encode_prompt( prompt=prompt, device=device, dtype=self.transformer.dtype, num_samples_per_prompt=num_samples_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds_2, negative_prompt_embeds=negative_prompt_embeds_2, prompt_attention_mask=prompt_attention_mask_2, negative_prompt_attention_mask=negative_prompt_attention_mask_2, max_sequence_length=77, text_encoder_index=1, ) else: prompt_embeds_2 = None negative_prompt_embeds_2 = None prompt_attention_mask_2 = None negative_prompt_attention_mask_2 = None # ==================prepare inpaint data===================================== if noise_strength != 0: self.noise_adder = GaussianNoiseAdder(mean=np.log(noise_strength), std=0.01, clear_ratio=0) mask_type, conditional_pixel_values_indices = self.get_mask_type_cond_indices(mask_type, conditional_pixel_values_path, conditional_pixel_values_indices, num_frames) conditional_pixel_values = get_pixel_values(conditional_pixel_values_path, num_frames) min_height = min([pixels.shape[2] for pixels in conditional_pixel_values]) min_width = min([pixels.shape[3] for pixels in conditional_pixel_values]) resize_transform = self.get_resize_transform( ori_height=min_height, ori_width=min_width, height=height, width=width, crop_for_hw=crop_for_hw, max_hxw=max_hxw, ) video_transform = self.get_video_transform() conditional_pixel_values = torch.cat([resize_transform(pixels) for pixels in conditional_pixel_values]) real_height, real_width = conditional_pixel_values.shape[-2], conditional_pixel_values.shape[-1] # ==================prepare inpaint data===================================== # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables if get_sequence_parallel_state(): world_size = hccl_info.world_size if torch_npu is not None else nccl_info.world_size num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_samples_per_prompt, num_channels_latents, (num_frames + world_size - 1) // world_size if get_sequence_parallel_state() else num_frames, real_height, real_width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # ==============================create mask===================================== masked_pixel_values, mask = self.get_masked_pixel_values_mask( conditional_pixel_values, conditional_pixel_values_indices, mask_type, batch_size, num_samples_per_prompt, num_frames, real_height, real_width, video_transform, prompt_embeds.dtype, device ) # ==============================create mask===================================== # 7 create image_rotary_emb, style embedding & time ids if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) if self.tokenizer_2 is not None: prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) prompt_embeds = prompt_embeds.to(device=device) prompt_attention_mask = prompt_attention_mask.to(device=device) if self.tokenizer_2 is not None: prompt_embeds_2 = prompt_embeds_2.to(device=device) prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) # ==================make sp===================================== if get_sequence_parallel_state(): prompt_embeds = rearrange( prompt_embeds, 'b (n x) h -> b n x h', n=world_size, x=prompt_embeds.shape[1] // world_size ).contiguous() rank = hccl_info.rank if torch_npu is not None else nccl_info.rank prompt_embeds = prompt_embeds[:, rank, :, :] latents_num_frames = latents.shape[2] masked_pixel_values = masked_pixel_values[:, :, latents_num_frames * rank: latents_num_frames * (rank + 1)] mask = mask[:, :, latents_num_frames * rank: latents_num_frames * (rank + 1)] # ==================make sp===================================== # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # inpaint latent_model_input = torch.cat([latent_model_input, masked_pixel_values, mask], dim=1) # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( dtype=latent_model_input.dtype ) # ==================prepare my shape===================================== # predict the noise residual if prompt_embeds.ndim == 3: prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d if prompt_attention_mask.ndim == 2: prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l if prompt_embeds_2 is not None and prompt_embeds_2.ndim == 2: prompt_embeds = prompt_embeds.unsqueeze(1) # b d -> b 1 d attention_mask = torch.ones_like(latent_model_input)[:, 0].to(device=device) # ==================prepare my shape===================================== # ==================make sp===================================== if get_sequence_parallel_state(): attention_mask = attention_mask.repeat(1, world_size, 1, 1) # ==================make sp===================================== noise_pred = self.transformer( latent_model_input, attention_mask=attention_mask, encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, timestep=t_expand, pooled_projections=prompt_embeds_2, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) negative_prompt_embeds_2 = callback_outputs.pop( "negative_prompt_embeds_2", negative_prompt_embeds_2 ) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() # ==================make sp===================================== if get_sequence_parallel_state(): latents_shape = list(latents.shape) # b c t//sp h w full_shape = [latents_shape[0] * world_size] + latents_shape[1:] # # b*sp c t//sp h w all_latents = torch.zeros(full_shape, dtype=latents.dtype, device=latents.device) torch.distributed.all_gather_into_tensor(all_latents, latents) latents_list = list(all_latents.chunk(world_size, dim=0)) latents = torch.cat(latents_list, dim=2) # ==================make sp===================================== if not output_type == "latent": videos = self.decode_latents(latents) videos = videos[:, :num_frames, :real_height, :real_width] else: videos = latents # Offload all models self.maybe_free_model_hooks() if not return_dict: return (videos, ) return OpenSoraPipelineOutput(videos=videos) ================================================ FILE: opensora/sample/pipeline_opensora.py ================================================ import inspect from typing import Callable, Dict, List, Optional, Tuple, Union from dataclasses import dataclass import numpy as np import torch from einops import rearrange from transformers import CLIPTextModelWithProjection, CLIPTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.image_processor import VaeImageProcessor from diffusers.models import AutoencoderKL, HunyuanDiT2DModel from diffusers.models.embeddings import get_2d_rotary_pos_embed from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import DDPMScheduler, FlowMatchEulerDiscreteScheduler from diffusers.utils import logging, BaseOutput from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline from opensora.models.diffusion.opensora_v1_3.modeling_opensora import OpenSoraT2V_v1_3 try: import torch_npu from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info except: torch_npu = None from opensora.utils.parallel_states import get_sequence_parallel_state, nccl_info logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class OpenSoraPipelineOutput(BaseOutput): videos: Union[List[torch.FloatTensor], np.ndarray] # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class OpenSoraPipeline(DiffusionPipeline): model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" _optional_components = [ "text_encoder_2", "tokenizer_2", "text_encoder", "tokenizer", ] _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "prompt_embeds_2", "negative_prompt_embeds_2", ] def __init__( self, vae: AutoencoderKL, text_encoder: T5EncoderModel, tokenizer: MT5Tokenizer, transformer: OpenSoraT2V_v1_3, scheduler: DDPMScheduler, text_encoder_2: CLIPTextModelWithProjection = None, tokenizer_2: CLIPTokenizer = None, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, tokenizer_2=tokenizer_2, transformer=transformer, scheduler=scheduler, text_encoder_2=text_encoder_2, ) def encode_prompt( self, prompt: str, device: torch.device = None, dtype: torch.dtype = None, num_samples_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, max_sequence_length: Optional[int] = None, text_encoder_index: int = 0, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device dtype (`torch.dtype`): torch dtype num_samples_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. prompt_attention_mask (`torch.Tensor`, *optional*): Attention mask for the prompt. Required when `prompt_embeds` is passed directly. negative_prompt_attention_mask (`torch.Tensor`, *optional*): Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. text_encoder_index (`int`, *optional*): Index of the text encoder to use. `0` for T5 and `1` for clip. """ if dtype is None: if self.text_encoder_2 is not None: dtype = self.text_encoder_2.dtype elif self.transformer is not None: dtype = self.transformer.dtype else: dtype = None if device is None: device = getattr(self, '_execution_device', None) or getattr(self, 'device', None) or torch.device('cuda') tokenizers = [self.tokenizer, self.tokenizer_2] text_encoders = [self.text_encoder, self.text_encoder_2] tokenizer = tokenizers[text_encoder_index] text_encoder = text_encoders[text_encoder_index] if max_sequence_length is None: if text_encoder_index == 0: max_length = 512 if text_encoder_index == 1: max_length = 77 else: max_length = max_sequence_length if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: text_inputs = tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = text_encoder( text_input_ids.to(device), attention_mask=prompt_attention_mask, ) prompt_embeds = prompt_embeds[0] if text_encoder_index == 1: prompt_embeds = prompt_embeds.unsqueeze(1) # b d -> b 1 d for clip prompt_attention_mask = prompt_attention_mask.repeat(num_samples_per_prompt, 1) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_samples_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_samples_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # max_length = prompt_embeds.shape[1] uncond_input = tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if text_encoder_index == 1: negative_prompt_embeds = negative_prompt_embeds.unsqueeze(1) # b d -> b 1 d for clip negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_samples_per_prompt, 1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_samples_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_samples_per_prompt, seq_len, -1) return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, num_frames, height, width, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, prompt_embeds_2=None, negative_prompt_embeds_2=None, prompt_attention_mask_2=None, negative_prompt_attention_mask_2=None, callback_on_step_end_tensor_inputs=None, ): if (num_frames - 1) % 4 != 0: raise ValueError(f"`num_frames - 1` have to be divisible by 4 but is {num_frames}.") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is None and prompt_embeds_2 is None: raise ValueError( "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if prompt_embeds is not None and prompt_attention_mask is None: raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: raise ValueError( "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: raise ValueError( "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" f" {negative_prompt_embeds_2.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, (int(num_frames) - 1) // self.vae.vae_scale_factor[0] + 1, int(height) // self.vae.vae_scale_factor[1], int(width) // self.vae.vae_scale_factor[2], ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @property def guidance_scale(self): return self._guidance_scale @property def guidance_rescale(self): return self._guidance_rescale @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, num_frames: Optional[int] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: Optional[int] = 50, timesteps: List[int] = None, guidance_scale: Optional[float] = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_samples_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, prompt_embeds_2: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds_2: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, prompt_attention_mask_2: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask_2: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], guidance_rescale: float = 0.0, max_sequence_length: int = 512, device = None, ): if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. default height and width num_frames = num_frames or (self.transformer.config.sample_size_t - 1) * self.vae.vae_scale_factor[0] + 1 height = height or self.transformer.config.sample_size[0] * self.vae.vae_scale_factor[1] width = width or self.transformer.config.sample_size[1] * self.vae.vae_scale_factor[2] # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, num_frames, height, width, negative_prompt, prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask, prompt_embeds_2, negative_prompt_embeds_2, prompt_attention_mask_2, negative_prompt_attention_mask_2, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = device or getattr(self, '_execution_device', None) or getattr(self, 'device', None) or torch.device('cuda') # 3. Encode input prompt ( prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask, ) = self.encode_prompt( prompt=prompt, device=device, dtype=self.transformer.dtype, num_samples_per_prompt=num_samples_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, max_sequence_length=max_sequence_length, text_encoder_index=0, ) if self.tokenizer_2 is not None: ( prompt_embeds_2, negative_prompt_embeds_2, prompt_attention_mask_2, negative_prompt_attention_mask_2, ) = self.encode_prompt( prompt=prompt, device=device, dtype=self.transformer.dtype, num_samples_per_prompt=num_samples_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds_2, negative_prompt_embeds=negative_prompt_embeds_2, prompt_attention_mask=prompt_attention_mask_2, negative_prompt_attention_mask=negative_prompt_attention_mask_2, max_sequence_length=77, text_encoder_index=1, ) else: prompt_embeds_2 = None negative_prompt_embeds_2 = None prompt_attention_mask_2 = None negative_prompt_attention_mask_2 = None # 4. Prepare timesteps if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) else: timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) # 5. Prepare latent variables if get_sequence_parallel_state(): world_size = hccl_info.world_size if torch_npu is not None else nccl_info.world_size num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_samples_per_prompt, num_channels_latents, (num_frames + world_size - 1) // world_size if get_sequence_parallel_state() else num_frames, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) else: extra_step_kwargs = {} # 7 create image_rotary_emb, style embedding & time ids if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) if self.tokenizer_2 is not None: prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) prompt_embeds = prompt_embeds.to(device=device) prompt_attention_mask = prompt_attention_mask.to(device=device) if self.tokenizer_2 is not None: prompt_embeds_2 = prompt_embeds_2.to(device=device) prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) # ==================make sp===================================== if get_sequence_parallel_state(): prompt_embeds = rearrange( prompt_embeds, 'b (n x) h -> b n x h', n=world_size, x=prompt_embeds.shape[1] // world_size ).contiguous() rank = hccl_info.rank if torch_npu is not None else nccl_info.rank prompt_embeds = prompt_embeds[:, rank, :, :] # ==================make sp===================================== # 8. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): timestep = torch.tensor([t] * latent_model_input.shape[0], device=device).to( dtype=latent_model_input.dtype ) else: timestep = t.expand(latent_model_input.shape[0]) # ==================prepare my shape===================================== # predict the noise residual if prompt_embeds.ndim == 3: prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d if prompt_attention_mask.ndim == 2: prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l if prompt_embeds_2 is not None and prompt_embeds_2.ndim == 2: prompt_embeds = prompt_embeds.unsqueeze(1) # b d -> b 1 d attention_mask = torch.ones_like(latent_model_input)[:, 0].to(device=device) # ==================prepare my shape===================================== # ==================make sp===================================== if get_sequence_parallel_state(): attention_mask = attention_mask.repeat(1, world_size, 1, 1) # ==================make sp===================================== noise_pred = self.transformer( latent_model_input, attention_mask=attention_mask, encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, timestep=timestep, pooled_projections=prompt_embeds_2, return_dict=False, )[0] assert not torch.any(torch.isnan(noise_pred)) # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and guidance_rescale > 0.0 and not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) negative_prompt_embeds_2 = callback_outputs.pop( "negative_prompt_embeds_2", negative_prompt_embeds_2 ) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() # ==================make sp===================================== if get_sequence_parallel_state(): latents_shape = list(latents.shape) # b c t//sp h w full_shape = [latents_shape[0] * world_size] + latents_shape[1:] # # b*sp c t//sp h w all_latents = torch.zeros(full_shape, dtype=latents.dtype, device=latents.device) torch.distributed.all_gather_into_tensor(all_latents, latents) latents_list = list(all_latents.chunk(world_size, dim=0)) latents = torch.cat(latents_list, dim=2) # ==================make sp===================================== if not output_type == "latent": videos = self.decode_latents(latents) videos = videos[:, :num_frames, :height, :width] else: videos = latents # Offload all models self.maybe_free_model_hooks() if not return_dict: return (videos, ) return OpenSoraPipelineOutput(videos=videos) def decode_latents(self, latents): print(f'before vae decode {latents.shape}', torch.max(latents).item(), torch.min(latents).item(), torch.mean(latents).item(), torch.std(latents).item()) video = self.vae.decode(latents.to(self.vae.vae.dtype)) print(f'after vae decode {latents.shape}', torch.max(video).item(), torch.min(video).item(), torch.mean(video).item(), torch.std(video).item()) video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().permute(0, 1, 3, 4, 2).contiguous() # b t h w c return video ================================================ FILE: opensora/sample/rec_image.py ================================================ import sys sys.path.append(".") from PIL import Image import torch from torchvision.transforms import ToTensor, Compose, Resize, Normalize, Lambda from torch.nn import functional as F import argparse import numpy as np from opensora.models.causalvideovae import ae_wrapper def preprocess(video_data: torch.Tensor, short_size: int = 128) -> torch.Tensor: transform = Compose( [ ToTensor(), Lambda(lambda x: 2. * x - 1.), Resize(size=short_size), ] ) outputs = transform(video_data) outputs = outputs.unsqueeze(0).unsqueeze(2) return outputs def main(args: argparse.Namespace): image_path = args.image_path short_size = args.short_size device = args.device kwarg = {} # vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir', **kwarg).to(device) vae = ae_wrapper[args.ae](args.ae_path, **kwarg).eval().to(device) if args.enable_tiling: vae.vae.enable_tiling() vae.vae.tile_overlap_factor = args.tile_overlap_factor vae.eval() vae = vae.to(device) vae = vae.half() with torch.no_grad(): x_vae = preprocess(Image.open(image_path), short_size) x_vae = x_vae.to(device, dtype=torch.float16) # b c t h w latents = vae.encode(x_vae) latents = latents.to(torch.float16) image_recon = vae.decode(latents) # b t c h w x = image_recon[0, 0, :, :, :] x = x.squeeze() x = x.detach().cpu().numpy() x = np.clip(x, -1, 1) x = (x + 1) / 2 x = (255*x).astype(np.uint8) x = x.transpose(1,2,0) image = Image.fromarray(x) image.save(args.rec_path) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--image_path', type=str, default='') parser.add_argument('--rec_path', type=str, default='') parser.add_argument('--ae', type=str, default='') parser.add_argument('--ae_path', type=str, default='') parser.add_argument('--model_path', type=str, default='results/pretrained') parser.add_argument('--short_size', type=int, default=336) parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--tile_overlap_factor', type=float, default=0.25) parser.add_argument('--enable_tiling', action='store_true') args = parser.parse_args() main(args) ================================================ FILE: opensora/sample/rec_video.py ================================================ import math import random import argparse from typing import Optional import cv2 import numpy as np import numpy.typing as npt import torch from PIL import Image from decord import VideoReader, cpu from torch.nn import functional as F from pytorchvideo.transforms import ShortSideScale from torchvision.transforms import Lambda, Compose import sys from opensora.models.causalvideovae import ae_wrapper from opensora.dataset.transform import ToTensorVideo, CenterCropResizeVideo def array_to_video(image_array: npt.NDArray, fps: float = 30.0, output_file: str = 'output_video.mp4') -> None: height, width, channels = image_array[0].shape fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height)) for image in image_array: image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) video_writer.write(image_rgb) video_writer.release() def custom_to_video(x: torch.Tensor, fps: float = 2.0, output_file: str = 'output_video.mp4') -> None: x = x.detach().cpu() x = torch.clamp(x, -1, 1) x = (x + 1) / 2 x = x.permute(0, 2, 3, 1).numpy() x = (255 * x).astype(np.uint8) array_to_video(x, fps=fps, output_file=output_file) return def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor: decord_vr = VideoReader(video_path, ctx=cpu(0)) total_frames = len(decord_vr) sample_frames_len = sample_rate * num_frames # if total_frames > sample_frames_len: # s = random.randint(0, total_frames - sample_frames_len - 1) # s = 0 # e = s + sample_frames_len # num_frames = num_frames # else: # s = 0 # e = total_frames # num_frames = int(total_frames / sample_frames_len * num_frames) s = 0 e = sample_frames_len print(f'sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}', video_path, total_frames) frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int) video_data = decord_vr.get_batch(frame_id_list).asnumpy() video_data = torch.from_numpy(video_data) video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W) return video_data def preprocess(video_data: torch.Tensor, height: int = 128, width: int = 128) -> torch.Tensor: transform = Compose( [ ToTensorVideo(), CenterCropResizeVideo((height, width)), Lambda(lambda x: 2. * x - 1.) ] ) video_outputs = transform(video_data) video_outputs = torch.unsqueeze(video_outputs, 0) return video_outputs def main(args: argparse.Namespace): device = args.device kwarg = {} # vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir', **kwarg).to(device) # vae = CausalVAEModelWrapper(args.ae_path, **kwarg).to(device) vae = ae_wrapper[args.ae](args.ae_path, **kwarg).eval().to(device) if args.enable_tiling: vae.vae.enable_tiling() vae.vae.tile_overlap_factor = args.tile_overlap_factor # vae.vae.tile_sample_min_size = 512 # vae.vae.tile_latent_min_size = 64 # vae.vae.tile_sample_min_size_t = 29 # vae.vae.tile_latent_min_size_t = 8 # if args.save_memory: # vae.vae.tile_sample_min_size = 256 # vae.vae.tile_latent_min_size = 32 # vae.vae.tile_sample_min_size_t = 9 # vae.vae.tile_latent_min_size_t = 3 dtype = torch.float32 vae.eval() vae = vae.to(device, dtype=dtype) with torch.no_grad(): x_vae = preprocess(read_video(args.video_path, args.num_frames, args.sample_rate), args.height, args.width) print(x_vae.shape) x_vae = x_vae.to(device, dtype=dtype) # b c t h w # for i in range(10000): latents = vae.encode(x_vae) print(latents.shape) latents = latents.to(dtype) video_recon = vae.decode(latents) # b t c h w print(video_recon.shape) # vae = vae.half() # from tqdm import tqdm # with torch.no_grad(): # x_vae = torch.rand(1, 3, 93, 720, 1280) # print(x_vae.shape) # x_vae = x_vae.to(device, dtype=torch.float16) # b c t h w # # x_vae = x_vae.to(device) # b c t h w # for i in tqdm(range(100000)): # latents = vae.encode(x_vae) # print(latents.shape) # latents = latents.to(torch.float16) # video_recon = vae.decode(latents) # b t c h w # print(video_recon.shape) custom_to_video(video_recon[0], fps=args.fps, output_file=args.rec_path) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--video_path', type=str, default='') parser.add_argument('--rec_path', type=str, default='') parser.add_argument('--ae', type=str, default='') parser.add_argument('--ae_path', type=str, default='') parser.add_argument('--model_path', type=str, default='results/pretrained') parser.add_argument('--fps', type=int, default=30) parser.add_argument('--height', type=int, default=336) parser.add_argument('--width', type=int, default=336) parser.add_argument('--num_frames', type=int, default=100) parser.add_argument('--sample_rate', type=int, default=1) parser.add_argument('--device', type=str, default="cuda") parser.add_argument('--tile_overlap_factor', type=float, default=0.25) parser.add_argument('--tile_sample_min_size', type=int, default=512) parser.add_argument('--tile_sample_min_size_t', type=int, default=33) parser.add_argument('--tile_sample_min_size_dec', type=int, default=256) parser.add_argument('--tile_sample_min_size_dec_t', type=int, default=33) parser.add_argument('--enable_tiling', action='store_true') parser.add_argument('--save_memory', action='store_true') args = parser.parse_args() main(args) ================================================ FILE: opensora/sample/sample.py ================================================ import os import torch try: import torch_npu from opensora.npu_config import npu_config except: torch_npu = None npu_config = None pass from opensora.utils.sample_utils import ( init_gpu_env, init_npu_env, prepare_pipeline, get_args, run_model_and_save_samples, run_model_and_save_samples_npu ) from opensora.sample.caption_refiner import OpenSoraCaptionRefiner if __name__ == "__main__": args = get_args() dtype = torch.float16 if torch_npu is not None: npu_config.print_msg(args) npu_config.conv_dtype = dtype init_npu_env(args) else: args = init_gpu_env(args) device = torch.cuda.current_device() if args.num_frames != 1 and args.enhance_video is not None: from opensora.sample.VEnhancer.enhance_a_video import VEnhancer enhance_video_model = VEnhancer(model_path=args.enhance_video, version='v2', device=device) else: enhance_video_model = None pipeline = prepare_pipeline(args, dtype, device) if args.caption_refiner is not None: caption_refiner_model = OpenSoraCaptionRefiner(args, dtype, device) else: caption_refiner_model = None if npu_config is not None and npu_config.on_npu and npu_config.profiling: run_model_and_save_samples_npu(args, pipeline, caption_refiner_model, enhance_video_model) else: run_model_and_save_samples(args, pipeline, caption_refiner_model, enhance_video_model) ================================================ FILE: opensora/serve/gradio_utils.py ================================================ import random import imageio import uuid import torch import numpy as np POS_PROMPT = """ high quality, high aesthetic, {} """ NEG_PROMPT = """ nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry. """ NUM_IMAGES_PER_PROMPT = 1 MAX_SEED = np.iinfo(np.int32).max def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: if randomize_seed: seed = random.randint(0, MAX_SEED) return seed LOGO = """
Open-Sora Plan logo
""" TITLE = """
Open-Sora Plan🤗
""" DESCRIPTION = """
Support Chinese and English; 支持中英双语
Welcome to Star🌟 our GitHub
""" t2v_prompt_examples = [ "动画场景特写中,一个矮小、毛茸茸的怪物跪在一根融化的红蜡烛旁。三维写实的艺术风格注重光照和纹理的相互作用,在整个场景中投射出引人入胜的阴影。怪物睁着好奇的大眼睛注视着火焰,它的皮毛在温暖闪烁的光芒中轻轻拂动。镜头慢慢拉近,捕捉到怪物皮毛的复杂细节和精致的熔蜡液滴。怪物试探性地伸出一只爪子,似乎想要触碰火焰,而烛光则在它周围闪烁舞动,气氛充满了惊奇和好奇。", "An animated scene features a close-up of a short, fluffy monster kneeling beside a melting red candle. The 3D, realistic art style focuses on the interplay of lighting and texture, casting intriguing shadows across the scene. The monster gazes at the flame with wide, curious eyes, its fur gently ruffling in the warm, flickering glow. The camera slowly zooms in, capturing the intricate details of the monster's fur and the delicate, molten wax droplets. The atmosphere is filled with a sense of wonder and curiosity, as the monster tentatively reaches out a paw, as if to touch the flame, while the candlelight dances and flickers around it.", "特写镜头捕捉到一只维多利亚皇冠鸽,其醒目的蓝色羽毛和鲜艳的红色胸部格外显眼。这只鸽子精致的花边鸽冠和醒目的红眼更增添了它的威严。鸽子的头部略微偏向一侧,给人一种威严的感觉。背景被模糊处理,使人们的注意力集中在鸽子引人注目的特征上。柔和的光线洒在画面上,投下柔和的阴影,增强了鸽子羽毛的质感。鸽子微微扇动翅膀,嘴角向上翘起,似乎在好奇地观察周围的环境,营造出一种动感迷人的氛围。", "A close-up shot captures a Victoria crowned pigeon, its striking blue plumage and vibrant red chest standing out prominently. The bird's delicate, lacy crest and striking red eye add to its regal appearance. The pigeon's head is tilted slightly to the side, giving it a majestic look. The background is blurred, drawing attention to the bird's striking features. Soft light bathes the scene, casting gentle shadows that enhance the texture of its feathers. The pigeon flutters its wings slightly, and its beak tilts upwards, as if curiously observing the surroundings, creating a dynamic and captivating atmosphere.", "一架无人机捕捉到了大苏尔加雷角海滩上海浪拍打着崎岖悬崖的壮丽景色。湛蓝的海水拍打出白色的浪花,夕阳的金光照亮了岩石海岸,投下长长的阴影,营造出温暖宁静的氛围。远处矗立着一座小岛,岛上有一座灯塔,更增添了画面的魅力。海鸥在头顶上滑翔,海风吹过附近的植被,沙沙作响,给宁静的海岸景观带来了勃勃生机。", "A drone captures a breathtaking view of waves crashing against the rugged cliffs along Big Sur's Garay Point Beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore, casting long shadows and creating a warm, serene atmosphere. A small island with a lighthouse stands in the distance, adding to the scene's charm. Seagulls glide overhead as the ocean breeze rustles through the nearby vegetation, bringing life to the tranquil coastal landscape.", "一个二十岁出头的年轻人,头发蓬松,鼻梁上架着一副眼镜,安详地坐在高高飘扬的蓬松白云上。他全神贯注地读着一本书,偶尔抬起头看一眼周围翱翔的鸟儿。阳光透过飘渺的云层,在这幅画面上洒下柔和的金色光芒,并在他的脸上投下俏皮的影子。当他翻开书页时,一阵微风吹过,书页沙沙作响,他微笑着,感受着失重和自由的快感。", "A young man in his early twenties, with tousled hair and a pair of glasses perched on the end of his nose, sits serenely on a fluffy, white cloud floating high in the sky. He is engrossed in a book, occasionally glancing up to watch the birds soar around him. The sunlight filters through the wispy clouds, casting a soft, golden glow over the scene and creating playful shadows that dance on his face. As he turns a page, a gentle breeze rustles the pages, and he smiles, feeling the thrill of weightlessness and freedom.", "三维动画描绘了一只圆滚滚、毛茸茸的小动物,它有一双富于表情的大眼睛,正在探索一片生机勃勃的魔法森林。这个异想天开的生物是兔子和松鼠的混合体,长着柔软的蓝色皮毛和浓密的条纹尾巴。它沿着波光粼粼的溪流蹦蹦跳跳,眼睛睁得大大的,充满了好奇。森林里充满了神奇的元素:会发光和变色的花朵、长着紫色和银色树叶的树木,还有像萤火虫一样的小浮光。它跳着跳着,停了下来,与一群围着蘑菇圈跳舞的小精灵嬉戏互动。然后,它抬头敬畏地看着一棵发光的大树,这棵树似乎是森林的核心。摄像机平稳地摇镜头,捕捉到这只小动物好奇地伸手触摸一朵发光的花朵,花朵随之变色。整个场景沐浴在柔和、空灵的光线中,背景中的阴影轻轻舞动,营造出一种令人陶醉和惊奇的氛围。小动物的嬉戏打闹和神奇的氛围让森林变得生机勃勃,仿佛每一刻都是一次发现和喜悦。", "A 3D animation depicts a small, round, fluffy creature with big, expressive eyes exploring a vibrant, enchanted forest. This whimsical creature, a blend of a rabbit and a squirrel, has soft blue fur and a bushy, striped tail. It hops along a sparkling stream, its eyes wide with wonder. The forest is alive with magical elements: flowers that glow and change colors, trees with leaves in shades of purple and silver, and small floating lights that resemble fireflies. As the creature hops, it pauses to interact playfully with a group of tiny, fairy-like beings dancing around a mushroom ring. It then looks up in awe at a large, glowing tree that seems to be the heart of the forest. The camera pans smoothly to capture the creature's curiosity as it reaches out to touch a glowing flower, causing it to change colors. The scene is bathed in a soft, ethereal light, with shadows dancing gently in the background, creating an atmosphere of enchantment and wonder. The creature's playful antics and the magical ambiance make the forest come alive, as if every moment is a discovery and a delight.", "一架无人机优雅地环绕着阿马尔菲海岸崎岖不平的山顶上一座历史悠久的教堂,拍摄其宏伟的建筑细节以及层层叠叠的小径和天井。下方,海浪拍打着岩石,地平线延伸至意大利的沿海水域和丘陵地貌。远处的身影在天井中漫步,欣赏着壮丽的海景,营造出一幅动感十足的画面。午后和煦的阳光让整个场景沐浴在神奇而浪漫的光影中,投下长长的阴影,为迷人的景色增添了深度。镜头不时拉近以突出教堂错综复杂的细节,然后拉远以展示广阔的海岸线,营造出引人入胜的视觉叙事效果。", "A drone camera gracefully circles a historic church perched on a rugged outcropping along the Amalfi Coast, capturing its magnificent architectural details and tiered pathways and patios. Below, waves crash against the rocks, while the horizon stretches out over the coastal waters and hilly landscapes of Italy. Distant figures stroll and enjoy the breathtaking ocean views from the patios, creating a dynamic scene. The warm glow of the afternoon sun bathes the scene in a magical and romantic light, casting long shadows and adding depth to the stunning vista. The camera occasionally zooms in to highlight the intricate details of the church, then pans out to showcase the expansive coastline, creating a captivating visual narrative.", "一个特写镜头捕捉到一位 60 多岁、留着胡子的白发老人,他坐在巴黎的一家咖啡馆里陷入沉思,思考着宇宙的历史。他的眼睛紧紧盯着屏幕外走动的人们,而自己却一动不动。他身着羊毛大衣、纽扣衬衫、棕色贝雷帽,戴着一副眼镜,散发着教授的风范。他偶尔瞥一眼四周,目光停留在背景中熙熙攘攘的巴黎街道和城市景观上。场景沐浴在金色的光线中,让人联想到 35 毫米电影胶片。当他微微前倾时,眼睛睁大,露出顿悟的瞬间,并微微闭口微笑,暗示他已经找到了生命奥秘的答案。景深营造出光影交错的动态效果,烘托出智慧沉思的氛围。", "An extreme close-up captures a gray-haired man with a beard in his 60s, deep in thought as he sits at a Parisian cafe, contemplating the history of the universe. His eyes focus intently on people walking offscreen, while he remains mostly motionless. Dressed in a wool coat, a button-down shirt, a brown beret, and glasses, he exudes a professorial demeanor. The man occasionally glances around, his gaze lingering on the bustling Parisian streets and cityscape in the background. The scene is bathed in golden light, reminiscent of a cinematic 35mm film. As he leans forward slightly, his eyes widen in a moment of epiphany, and he offers a subtle, closed-mouth smile, suggesting he has found the answer to the mystery of life. The depth of field creates a dynamic interplay of light and shadow, enhancing the atmosphere of intellectual contemplation.", "一只欢快的水獭穿着明黄色的救生衣,自信地在冲浪板上保持平衡,在郁郁葱葱的热带岛屿附近波光粼粼的绿松石水域中滑行。该场景采用三维数字艺术风格渲染,阳光在水面上投下俏皮的阴影。水獭不时将爪子伸入水中,溅起的水珠捕捉到光线,为宁静的氛围增添了动感和刺激。", "A cheerful otter confidently balances on a surfboard, donning a bright yellow lifejacket, as it glides through the shimmering turquoise waters near lush tropical islands. The scene is rendered in a 3D digital art style, with the sunlight casting playful shadows on the water's surface. The otter occasionally dips its paws into the water, sending up sprays of droplets that catch the light, adding a sense of motion and excitement to the tranquil atmosphere.", "在这幅迷人的特写镜头中,一只变色龙展示了它非凡的变色能力,在柔和的散射光中,它鲜艳的色调微妙地变换着。模糊的背景凸显了变色龙醒目的外表,而光影的交错则突出了变色龙皮肤的复杂细节。", "In this captivating close-up shot, a chameleon displays its remarkable color-changing abilities, its vibrant hues shifting subtly in the soft, diffused light. The blurred background highlights the animal's striking appearance, while the interplay of light and shadow accentuates the intricate details of its skin.", "圣托里尼在蓝色时刻的壮丽鸟瞰图捕捉到了白色基克拉迪建筑与蓝色圆顶的迷人建筑,在黄昏的天空中投射出长长的阴影。火山口的景色令人惊叹,光与影的交织营造出宁静的氛围。当太阳落到地平线以下时,夕阳的余晖将整个场景笼罩在温暖的金色中,海鸥在空中优雅地翱翔,几艘帆船在下方的火山口悠闲地漂流。", "A breathtaking aerial view of Santorini during the blue hour captures the stunning architecture of white Cycladic buildings with blue domes, casting long shadows against the twilight sky. The caldera views are awe-inspiring, with the interplay of light and shadow creating a serene atmosphere. As the sun dips below the horizon, the gentle glow of the setting sun bathes the scene in a warm, golden hue, while seagulls soar gracefully through the air and a few sailboats drift lazily in the caldera below.", "一群羊驼在鲜艳的涂鸦墙前自信地摆着姿势,每只羊驼都穿着五颜六色的羊毛针织衫,戴着时尚的太阳镜。在正午明媚的阳光下,它们嬉戏互动,有的好奇地东张西望,有的则亲昵地偎依在一起。光与影的鲜明对比增强了这一场景的动感活力,营造出一种融合了都市前卫与奇异魅力的氛围。", "A group of alpacas, each donning colorful knit wool sweaters and stylish sunglasses, pose confidently against a vibrant graffiti-covered wall. Under the bright midday sun, they interact playfully with one another, some glancing around curiously while others nuzzle affectionately. The scene's dynamic energy is heightened by the stark interplay of light and shadow, creating an atmosphere that blends urban edginess with whimsical charm.", "一只充满活力的动画兔子,身穿俏皮的粉色滑雪服,在湛蓝的天空下,熟练地从积雪的山坡上滑下。兔子充满活力地跳跃和旋转,在闪闪发光的雪地上投下动态阴影,而阳光的明亮光线则凸显了闪闪发光的景观,营造出一种欢快的氛围。当兔子下降时,它的流畅动作被广角镜头捕捉到,增加了速度感和刺激感。", "A vibrant animated rabbit, dressed in a playful pink snowboarding outfit, expertly carves its way down a snowy mountain slope under a clear blue sky. The rabbit performs energetic jumps and spins, casting dynamic shadows on the glistening snow, while the sun's bright rays highlight the sparkling landscape, creating an atmosphere of joyful exhilaration. As the rabbit descends, its fluid motions are captured in a sweeping camera angle, adding to the sense of speed and excitement.", "食物镜头,完美的汉堡,配上奶酪和生菜,微距拍摄,旋转拍摄,推拉镜头", "food shot, a perfect burger in a bun with cheese and lettuce, macro shot, rotating shot, dolly in", "这幅肖像画描绘了一只长着蓝眼睛的橘色猫,缓缓旋转,灵感来自维米尔的《戴珍珠耳环的少女》。这只猫戴着珍珠耳环,棕色的皮毛像荷兰帽一样,背景为黑色,在工作室灯光的映衬下显得格外明亮。", "This portrait depicts an orange cat with blue eyes, slowly rotating, inspired by Vermeer ’s ’Girl with a Pearl Earring’. The cat is adorned with pearl earrings and has brown fur styled like a Dutch cap against a black background, illuminated by studio lighting.", "一只熊猫在竹林下弹奏吉他,它的爪子轻轻拨动琴弦,一群着迷的兔子观看着,音乐与竹叶的沙沙声融为一体。高清。", "A panda strumming a guitar under a bamboo grove, its paws gently plucking the strings as a group of mesmerized rabbits watch, the music blending with the rustle of bamboo leaves. HD.", "雪花玻璃球摇晃后,会呈现出一座微型城市,雪花实际上是闪闪发光的星星。建筑物亮起,反射着天上的雪花,微小的人影在街道上移动,他们的路径被柔和的星光照亮,营造出神奇、宁静的城市景观。高清。", "A snow globe, when shaken, reveals a miniature city where the snowflakes are actually glowing stars. The buildings light up, reflecting the celestial snowfall, and tiny figures move through the streets, their paths illuminated by the gentle starlight, creating a magical, peaceful urban landscape. HD.", "魔术师水晶球的特写,展现了水晶球内部的未来城市景观。摩天大楼的光影直冲云霄,飞行汽车在空中飞驰,在水晶球表面投射出霓虹灯的反光。8K。", "A close-up of a magician’s crystal ball that reveals a futuristic cityscape within. Skyscrapers of light stretch towards the heavens, and flying cars zip through the air, casting neon reflections across the ball’s surface. 8K.", ] style_list = [ { "name": "(Default)", "prompt": "(masterpiece), (best quality), (ultra-detailed), (unwatermarked), {prompt}", "negative_prompt": NEG_PROMPT, }, { "name": "Cinematic", "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured. ", }, { "name": "Photographic", "prompt": "cinematic photo, a close-up of {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed", "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly. ", }, { "name": "Anime", "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed", "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast. ", }, { "name": "Manga", "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style", "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style. ", }, { "name": "Digital Art", "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed", "negative_prompt": "photo, photorealistic, realism, ugly. ", }, { "name": "Pixel art", "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics", "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic. ", }, { "name": "Fantasy art", "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy", "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white. ", }, { "name": "Neonpunk", "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional", "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured. ", }, { "name": "3D Model", "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting", "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting. ", }, ] ================================================ FILE: opensora/serve/gradio_web_server.py ================================================ import gradio as gr import os import torch from einops import rearrange import torch.distributed as dist from torchvision.utils import save_image import imageio import math import argparse import random import numpy as np import string from opensora.sample.caption_refiner import OpenSoraCaptionRefiner from opensora.utils.sample_utils import ( prepare_pipeline, save_video_grid, init_gpu_env ) from .gradio_utils import * @torch.no_grad() @torch.inference_mode() def generate( prompt: str, seed: int = 0, num_frames: int = 29, num_samples: int = 1, guidance_scale: float = 4.5, num_inference_steps: int = 25, randomize_seed: bool = False, progress=gr.Progress(track_tqdm=False), ): seed = int(randomize_seed_fn(seed, randomize_seed)) if seed is not None: torch.manual_seed(seed) if not os.path.exists(args.save_img_path): os.makedirs(args.save_img_path, exist_ok=True) video_grids = [] text_prompt = [prompt] for index, prompt in enumerate(text_prompt): if caption_refiner_model is not None: refine_prompt = caption_refiner_model.get_refiner_output(prompt) print(f'\nOrigin prompt: {prompt}\n->\nRefine prompt: {refine_prompt}') prompt = refine_prompt input_prompt = POS_PROMPT.format(prompt) videos = pipeline( input_prompt, negative_prompt=NEG_PROMPT, num_frames=num_frames, height=352, width=640, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_samples_per_prompt=num_samples, max_sequence_length=512, device=device, ).videos if num_frames != 1 and enhance_video_model is not None: # b t h w c videos = enhance_video_model.enhance_a_video(videos, input_prompt, 2.0, args.fps, 250) if num_frames == 1: videos = rearrange(videos, 'b t h w c -> (b t) c h w') if num_samples != 1: for i, image in enumerate(videos): save_image( image / 255.0, os.path.join( args.save_img_path, f'{args.sample_method}_{index}_gs{guidance_scale}_s{num_inference_steps}_i{i}.jpg' ), nrow=math.ceil(math.sqrt(videos.shape[0])), normalize=True, value_range=(0, 1) ) # b c h w save_image( videos / 255.0, os.path.join( args.save_img_path, f'{args.sample_method}_{index}_gs{guidance_scale}_s{num_inference_steps}.jpg' ), nrow=math.ceil(math.sqrt(videos.shape[0])), normalize=True, value_range=(0, 1) ) # b c h w else: if num_samples == 1: imageio.mimwrite( os.path.join( args.save_img_path, f'{args.sample_method}_{index}_gs{guidance_scale}_s{num_inference_steps}.mp4' ), videos[0], fps=args.fps, quality=6 ) # highest quality is 10, lowest is 0 else: for i in range(num_samples): imageio.mimwrite( os.path.join( args.save_img_path, f'{args.sample_method}_{index}_gs{guidance_scale}_s{num_inference_steps}_i{i}.mp4' ), videos[i], fps=args.fps, quality=6 ) # highest quality is 10, lowest is 0 videos = save_video_grid(videos) imageio.mimwrite( os.path.join( args.save_img_path, f'{args.sample_method}_{index}_gs{guidance_scale}_s{num_inference_steps}.mp4' ), videos, fps=args.fps, quality=6 ) # highest quality is 10, lowest is 0) videos = videos.unsqueeze(0) # 1 t h w c video_grids.append(videos) video_grids = torch.cat(video_grids, dim=0) final_path = os.path.join( args.save_img_path, f'{args.sample_method}_gs{guidance_scale}_s{num_inference_steps}' ) random_string = ''.join(random.choices(string.ascii_letters, k=4)) if num_frames == 1: final_path = final_path + f'_{random_string}.jpg' save_image( video_grids / 255.0, final_path, nrow=math.ceil(math.sqrt(len(video_grids))), normalize=True, value_range=(0, 1) ) else: video_grids = save_video_grid(video_grids) final_path = final_path + f'_{random_string}.mp4' imageio.mimwrite( final_path, video_grids, fps=args.fps, quality=6 ) print('save path {}'.format(args.save_img_path)) return final_path, seed parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default='LanguageBind/Open-Sora-Plan-v1.0.0') parser.add_argument("--version", type=str, default='v1_3', choices=['v1_3', 'v1_5']) parser.add_argument("--caption_refiner", type=str, default=None) parser.add_argument("--ae", type=str, default='CausalVAEModel_4x8x8') parser.add_argument("--ae_path", type=str, default='CausalVAEModel_4x8x8') parser.add_argument("--text_encoder_name_1", type=str, default='DeepFloyd/t5-v1_1-xxl') parser.add_argument("--text_encoder_name_2", type=str, default=None) parser.add_argument("--save_img_path", type=str, default="./test_gradio") parser.add_argument("--fps", type=int, default=18) parser.add_argument('--enable_tiling', action='store_true') parser.add_argument('--save_memory', action='store_true') parser.add_argument('--compile', action='store_true') parser.add_argument("--gradio_port", type=int, default=11900) parser.add_argument("--local_rank", type=int, default=0) parser.add_argument("--enhance_video", type=str, default=None) parser.add_argument("--model_type", type=str, default='t2v') parser.add_argument("--cache_dir", type=str, default="cache_dir") parser.add_argument("--prediction_type", type=str, default="v_prediction") parser.add_argument('--v1_5_scheduler', action='store_true') parser.add_argument('--sample_method', type=str, default='EulerAncestralDiscrete') args = parser.parse_args() args.sp = False args.rescale_betas_zero_snr = True dtype = torch.bfloat16 # args = init_gpu_env(args) device = torch.cuda.current_device() if args.enhance_video is not None: from opensora.sample.VEnhancer.enhance_a_video import VEnhancer enhance_video_model = VEnhancer(model_path=args.enhance_video, version='v2', device=device) else: enhance_video_model = None pipeline = prepare_pipeline(args, dtype, device) if args.caption_refiner is not None: caption_refiner_model = OpenSoraCaptionRefiner(args, dtype, device) else: caption_refiner_model = None with gr.Blocks(css="style.css") as demo: gr.Markdown(LOGO) gr.Markdown(TITLE) gr.Markdown(DESCRIPTION) with gr.Row(equal_height=False): with gr.Group(): with gr.Row(): seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): num_frames = gr.Slider( label="Num Frames", minimum=1, maximum=93, step=16, value=29, ) num_samples = gr.Slider( label="Num Samples", minimum=1, maximum=4, step=1, value=1, ) with gr.Row(): guidance_scale = gr.Slider( label="Guidance scale", minimum=1, maximum=10, step=0.1, value=7.5, ) inference_steps = gr.Slider( label="Inference steps", minimum=10, maximum=200, step=1, value=50, ) with gr.Group(): with gr.Row(): prompt = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, ) run_button = gr.Button("Run", scale=0) result = gr.Video(autoplay=True, label="Result") # result = gr.Gallery(label="Result", columns=NUM_IMAGES_PER_PROMPT, show_label=False) with gr.Row(), gr.Column(): gr.Markdown("## Examples (Text-to-Video)") examples = [[i, 42, 93, 1, 7.5, 100, True] for i in t2v_prompt_examples] gr.Examples( examples=examples, inputs=[ prompt, seed, num_frames, num_samples, guidance_scale, inference_steps, randomize_seed ], label='Text-to-Video', cache_examples=False, outputs=[result, seed], fn=generate ) gr.on( triggers=[ prompt.submit, run_button.click, ], fn=generate, inputs=[ prompt, seed, num_frames, num_samples, guidance_scale, inference_steps, randomize_seed, ], outputs=[result, seed], api_name="run", ) # if __name__ == "__main__": demo.queue(max_size=20).launch( server_name="0.0.0.0", server_port=args.gradio_port+args.local_rank, debug=True ) ================================================ FILE: opensora/serve/gradio_web_server_i2v.py ================================================ import gradio as gr import os import torch from einops import rearrange import torch.distributed as dist from torchvision.utils import save_image import imageio import math import argparse import random import numpy as np import string from opensora.sample.caption_refiner import OpenSoraCaptionRefiner from opensora.utils.sample_utils import ( prepare_pipeline, save_video_grid, init_gpu_env ) from .gradio_utils import * @torch.no_grad() @torch.inference_mode() def generate( prompt: str, image_1: str, image_2: str = None, seed: int = 0, num_frames: int = 29, num_samples: int = 1, guidance_scale: float = 4.5, num_inference_steps: int = 25, randomize_seed: bool = False, progress=gr.Progress(track_tqdm=True), ): seed = int(randomize_seed_fn(seed, randomize_seed)) if seed is not None: torch.manual_seed(seed) if not os.path.exists(args.save_img_path): os.makedirs(args.save_img_path, exist_ok=True) video_grids = [] text_prompt = [prompt] images = [[image_1] if image_2 is None else [image_1, image_2]] for index, (image, prompt) in enumerate(zip(images, text_prompt)): if caption_refiner_model is not None: refine_prompt = caption_refiner_model.get_refiner_output(prompt) print(f'\nOrigin prompt: {prompt}\n->\nRefine prompt: {refine_prompt}') prompt = refine_prompt input_prompt = POS_PROMPT.format(prompt) print(image) videos = pipeline( conditional_images=image, prompt=input_prompt, negative_prompt=NEG_PROMPT, num_frames=num_frames, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_samples_per_prompt=num_samples, max_sequence_length=512, device=device, ).videos if num_frames != 1 and enhance_video_model is not None: # b t h w c videos = enhance_video_model.enhance_a_video(videos, input_prompt, 2.0, args.fps, 250) if num_frames == 1: videos = rearrange(videos, 'b t h w c -> (b t) c h w') if num_samples != 1: for i, image in enumerate(videos): save_image( image / 255.0, os.path.join( args.save_img_path, f'{args.sample_method}_{index}_gs{guidance_scale}_s{num_inference_steps}_i{i}.jpg' ), nrow=math.ceil(math.sqrt(videos.shape[0])), normalize=True, value_range=(0, 1) ) # b c h w save_image( videos / 255.0, os.path.join( args.save_img_path, f'{args.sample_method}_{index}_gs{guidance_scale}_s{num_inference_steps}.jpg' ), nrow=math.ceil(math.sqrt(videos.shape[0])), normalize=True, value_range=(0, 1) ) # b c h w else: if num_samples == 1: imageio.mimwrite( os.path.join( args.save_img_path, f'{args.sample_method}_{index}_gs{guidance_scale}_s{num_inference_steps}.mp4' ), videos[0], fps=args.fps, quality=6 ) # highest quality is 10, lowest is 0 else: for i in range(num_samples): imageio.mimwrite( os.path.join( args.save_img_path, f'{args.sample_method}_{index}_gs{guidance_scale}_s{num_inference_steps}_i{i}.mp4' ), videos[i], fps=args.fps, quality=6 ) # highest quality is 10, lowest is 0 videos = save_video_grid(videos) imageio.mimwrite( os.path.join( args.save_img_path, f'{args.sample_method}_{index}_gs{guidance_scale}_s{num_inference_steps}.mp4' ), videos, fps=args.fps, quality=6 ) # highest quality is 10, lowest is 0) videos = videos.unsqueeze(0) # 1 t h w c video_grids.append(videos) video_grids = torch.cat(video_grids, dim=0) final_path = os.path.join( args.save_img_path, f'{args.sample_method}_gs{guidance_scale}_s{num_inference_steps}' ) random_string = ''.join(random.choices(string.ascii_letters, k=4)) if num_frames == 1: final_path = final_path + f'_{random_string}.jpg' save_image( video_grids / 255.0, final_path, nrow=math.ceil(math.sqrt(len(video_grids))), normalize=True, value_range=(0, 1) ) else: video_grids = save_video_grid(video_grids) final_path = final_path + f'_{random_string}.mp4' imageio.mimwrite( final_path, video_grids, fps=args.fps, quality=6 ) print('save path {}'.format(args.save_img_path)) return final_path, seed parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default='LanguageBind/Open-Sora-Plan-v1.0.0') parser.add_argument("--version", type=str, default='v1_3', choices=['v1_3', 'v1_5']) parser.add_argument("--caption_refiner", type=str, default=None) parser.add_argument("--ae", type=str, default='CausalVAEModel_4x8x8') parser.add_argument("--ae_path", type=str, default='CausalVAEModel_4x8x8') parser.add_argument("--text_encoder_name_1", type=str, default='DeepFloyd/t5-v1_1-xxl') parser.add_argument("--text_encoder_name_2", type=str, default=None) parser.add_argument("--save_img_path", type=str, default="./sample_videos/t2v") parser.add_argument("--fps", type=int, default=24) parser.add_argument('--enable_tiling', action='store_true') parser.add_argument('--save_memory', action='store_true') parser.add_argument('--compile', action='store_true') parser.add_argument("--gradio_port", type=int, default=11900) parser.add_argument("--enhance_video", type=str, default=None) parser.add_argument("--model_type", type=str, default='i2v') args = parser.parse_args() args.model_path = "/storage/gyy/hw/Open-Sora-Plan/runs/inpaint_93x1280x1280_stage3_gpu/checkpoint-1692/model_ema" args.version = "v1_3" args.caption_refiner = "/storage/ongoing/refine_model/llama3_1_instruct_lora/llama3_8B_lora_merged_cn" args.ae = "WFVAEModel_D8_4x8x8" args.ae_path = "/storage/lcm/wf-vae_trilinear" args.text_encoder_name_1 = "/storage/ongoing/new/Open-Sora-Plan/cache_dir/mt5-xxl" args.text_encoder_name_2 = None args.save_img_path = "./test_gradio" args.fps = 18 args.prediction_type = "v_prediction" args.rescale_betas_zero_snr = True args.cache_dir = "./cache_dir" args.sample_method = 'EulerAncestralDiscrete' args.sp = False args.crop_for_hw = False args.max_hw_square = 1048576 args.enable_tiling = True dtype = torch.bfloat16 args = init_gpu_env(args) device = torch.cuda.current_device() if args.enhance_video is not None: from opensora.sample.VEnhancer.enhance_a_video import VEnhancer enhance_video_model = VEnhancer(model_path=args.enhance_video, version='v2', device=device) else: enhance_video_model = None pipeline = prepare_pipeline(args, dtype, device) if args.caption_refiner is not None: caption_refiner_model = OpenSoraCaptionRefiner(args, dtype, device) else: caption_refiner_model = None with gr.Blocks(css="style.css") as demo: gr.Markdown(LOGO) gr.Markdown(TITLE) gr.Markdown(DESCRIPTION) with gr.Row(equal_height=False): with gr.Group(): with gr.Row(): image_1 = gr.Image(type="filepath", label='Image 1') image_2 = gr.Image(type="filepath", label='Image 2') with gr.Row(): seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): num_frames = gr.Slider( label="Num Frames", minimum=29, maximum=93, step=16, value=29, ) num_samples = gr.Slider( label="Num Samples", minimum=1, maximum=4, step=1, value=1, ) with gr.Row(): guidance_scale = gr.Slider( label="Guidance scale", minimum=1, maximum=10, step=0.1, value=7.5, ) inference_steps = gr.Slider( label="Inference steps", minimum=10, maximum=200, step=1, value=50, ) with gr.Group(): with gr.Row(): prompt = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, ) run_button = gr.Button("Run", scale=0) result = gr.Video(autoplay=True, label="Result") # result = gr.Gallery(label="Result", columns=NUM_IMAGES_PER_PROMPT, show_label=False) # with gr.Row(), gr.Column(): # gr.Markdown("## Examples (Text-to-Video)") # examples = [[i, 42, 93, 1, 7.5, 100, False] for i in t2v_prompt_examples] # gr.Examples( # examples=examples, # inputs=[ # prompt, seed, num_frames, num_samples, # guidance_scale, inference_steps, randomize_seed # ], # label='Text-to-Video', # cache_examples=False, # outputs=[result, seed], # fn=generate # ) gr.on( triggers=[ prompt.submit, run_button.click, ], fn=generate, inputs=[ prompt, image_1, image_2, seed, num_frames, num_samples, guidance_scale, inference_steps, randomize_seed, ], outputs=[result, seed], api_name="run", ) # if __name__ == "__main__": demo.queue(max_size=20).launch( server_name="0.0.0.0", server_port=args.gradio_port+args.local_rank, debug=True ) ================================================ FILE: opensora/serve/style.css ================================================ .gradio-container{width:1280px!important} ================================================ FILE: opensora/train/train_causalvae.py ================================================ import os import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, DistributedSampler, Subset import argparse import logging from colorlog import ColoredFormatter import tqdm from itertools import chain import wandb import random import numpy as np from pathlib import Path from einops import rearrange import time try: import lpips except: raise Exception("Need lpips to valid.") import sys sys.path.append(".") from opensora.models.causalvideovae.model import * from opensora.models.causalvideovae.model.ema_model import EMA from opensora.models.causalvideovae.dataset.ddp_sampler import CustomDistributedSampler from opensora.models.causalvideovae.dataset.video_dataset import TrainVideoDataset, ValidVideoDataset from opensora.models.causalvideovae.model.utils.module_utils import resolve_str_to_obj from opensora.models.causalvideovae.utils.video_utils import tensor_to_video def set_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) def ddp_setup(): dist.init_process_group(backend="nccl") torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) def setup_logger(rank): logger = logging.getLogger() logger.setLevel(logging.INFO) formatter = ColoredFormatter( f"[rank{rank}] %(log_color)s%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", log_colors={ "DEBUG": "cyan", "INFO": "green", "WARNING": "yellow", "ERROR": "red", "CRITICAL": "bold_red", }, reset=True, style="%", ) stream_handler = logging.StreamHandler() stream_handler.setLevel(logging.DEBUG) stream_handler.setFormatter(formatter) if not logger.handlers: logger.addHandler(stream_handler) return logger def check_unused_params(model): unused_params = [] for name, param in model.named_parameters(): if param.grad is None: unused_params.append(name) return unused_params def set_requires_grad_optimizer(optimizer, requires_grad): for param_group in optimizer.param_groups: for param in param_group["params"]: param.requires_grad = requires_grad def total_params(model): total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) total_params_in_millions = total_params / 1e6 return int(total_params_in_millions) def get_exp_name(args): return f"{args.exp_name}-lr{args.lr:.2e}-bs{args.batch_size}-rs{args.resolution}-sr{args.sample_rate}-fr{args.num_frames}" def set_train(modules): for module in modules: module.train() def set_eval(modules): for module in modules: module.eval() def set_modules_requires_grad(modules, requires_grad): for module in modules: module.requires_grad_(requires_grad) def save_checkpoint( epoch, current_step, optimizer_state, state_dict, scaler_state, sampler_state, checkpoint_dir, filename="checkpoint.ckpt", ema_state_dict={}, ): filepath = checkpoint_dir / Path(filename) torch.save( { "epoch": epoch, "current_step": current_step, "optimizer_state": optimizer_state, "state_dict": state_dict, "ema_state_dict": ema_state_dict, "scaler_state": scaler_state, "sampler_state": sampler_state, }, filepath, ) return filepath def valid(global_rank, rank, model, val_dataloader, precision, args): if args.eval_lpips: lpips_model = lpips.LPIPS(net="alex", spatial=True) lpips_model.to(rank) lpips_model = DDP(lpips_model, device_ids=[rank]) lpips_model.requires_grad_(False) lpips_model.eval() bar = None if global_rank == 0: bar = tqdm.tqdm(total=len(val_dataloader), desc="Validation...") psnr_list = [] lpips_list = [] video_log = [] num_video_log = args.eval_num_video_log with torch.no_grad(): for batch_idx, batch in enumerate(val_dataloader): inputs = batch["video"].to(rank) with torch.cuda.amp.autocast(dtype=precision): outputs = model(inputs) video_recon = outputs[0] # Upload videos if global_rank == 0: for i in range(len(video_recon)): if num_video_log <= 0: break video = tensor_to_video(video_recon[i]) video_log.append(video) num_video_log -= 1 inputs = rearrange(inputs, "b c t h w -> (b t) c h w").contiguous() video_recon = rearrange( video_recon, "b c t h w -> (b t) c h w" ).contiguous() # Calculate PSNR mse = torch.mean(torch.square(inputs - video_recon), dim=(1, 2, 3)) psnr = 20 * torch.log10(1 / torch.sqrt(mse)) psnr = psnr.mean().detach().cpu().item() # Calculate LPIPS if args.eval_lpips: lpips_score = ( lpips_model.forward(inputs, video_recon) .mean() .detach() .cpu() .item() ) lpips_list.append(lpips_score) psnr_list.append(psnr) if global_rank == 0: bar.update() # Release gpus memory torch.cuda.empty_cache() return psnr_list, lpips_list, video_log def gather_valid_result(psnr_list, lpips_list, video_log_list, rank, world_size): gathered_psnr_list = [None for _ in range(world_size)] gathered_lpips_list = [None for _ in range(world_size)] gathered_video_logs = [None for _ in range(world_size)] dist.all_gather_object(gathered_psnr_list, psnr_list) dist.all_gather_object(gathered_lpips_list, lpips_list) dist.all_gather_object(gathered_video_logs, video_log_list) return ( np.array(gathered_psnr_list).mean(), np.array(gathered_lpips_list).mean(), list(chain(*gathered_video_logs)), ) def train(args): # Setup logger ddp_setup() rank = int(os.environ["LOCAL_RANK"]) global_rank = dist.get_rank() logger = setup_logger(rank) # Init ckpt_dir = Path(args.ckpt_dir) / Path(get_exp_name(args)) if global_rank == 0: try: ckpt_dir.mkdir(exist_ok=False, parents=True) except: logger.warning(f"`{ckpt_dir}` exists!") time.sleep(5) dist.barrier() # Load generator model model_cls = ModelRegistry.get_model(args.model_name) if not model_cls: raise ModuleNotFoundError( f"`{args.model_name}` not in {str(ModelRegistry._models.keys())}." ) if args.pretrained_model_name_or_path is not None: if global_rank == 0: logger.warning( f"You are loading a checkpoint from `{args.pretrained_model_name_or_path}`." ) model = model_cls.from_pretrained( args.pretrained_model_name_or_path, ignore_mismatched_sizes=args.ignore_mismatched_sizes, low_cpu_mem_usage=False, device_map=None, ) else: if global_rank == 0: logger.warning(f"Model will be inited randomly.") model = model_cls.from_config(args.model_config) if global_rank == 0: logger.warning("Connecting to WANDB...") model_config = dict(**model.config) args_config = dict(**vars(args)) if 'resolution' in model_config: del model_config['resolution'] wandb.init( project=os.environ.get("WANDB_PROJECT", "causalvideovae"), config=dict(**model_config, **args_config), name=get_exp_name(args), ) dist.barrier() # Load discriminator model disc_cls = resolve_str_to_obj(args.disc_cls, append=False) logger.warning( f"disc_class: {args.disc_cls} perceptual_weight: {args.perceptual_weight} loss_type: {args.loss_type}" ) disc = disc_cls( disc_start=args.disc_start, disc_weight=args.disc_weight, kl_weight=args.kl_weight, logvar_init=args.logvar_init, perceptual_weight=args.perceptual_weight, loss_type=args.loss_type, wavelet_weight=args.wavelet_weight ) # DDP model = model.to(rank, ) model = DDP( model, device_ids=[rank], find_unused_parameters=args.find_unused_parameters ) disc = disc.to(rank) disc = DDP( disc, device_ids=[rank], find_unused_parameters=args.find_unused_parameters ) # Load dataset dataset = TrainVideoDataset( args.video_path, sequence_length=args.num_frames, resolution=args.resolution, sample_rate=args.sample_rate, dynamic_sample=args.dynamic_sample, cache_file="idx.pkl", is_main_process=global_rank == 0, ) ddp_sampler = CustomDistributedSampler(dataset) dataloader = DataLoader( dataset, batch_size=args.batch_size, sampler=ddp_sampler, pin_memory=True, num_workers=args.dataset_num_worker, ) val_dataset = ValidVideoDataset( real_video_dir=args.eval_video_path, num_frames=args.eval_num_frames, sample_rate=args.eval_sample_rate, crop_size=args.eval_resolution, resolution=args.eval_resolution, ) indices = range(args.eval_subset_size) val_dataset = Subset(val_dataset, indices=indices) val_sampler = CustomDistributedSampler(val_dataset) val_dataloader = DataLoader( val_dataset, batch_size=args.eval_batch_size, sampler=val_sampler, pin_memory=True, ) # Optimizer modules_to_train = [module for module in model.module.get_decoder()] if not args.freeze_encoder: modules_to_train += [module for module in model.module.get_encoder()] else: for module in model.module.get_encoder(): module.eval() module.requires_grad_(False) logger.warning("Encoder is freezed!") parameters_to_train = [] for module in modules_to_train: parameters_to_train += list(filter(lambda p: p.requires_grad, module.parameters())) gen_optimizer = torch.optim.AdamW(parameters_to_train, lr=args.lr, weight_decay=1e-4) disc_optimizer = torch.optim.AdamW( filter(lambda p: p.requires_grad, disc.module.discriminator.parameters()), lr=args.lr, weight_decay=0.01 ) # AMP scaler scaler = torch.cuda.amp.GradScaler() precision = torch.bfloat16 if args.mix_precision == "fp16": precision = torch.float16 elif args.mix_precision == "fp32": precision = torch.float32 print(precision) # Load from checkpoint start_epoch = 0 current_step = 0 if args.resume_from_checkpoint: if not os.path.isfile(args.resume_from_checkpoint): raise Exception( f"Make sure `{args.resume_from_checkpoint}` is a ckpt file." ) checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu") model.module.load_state_dict(checkpoint["state_dict"]["gen_model"], strict=False) disc.module.load_state_dict(checkpoint["state_dict"]["dics_model"]) scaler.load_state_dict(checkpoint["scaler_state"]) gen_optimizer.load_state_dict(checkpoint["optimizer_state"]["gen_optimizer"]) disc_optimizer.load_state_dict(checkpoint["optimizer_state"]["disc_optimizer"]) ddp_sampler.load_state_dict(checkpoint["sampler_state"]) start_epoch = checkpoint["sampler_state"]["epoch"] current_step = checkpoint["current_step"] logger.info( f"Checkpoint loaded from {args.resume_from_checkpoint}, starting from epoch {start_epoch} step {current_step}" ) if args.ema: logger.warning(f"Start with EMA. EMA decay = {args.ema_decay}.") ema = EMA(model, args.ema_decay) ema.register() # Training loop logger.info("Prepared!") dist.barrier() if global_rank == 0: logger.info(f"=== Model Params ===") logger.info(f"Generator:\t\t{total_params(model.module)}M") logger.info(f"\t- Encoder:\t{total_params(model.module.encoder):d}M") logger.info(f"\t- Decoder:\t{total_params(model.module.decoder):d}M") logger.info(f"Discriminator:\t{total_params(disc.module):d}M") logger.info(f"===========") logger.info(f"Precision is set to: {args.mix_precision}!") logger.info("Start training!") # Training Bar bar_desc = "" bar = None if global_rank == 0: max_steps = ( args.epochs * len(dataloader) if args.max_steps is None else args.max_steps ) bar = tqdm.tqdm(total=max_steps, desc=bar_desc.format(current_epoch=0, loss=0)) bar.update(current_step) bar_desc = "Epoch: {current_epoch}, Loss: {loss}" logger.warning("Training Details: ") logger.warning(f" Max steps: {max_steps}") logger.warning(f" Dataset Samples: {len(dataloader)}") logger.warning( f" Total Batch Size: {args.batch_size} * {os.environ['WORLD_SIZE']}" ) dist.barrier() # Training Loop num_epochs = args.epochs def update_bar(bar): if global_rank == 0: bar.desc = bar_desc.format(current_epoch=epoch, loss=f"-") bar.update() for epoch in range(num_epochs): set_train(modules_to_train) ddp_sampler.set_epoch(epoch) # Shuffle data at every epoch for batch_idx, batch in enumerate(dataloader): inputs = batch["video"].to(rank) if ( current_step % 2 == 1 and current_step >= disc.module.discriminator_iter_start ): set_modules_requires_grad(modules_to_train, False) step_gen = False step_dis = True else: set_modules_requires_grad(modules_to_train, True) step_gen = True step_dis = False assert ( step_gen or step_dis ), "You should backward either Gen or Dis in a step." with torch.cuda.amp.autocast(dtype=precision): outputs = model(inputs) recon = outputs[0] posterior = outputs[1] if len(outputs) == 3: # which means there is wavelet output wavelet_coeffs = outputs[2] if args.wavelet_loss else None else: wavelet_coeffs = None # Generator Step if step_gen: with torch.cuda.amp.autocast(dtype=precision): g_loss, g_log = disc( inputs, recon, posterior, optimizer_idx=0, global_step=current_step, last_layer=model.module.get_last_layer(), wavelet_coeffs=wavelet_coeffs, split="train", ) gen_optimizer.zero_grad() scaler.scale(g_loss).backward() # scaler.unscale_(gen_optimizer) # torch.nn.utils.clip_grad_norm_(parameters_to_train, 5e6) scaler.step(gen_optimizer) scaler.update() if args.ema: ema.update() if global_rank == 0 and current_step % args.log_steps == 0: wandb.log( {"train/generator_loss": g_loss.item()}, step=current_step ) wandb.log( {"train/rec_loss": g_log['train/rec_loss']}, step=current_step ) wandb.log( {"train/latents_std": posterior.sample().std().item()}, step=current_step ) if 'train/sb_loss' in g_log: wandb.log( {"train/sb_loss": g_log['train/sb_loss']}, step=current_step ) if 'train/wl_loss' in g_log: wandb.log( {"train/wl_loss": g_log['train/wl_loss']}, step=current_step ) # Discriminator Step if step_dis: with torch.cuda.amp.autocast(dtype=precision): d_loss, d_log = disc( inputs, recon, posterior, optimizer_idx=1, global_step=current_step, last_layer=None, split="train", ) disc_optimizer.zero_grad() scaler.scale(d_loss).backward() scaler.unscale_(disc_optimizer) torch.nn.utils.clip_grad_norm_(disc.module.discriminator.parameters(), 1.0) scaler.step(disc_optimizer) scaler.update() if global_rank == 0 and current_step % args.log_steps == 0: wandb.log( {"train/discriminator_loss": d_loss.item()}, step=current_step ) update_bar(bar) current_step += 1 def valid_model(model, name=""): set_eval(modules_to_train) psnr_list, lpips_list, video_log = valid( global_rank, rank, model, val_dataloader, precision, args ) valid_psnr, valid_lpips, valid_video_log = gather_valid_result( psnr_list, lpips_list, video_log, rank, dist.get_world_size() ) if global_rank == 0: name = "_" + name if name != "" else name wandb.log( { f"val{name}/recon": wandb.Video( np.array(valid_video_log), fps=10 ) }, step=current_step, ) wandb.log({f"val{name}/psnr": valid_psnr}, step=current_step) wandb.log({f"val{name}/lpips": valid_lpips}, step=current_step) logger.info(f"{name} Validation done.") if current_step % args.eval_steps == 0 or current_step == 1: if global_rank == 0: logger.info("Starting validation...") valid_model(model) if args.ema: ema.apply_shadow() valid_model(model, "ema") ema.restore() # Checkpoint if current_step % args.save_ckpt_step == 0 and global_rank == 0: file_path = save_checkpoint( epoch, current_step, { "gen_optimizer": gen_optimizer.state_dict(), "disc_optimizer": disc_optimizer.state_dict(), }, { "gen_model": model.module.state_dict(), "dics_model": disc.module.state_dict(), }, scaler.state_dict(), ddp_sampler.state_dict(), ckpt_dir, f"checkpoint-{current_step}.ckpt", ema_state_dict=ema.shadow if args.ema else {}, ) logger.info(f"Checkpoint has been saved to `{file_path}`.") dist.destroy_process_group() def main(): parser = argparse.ArgumentParser(description="Distributed Training") # Exp setting parser.add_argument( "--exp_name", type=str, default="test", help="number of epochs to train" ) parser.add_argument("--seed", type=int, default=1234, help="seed") # Training setting parser.add_argument( "--epochs", type=int, default=10, help="number of epochs to train" ) parser.add_argument( "--max_steps", type=int, default=None, help="number of epochs to train" ) parser.add_argument("--save_ckpt_step", type=int, default=1000, help="") parser.add_argument("--ckpt_dir", type=str, default="./results/", help="") parser.add_argument( "--batch_size", type=int, default=1, help="batch size for training" ) parser.add_argument("--lr", type=float, default=1e-5, help="learning rate") parser.add_argument("--log_steps", type=int, default=5, help="log steps") parser.add_argument("--freeze_encoder", action="store_true", help="") parser.add_argument("--clip_grad_norm", type=float, default=1e5, help="") # Data parser.add_argument("--video_path", type=str, default=None, help="") parser.add_argument("--num_frames", type=int, default=17, help="") parser.add_argument("--resolution", type=int, default=256, help="") parser.add_argument("--sample_rate", type=int, default=2, help="") parser.add_argument("--dynamic_sample", action="store_true", help="") # Generator model parser.add_argument("--ignore_mismatched_sizes", action="store_true", help="") parser.add_argument("--find_unused_parameters", action="store_true", help="") parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, help="" ) parser.add_argument("--model_name", type=str, default=None, help="") parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="") parser.add_argument("--not_resume_training_process", action="store_true", help="") parser.add_argument("--model_config", type=str, default=None, help="") parser.add_argument( "--mix_precision", type=str, default="bf16", choices=["fp16", "bf16", "fp32"], help="precision for training", ) parser.add_argument("--wavelet_loss", action="store_true", help="") parser.add_argument("--wavelet_weight", type=float, default=0.1, help="") # Discriminator Model parser.add_argument("--load_disc_from_checkpoint", type=str, default=None, help="") parser.add_argument( "--disc_cls", type=str, default="opensora.models.causalvideovae.model.losses.LPIPSWithDiscriminator3D", help="", ) parser.add_argument("--disc_start", type=int, default=5, help="") parser.add_argument("--disc_weight", type=float, default=0.5, help="") parser.add_argument("--kl_weight", type=float, default=1e-06, help="") parser.add_argument("--perceptual_weight", type=float, default=1.0, help="") parser.add_argument("--loss_type", type=str, default="l1", help="") parser.add_argument("--logvar_init", type=float, default=0.0, help="") # Validation parser.add_argument("--eval_steps", type=int, default=1000, help="") parser.add_argument("--eval_video_path", type=str, default=None, help="") parser.add_argument("--eval_num_frames", type=int, default=17, help="") parser.add_argument("--eval_resolution", type=int, default=256, help="") parser.add_argument("--eval_sample_rate", type=int, default=1, help="") parser.add_argument("--eval_batch_size", type=int, default=8, help="") parser.add_argument("--eval_subset_size", type=int, default=100, help="") parser.add_argument("--eval_num_video_log", type=int, default=2, help="") parser.add_argument("--eval_lpips", action="store_true", help="") # Dataset parser.add_argument("--dataset_num_worker", type=int, default=4, help="") # EMA parser.add_argument("--ema", action="store_true", help="") parser.add_argument("--ema_decay", type=float, default=0.999, help="") args = parser.parse_args() set_random_seed(args.seed) train(args) if __name__ == "__main__": main() ================================================ FILE: opensora/train/train_inpaint.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ A minimal training script for DiT using PyTorch DDP. """ import argparse import logging import math import os import shutil from pathlib import Path from typing import Optional import gc import numpy as np from einops import rearrange import torch.utils import torch.utils.data from tqdm import tqdm import yaml from opensora.adaptor.modules import replace_with_fp32_forwards try: import torch_npu from opensora.npu_config import npu_config from opensora.acceleration.parallel_states import initialize_sequence_parallel_state, \ destroy_sequence_parallel_group, get_sequence_parallel_state, set_sequence_parallel_state from opensora.acceleration.communications import prepare_parallel_data, broadcast except: torch_npu = None npu_config = None from opensora.utils.parallel_states import initialize_sequence_parallel_state, \ destroy_sequence_parallel_group, get_sequence_parallel_state, set_sequence_parallel_state from opensora.utils.communications import prepare_parallel_data, broadcast pass import time from dataclasses import field, dataclass from torch.utils.data import DataLoader from copy import deepcopy import accelerate import torch import json from torch.nn import functional as F import transformers from transformers.utils import ContextManagers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import DistributedType, ProjectConfiguration, set_seed from accelerate.state import AcceleratorState from packaging import version from tqdm.auto import tqdm import diffusers from diffusers import DDPMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, CogVideoXDDIMScheduler from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel, compute_snr from diffusers.utils import check_min_version, is_wandb_available from opensora.models.causalvideovae import ae_stride_config, ae_channel_config from opensora.models.causalvideovae import ae_norm, ae_denorm from opensora.models import CausalVAEModelWrapper from opensora.models.text_encoder import get_text_warpper from opensora.dataset import getdataset from opensora.models import CausalVAEModelWrapper from opensora.models.diffusion import Diffusion_models, Diffusion_models_class from opensora.utils.dataset_utils import Collate, LengthGroupedSampler from opensora.utils.utils import explicit_uniform_sampling from opensora.sample.pipeline_opensora import OpenSoraPipeline from opensora.models.causalvideovae import ae_stride_config, ae_wrapper from opensora.utils.mask_utils import MaskCompressor, GaussianNoiseAdder # from opensora.utils.utils import monitor_npu_power # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.24.0") logger = get_logger(__name__) from torch.utils.data import _utils _utils.MP_STATUS_CHECK_INTERVAL = 1800.0 # dataloader timeout (default is 5.0s), we increase it to 1800s. class ProgressInfo: def __init__(self, global_step, train_loss=0.0): self.global_step = global_step self.train_loss = train_loss ################################################################################# # Training Loop # ################################################################################# def main(args): logging_dir = Path(args.output_dir, args.logging_dir) if torch_npu is not None and npu_config is not None: npu_config.print_msg(args) npu_config.seed_everything(args.seed) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, ) if args.num_frames != 1: initialize_sequence_parallel_state(args.sp_size) if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed, device_specific=True) # Handle the repository creation if accelerator.is_main_process: if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) # backup the config file shutil.copy(args.mask_config, os.path.join(args.output_dir, "mask_config.yaml")) # For mixed precision training we cast all non-trainable weigths to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # Create model: # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. # For this to work properly all models must be run through `accelerate.prepare`. But accelerate # will try to assign the same optimizer with the same weights to all models during # `deepspeed.initialize`, which of course doesn't work. # # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 # frozen models from being partitioned during `zero.Init` which gets called during # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. def deepspeed_zero_init_disabled_context_manager(): """ returns either a context list that includes one that will disable zero.Init or an empty context list """ deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None if deepspeed_plugin is None: return [] return [deepspeed_plugin.zero3_init_context_manager(enable=False)] with ContextManagers(deepspeed_zero_init_disabled_context_manager()): kwargs = {} ae = ae_wrapper[args.ae](args.ae_path, cache_dir=args.cache_dir, **kwargs).eval() if args.enable_tiling: ae.vae.enable_tiling() kwargs = { 'torch_dtype': weight_dtype, 'low_cpu_mem_usage': False } text_enc_1 = get_text_warpper(args.text_encoder_name_1)(args, **kwargs).eval() text_enc_2 = None if args.text_encoder_name_2 is not None: text_enc_2 = get_text_warpper(args.text_encoder_name_2)(args, **kwargs).eval() ae_stride_t, ae_stride_h, ae_stride_w = ae_stride_config[args.ae] ae.vae_scale_factor = (ae_stride_t, ae_stride_h, ae_stride_w) assert ae_stride_h == ae_stride_w, f"Support only ae_stride_h == ae_stride_w now, but found ae_stride_h ({ae_stride_h}), ae_stride_w ({ae_stride_w})" args.ae_stride_t, args.ae_stride_h, args.ae_stride_w = ae_stride_t, ae_stride_h, ae_stride_w args.ae_stride = args.ae_stride_h patch_size = args.model[-3:] patch_size_t, patch_size_h, patch_size_w = int(patch_size[0]), int(patch_size[1]), int(patch_size[2]) args.patch_size = patch_size_h args.patch_size_t, args.patch_size_h, args.patch_size_w = patch_size_t, patch_size_h, patch_size_w assert patch_size_h == patch_size_w, f"Support only patch_size_h == patch_size_w now, but found patch_size_h ({patch_size_h}), patch_size_w ({patch_size_w})" assert (args.num_frames - 1) % ae_stride_t == 0, f"(Frames - 1) must be divisible by ae_stride_t, but found num_frames ({args.num_frames}), ae_stride_t ({ae_stride_t})." assert args.max_height % ae_stride_h == 0, f"Height must be divisible by ae_stride_h, but found Height ({args.max_height}), ae_stride_h ({ae_stride_h})." assert args.max_width % ae_stride_h == 0, f"Width size must be divisible by ae_stride_h, but found Width ({args.max_width}), ae_stride_h ({ae_stride_h})." args.stride_t = ae_stride_t * patch_size_t args.stride = ae_stride_h * patch_size_h ae.latent_size = latent_size = (args.max_height // ae_stride_h, args.max_width // ae_stride_w) args.latent_size_t = latent_size_t = (args.num_frames - 1) // ae_stride_t + 1 mask_compressor = MaskCompressor(ae_stride_h=ae_stride_h, ae_stride_w=ae_stride_w, ae_stride_t=ae_stride_t) noise_adder = None if args.add_noise_to_condition: noise_adder = GaussianNoiseAdder(mean=-3.0, std=0.5, clear_ratio=0.05) model_kwargs = {'vae_scale_factor_t': ae_stride_t} model = Diffusion_models[args.model]( in_channels=ae_channel_config[args.ae], out_channels=ae_channel_config[args.ae], sample_size_h=latent_size, sample_size_w=latent_size, sample_size_t=latent_size_t, interpolation_scale_h=args.interpolation_scale_h, interpolation_scale_w=args.interpolation_scale_w, interpolation_scale_t=args.interpolation_scale_t, sparse1d=args.sparse1d, sparse_n=args.sparse_n, **model_kwargs, ) # # use pretrained model? if args.pretrained: model_state_dict = model.state_dict() print(f'Load from {args.pretrained}') if args.pretrained.endswith('.safetensors'): from safetensors.torch import load_file as safe_load pretrained_checkpoint = safe_load(args.pretrained, device="cpu") pretrained_keys = set(list(pretrained_checkpoint.keys())) model_keys = set(list(model_state_dict.keys())) common_keys = list(pretrained_keys & model_keys) checkpoint = {k: pretrained_checkpoint[k] for k in common_keys if model_state_dict[k].numel() == pretrained_checkpoint[k].numel()} if not 'pos_embed_masked_hidden_states.0.proj.weight' in checkpoint: checkpoint['pos_embed_masked_hidden_states.0.proj.weight'] = checkpoint['pos_embed.proj.weight'] checkpoint['pos_embed_masked_hidden_states.0.proj.bias'] = checkpoint['pos_embed.proj.bias'] missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) elif os.path.isdir(args.pretrained): if os.path.exists(os.path.join(args.pretrained, 'config.json')): with open(os.path.join(args.pretrained, 'config.json')) as f: config = json.load(f) class_name = config['_class_name'] print(f'Load from {args.pretrained} with class_name {class_name}') load_model = Diffusion_models_class[class_name].from_pretrained(args.pretrained) missing_keys, unexpected_keys = model.load_state_dict(load_model.state_dict(), strict=False) if 'pos_embed_masked_hidden_states.0.proj.weight' in missing_keys: model.pos_embed_masked_hidden_states[0].proj.weight.data = deepcopy(load_model.pos_embed.proj.weight.data) model.pos_embed_masked_hidden_states[0].proj.bias.data = deepcopy(load_model.pos_embed.proj.bias.data) assert torch.equal(model.pos_embed_masked_hidden_states[0].proj.weight.data, load_model.pos_embed.proj.weight.data) assert torch.equal(model.pos_embed_masked_hidden_states[0].proj.bias.data, load_model.pos_embed.proj.bias.data) missing_keys.remove('pos_embed_masked_hidden_states.0.proj.weight') missing_keys.remove('pos_embed_masked_hidden_states.0.proj.bias') del load_model else: raise ValueError(f'Invalid pretrained model path: {args.pretrained}, you should provide a valid pretrained model path within a valid config.json file!') else: pretrained_checkpoint = torch.load(args.pretrained, map_location='cpu') if 'model' in checkpoint: pretrained_checkpoint = pretrained_checkpoint['model'] pretrained_keys = set(list(pretrained_checkpoint.keys())) model_keys = set(list(model_state_dict.keys())) common_keys = list(pretrained_keys & model_keys) checkpoint = {k: pretrained_checkpoint[k] for k in common_keys if model_state_dict[k].numel() == pretrained_checkpoint[k].numel()} if not 'pos_embed_masked_hidden_states.0.proj.weight' in checkpoint: checkpoint['pos_embed_masked_hidden_states.0.proj.weight'] = checkpoint['pos_embed.proj.weight'] checkpoint['pos_embed_masked_hidden_states.0.proj.bias'] = checkpoint['pos_embed.proj.bias'] missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) print(f'missing_keys {len(missing_keys)} {missing_keys}, unexpected_keys {len(unexpected_keys)}') print(f'Successfully load {len(model_state_dict) - len(missing_keys)}/{len(model_state_dict)} keys from {args.pretrained}!') model.gradient_checkpointing = args.gradient_checkpointing # Freeze vae and text encoders. ae.vae.requires_grad_(False) text_enc_1.requires_grad_(False) if text_enc_2 is not None: text_enc_2.requires_grad_(False) # Set model as trainable. model.train() kwargs = dict( prediction_type=args.prediction_type, rescale_betas_zero_snr=args.rescale_betas_zero_snr ) if args.cogvideox_scheduler: noise_scheduler = CogVideoXDDIMScheduler(**kwargs) elif args.v1_5_scheduler: kwargs['beta_start'] = 0.00085 kwargs['beta_end'] = 0.0120 kwargs['beta_schedule'] = "scaled_linear" noise_scheduler = DDPMScheduler(**kwargs) else: noise_scheduler = DDPMScheduler(**kwargs) # Move unet, vae and text_encoder to device and cast to weight_dtype # The VAE is in float32 to avoid NaN losses. if not args.extra_save_mem: ae.vae.to(accelerator.device, dtype=torch.float32 if args.vae_fp32 else weight_dtype) text_enc_1.to(accelerator.device, dtype=weight_dtype) if text_enc_2 is not None: text_enc_2.to(accelerator.device, dtype=weight_dtype) # Create EMA for the unet. if args.use_ema: ema_model = deepcopy(model) ema_model = EMAModel(ema_model.parameters(), decay=args.ema_decay, update_after_step=args.ema_start_step, model_cls=Diffusion_models_class[args.model], model_config=ema_model.config) # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: if args.use_ema: ema_model.save_pretrained(os.path.join(output_dir, "model_ema")) for i, model in enumerate(models): model.save_pretrained(os.path.join(output_dir, "model")) if weights: # Don't pop if empty # make sure to pop weight so that corresponding model is not saved again weights.pop() def load_model_hook(models, input_dir): if args.use_ema: load_model = EMAModel.from_pretrained(os.path.join(input_dir, "model_ema"), Diffusion_models_class[args.model]) ema_model.load_state_dict(load_model.state_dict()) ema_model.to(accelerator.device) del load_model for i in range(len(models)): # pop models so that they are not loaded again model = models.pop() # load diffusers style into model load_model = Diffusion_models_class[args.model].from_pretrained(input_dir, subfolder="model") model.register_to_config(**load_model.config) model.load_state_dict(load_model.state_dict()) del load_model accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True params_to_optimize = model.parameters() # Optimizer creation if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): logger.warning( f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." "Defaulting to adamW" ) args.optimizer = "adamw" if args.use_8bit_adam and not args.optimizer.lower() == "adamw": logger.warning( f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " f"set to {args.optimizer.lower()}" ) if args.optimizer.lower() == "adamw": if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError( "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." ) optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) if args.optimizer.lower() == "prodigy": try: import prodigyopt except ImportError: raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") optimizer_class = prodigyopt.Prodigy if args.learning_rate <= 0.1: logger.warning( "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" ) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), beta3=args.prodigy_beta3, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, decouple=args.prodigy_decouple, use_bias_correction=args.prodigy_use_bias_correction, safeguard_warmup=args.prodigy_safeguard_warmup, ) logger.info(f"optimizer: {optimizer}") # Setup data: if args.trained_data_global_step is not None: initial_global_step_for_sampler = args.trained_data_global_step else: initial_global_step_for_sampler = 0 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = total_batch_size // args.sp_size * args.train_sp_batch_size args.total_batch_size = total_batch_size if args.min_hxw is None: args.min_hxw = args.max_hxw // 4 train_dataset = getdataset(args) sampler = LengthGroupedSampler( args.train_batch_size, world_size=accelerator.num_processes, gradient_accumulation_size=args.gradient_accumulation_steps, initial_global_step=initial_global_step_for_sampler, lengths=train_dataset.lengths, group_data=args.group_data, ) train_dataloader = DataLoader( train_dataset, shuffle=False, # pin_memory=True, collate_fn=Collate(args), batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers, sampler=sampler, drop_last=True, # prefetch_factor=4 ) logger.info(f'after train_dataloader') # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) # Prepare everything with our `accelerator`. # model.requires_grad_(False) # model.pos_embed.requires_grad_(True) # model.patch_embed.requires_grad_(True) logger.info(f'before accelerator.prepare') model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler ) logger.info(f'after accelerator.prepare') if args.use_ema: ema_model.to(accelerator.device) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. # NOTE wandb if accelerator.is_main_process: logger.info("init trackers...") project_name = os.getenv('PROJECT', os.path.basename(args.output_dir)) entity = os.getenv('ENTITY', None) run_name = os.getenv('WANDB_NAME', None) init_kwargs = { "entity": entity, "run_name": run_name, } accelerator.init_trackers(project_name=project_name, config=vars(args), init_kwargs=init_kwargs) # Train! print(f" Args = {args}") logger.info(f" Args = {args}") logger.info("***** Running training *****") logger.info(f" Model = {model}") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") logger.info(f" Total optimization steps (num_update_steps_per_epoch) = {num_update_steps_per_epoch}") logger.info(f" Total training parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9} B") logger.info(f" AutoEncoder = {args.ae}; Dtype = {ae.vae.dtype}; Parameters = {sum(p.numel() for p in ae.parameters()) / 1e9} B") logger.info(f" Text_enc_1 = {args.text_encoder_name_1}; Dtype = {weight_dtype}; Parameters = {sum(p.numel() for p in text_enc_1.parameters()) / 1e9} B") if args.text_encoder_name_2 is not None: logger.info(f" Text_enc_2 = {args.text_encoder_name_2}; Dtype = {weight_dtype}; Parameters = {sum(p.numel() for p in text_enc_2.parameters()) / 1e9} B") global_step = 0 first_epoch = 0 # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) else: # Get the most recent checkpoint dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None if path is None: accelerator.print( f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch else: initial_global_step = 0 progress_bar = tqdm( range(0, args.max_train_steps), initial=initial_global_step, desc="Steps", # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) progress_info = ProgressInfo(global_step, train_loss=0.0) def sync_gradients_info(loss): # Checks if the accelerator has performed an optimization step behind the scenes if args.use_ema: ema_model.step(model.parameters()) progress_bar.update(1) progress_info.global_step += 1 end_time = time.time() one_step_duration = end_time - start_time accelerator.log({"train_loss": progress_info.train_loss}, step=progress_info.global_step) if torch_npu is not None and npu_config is not None: npu_config.print_msg(f"Step: [{progress_info.global_step}], local_loss={loss.detach().item()}, " f"train_loss={progress_info.train_loss}, time_cost={one_step_duration}", rank=0) progress_info.train_loss = 0.0 # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues. if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: if progress_info.global_step % args.checkpointing_steps == 0 or progress_info.global_step == args.after_one_epoch_global_step: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if accelerator.is_main_process and args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") for removing_checkpoint in removing_checkpoints: removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) shutil.rmtree(removing_checkpoint) save_path = os.path.join(args.output_dir, f"checkpoint-{progress_info.global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) # Regularly releasing memory # if progress_info.global_step % 100 == 0: # torch.cuda.empty_cache() # gc.collect() def run(model_input, model_kwargs, prof): global start_time start_time = time.time() try: in_channels = ae_channel_config[args.ae] model_input, masked_input, video_mask = model_input[:, 0:in_channels], model_input[:, in_channels:2 * in_channels], model_input[:, 2 * in_channels:] except: raise ValueError("masked_x and video_mask is None!") noise = torch.randn_like(model_input) if args.noise_offset: # https://www.crosslabs.org//blog/diffusion-with-offset-noise noise += args.noise_offset * torch.randn((model_input.shape[0], model_input.shape[1], 1, 1, 1), device=model_input.device) bsz = model_input.shape[0] # Sample a random timestep for each image without bias. if accelerator.num_processes > noise_scheduler.config.num_train_timesteps: timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device) else: timesteps = explicit_uniform_sampling( T=noise_scheduler.config.num_train_timesteps, n=accelerator.num_processes, rank=accelerator.process_index, bsz=bsz, device=model_input.device, ) # print(f'rank: {accelerator.process_index}, timesteps: {timesteps}') if get_sequence_parallel_state(): # image do not need sp, disable when image batch broadcast(timesteps) # Add noise to the model input according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) model_pred = model( torch.cat([noisy_model_input, masked_input, video_mask], dim=1), timesteps, **model_kwargs, )[0] # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(model_input, noise, timesteps) elif noise_scheduler.config.prediction_type == "sample": # We set the target to latents here, but the model_pred will return the noise sample prediction. target = model_input # We will have to subtract the noise residual from the prediction to get the target sample. model_pred = model_pred - noise else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") mask = model_kwargs.get('attention_mask', None) if get_sequence_parallel_state(): if torch.all(mask.bool()): mask = None # mask (sp_bs*b t h w) assert mask is None b, c, _, _, _ = model_pred.shape if mask is not None: mask = mask.unsqueeze(1).repeat(1, c, 1, 1, 1).float() # b t h w -> b c t h w mask = mask.reshape(b, -1) if args.snr_gamma is None: # model_pred: b c t h w, attention_mask: b t h w loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.reshape(b, -1) if mask is not None: loss = (loss * mask).sum() / mask.sum() # mean loss on unpad patches else: loss = loss.mean() else: # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( dim=1 )[0] if noise_scheduler.config.prediction_type == "epsilon": mse_loss_weights = mse_loss_weights / snr elif noise_scheduler.config.prediction_type == "v_prediction": mse_loss_weights = mse_loss_weights / (snr + 1) else: raise NameError(f'{noise_scheduler.config.prediction_type}') loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.reshape(b, -1) mse_loss_weights = mse_loss_weights.reshape(b, 1) if mask is not None: loss = (loss * mask * mse_loss_weights).sum() / mask.sum() # mean loss on unpad patches else: loss = (loss * mse_loss_weights).mean() # Gather the losses across all processes for logging (if we use distributed training). avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() progress_info.train_loss += avg_loss.detach().item() / args.gradient_accumulation_steps # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = model.parameters() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad() if accelerator.sync_gradients: sync_gradients_info(loss) if prof is not None: prof.step() return loss def train_one_step(step_, data_item_, prof_=None): train_loss = 0.0 x, attn_mask, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2 = data_item_ if accelerator.is_main_process: print(f'\nstep: {step_}, x: {x.shape}, dtype: {x.dtype}') # assert not torch.any(torch.isnan(x)), 'torch.any(torch.isnan(x))' # print('after data collate') # print(f'x: {x.shape}, attn_mask: {attn_mask.shape}, input_ids_1: {input_ids_1.shape}, cond_mask_1: {cond_mask_1.shape}, input_ids_2: {input_ids_2.shape}, cond_mask_2: {cond_mask_2.shape}') if args.extra_save_mem: torch.cuda.empty_cache() ae.vae.to(accelerator.device, dtype=torch.float32 if args.vae_fp32 else weight_dtype) text_enc_1.to(accelerator.device, dtype=weight_dtype) if text_enc_2 is not None: text_enc_2.to(accelerator.device, dtype=weight_dtype) x = x.to(accelerator.device, dtype=ae.vae.dtype) # B C T H W # x = x.to(accelerator.device, dtype=torch.float32) # B C T H W attn_mask = attn_mask.to(accelerator.device) # B T H W input_ids_1 = input_ids_1.to(accelerator.device) # B 1 L cond_mask_1 = cond_mask_1.to(accelerator.device) # B 1 L input_ids_2 = input_ids_2.to(accelerator.device) if input_ids_2 is not None else input_ids_2 # B 1 L cond_mask_2 = cond_mask_2.to(accelerator.device) if cond_mask_2 is not None else cond_mask_2 # B 1 L with torch.no_grad(): B, N, L = input_ids_1.shape # B 1 L # use batch inference input_ids_1 = input_ids_1.reshape(-1, L) cond_mask_1 = cond_mask_1.reshape(-1, L) cond_1 = text_enc_1(input_ids_1, cond_mask_1) # B L D cond_1 = cond_1.reshape(B, N, L, -1) cond_mask_1 = cond_mask_1.reshape(B, N, L) if text_enc_2 is not None: B_, N_, L_ = input_ids_2.shape # B 1 L input_ids_2 = input_ids_2.reshape(-1, L_) cond_2 = text_enc_2(input_ids_2, cond_mask_2) # B D cond_2 = cond_2.reshape(B_, 1, -1) # B 1 D else: cond_2 = None # Map input images to latent space + normalize latents x, masked_x, mask = x[:, :3], x[:, 3:6], x[:, 6:7] # Adding noise to control frames enhances generalization ability. if noise_adder is not None: masked_x = noise_adder(masked_x, mask) x, masked_x = ae.encode(x), ae.encode(masked_x) mask = mask_compressor(mask) x = torch.cat([x, masked_x, mask], dim=1) if args.extra_save_mem: ae.vae.to('cpu') text_enc_1.to('cpu') if text_enc_2 is not None: text_enc_2.to('cpu') torch.cuda.empty_cache() current_step_frame = x.shape[2] current_step_sp_state = get_sequence_parallel_state() if args.sp_size != 1: # enable sp if current_step_frame == 1: # but image do not need sp set_sequence_parallel_state(False) else: set_sequence_parallel_state(True) if get_sequence_parallel_state(): x, cond_1, attn_mask, cond_mask_1, cond_2 = prepare_parallel_data( x, cond_1, attn_mask, cond_mask_1, cond_2 ) # x (b c t h w) -gather0-> (sp*b c t h w) -scatter2-> (sp*b c t//sp h w) # cond_1 (b sp l/sp d) -gather0-> (sp*b sp l/sp d) -scatter1-> (sp*b 1 l/sp d) # attn_mask (b t*sp h w) -gather0-> (sp*b t*sp h w) -scatter1-> (sp*b t h w) # cond_mask_1 (b sp l) -gather0-> (sp*b sp l) -scatter1-> (sp*b 1 l) # cond_2 (b sp d) -gather0-> (sp*b sp d) -scatter1-> (sp*b 1 d) for iter in range(args.train_batch_size * args.sp_size // args.train_sp_batch_size): with accelerator.accumulate(model): # x (sp_bs*b c t//sp h w) # cond_1 (sp_bs*b 1 l/sp d) # attn_mask (sp_bs*b t h w) # cond_mask_1 (sp_bs*b 1 l) # cond_2 (sp_bs*b 1 d) st_idx = iter * args.train_sp_batch_size ed_idx = (iter + 1) * args.train_sp_batch_size model_kwargs = dict( encoder_hidden_states=cond_1[st_idx: ed_idx], attention_mask=attn_mask[st_idx: ed_idx], encoder_attention_mask=cond_mask_1[st_idx: ed_idx], pooled_projections=cond_2[st_idx: ed_idx] if cond_2 is not None else None, ) run(x[st_idx: ed_idx], model_kwargs, prof_) else: with accelerator.accumulate(model): # assert not torch.any(torch.isnan(x)), 'after vae' x = x.to(weight_dtype) model_kwargs = dict( encoder_hidden_states=cond_1, attention_mask=attn_mask, encoder_attention_mask=cond_mask_1, pooled_projections=cond_2 ) run(x, model_kwargs, prof_) set_sequence_parallel_state(current_step_sp_state) # in case the next step use sp, which need broadcast(timesteps) if progress_info.global_step >= args.max_train_steps: return True return False def train_one_epoch(prof_=None): # for epoch in range(first_epoch, args.num_train_epochs): progress_info.train_loss = 0.0 if progress_info.global_step >= args.max_train_steps: return True args.after_one_epoch_global_step = progress_info.global_step + len(train_dataloader) // args.gradient_accumulation_steps - 1 for step, data_item in enumerate(train_dataloader): if train_one_step(step, data_item, prof_): break if step >= 2 and torch_npu is not None and npu_config is not None: npu_config.free_mm() if npu_config is not None and npu_config.on_npu and npu_config.profiling: experimental_config = torch_npu.profiler._ExperimentalConfig( profiler_level=torch_npu.profiler.ProfilerLevel.Level1, aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization ) profile_output_path = f"/home/image_data/npu_profiling_t2v/{os.getenv('PROJECT_NAME', 'local')}" os.makedirs(profile_output_path, exist_ok=True) with torch_npu.profiler.profile( activities=[ torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU, ], with_stack=True, record_shapes=True, profile_memory=True, experimental_config=experimental_config, schedule=torch_npu.profiler.schedule( wait=npu_config.profiling_step, warmup=0, active=1, repeat=1, skip_first=0 ), on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(f"{profile_output_path}/") ) as prof: train_one_epoch(prof) else: if args.enable_profiling: with torch.profiler.profile( activities=[ # torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], schedule=torch.profiler.schedule(wait=5, warmup=1, active=1, repeat=1, skip_first=0), on_trace_ready=torch.profiler.tensorboard_trace_handler('./gpu_profiling_active_1_delmask_delbkmask_andvaemask_curope_gpu'), record_shapes=True, profile_memory=True, with_stack=True ) as prof: train_one_epoch(prof) else: train_one_epoch() accelerator.wait_for_everyone() accelerator.end_training() if get_sequence_parallel_state(): destroy_sequence_parallel_group() if __name__ == "__main__": parser = argparse.ArgumentParser() # dataset & dataloader parser.add_argument("--dataset", type=str, required=True) parser.add_argument("--data", type=str, required='') parser.add_argument("--sample_rate", type=int, default=1) parser.add_argument("--train_fps", type=int, default=24) parser.add_argument("--drop_short_ratio", type=float, default=1.0) parser.add_argument("--speed_factor", type=float, default=1.0) parser.add_argument("--num_frames", type=int, default=65) parser.add_argument("--max_height", type=int, default=320) parser.add_argument("--max_width", type=int, default=240) parser.add_argument("--max_hxw", type=int, default=None) parser.add_argument("--min_hxw", type=int, default=None) parser.add_argument("--ood_img_ratio", type=float, default=0.0) parser.add_argument("--use_img_from_vid", action="store_true") parser.add_argument("--model_max_length", type=int, default=512) parser.add_argument('--cfg', type=float, default=0.1) parser.add_argument("--dataloader_num_workers", type=int, default=10, help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.") parser.add_argument("--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader.") parser.add_argument("--group_data", action="store_true") parser.add_argument("--hw_stride", type=int, default=32) parser.add_argument("--force_resolution", action="store_true") parser.add_argument("--trained_data_global_step", type=int, default=None) parser.add_argument("--use_decord", action="store_true") # text encoder & vae & diffusion model parser.add_argument('--vae_fp32', action='store_true') parser.add_argument('--extra_save_mem', action='store_true') parser.add_argument("--model", type=str, choices=list(Diffusion_models.keys()), default="Latte-XL/122") parser.add_argument('--enable_tiling', action='store_true') parser.add_argument('--interpolation_scale_h', type=float, default=1.0) parser.add_argument('--interpolation_scale_w', type=float, default=1.0) parser.add_argument('--interpolation_scale_t', type=float, default=1.0) parser.add_argument("--ae", type=str, default="stabilityai/sd-vae-ft-mse") parser.add_argument("--ae_path", type=str, default="stabilityai/sd-vae-ft-mse") parser.add_argument("--text_encoder_name_1", type=str, default='DeepFloyd/t5-v1_1-xxl') parser.add_argument("--text_encoder_name_2", type=str, default=None) parser.add_argument("--cache_dir", type=str, default='./cache_dir') parser.add_argument("--pretrained", type=str, default=None) parser.add_argument('--sparse1d', action='store_true') parser.add_argument('--sparse_n', type=int, default=2) parser.add_argument('--cogvideox_scheduler', action='store_true') parser.add_argument('--v1_5_scheduler', action='store_true') parser.add_argument("--gradient_checkpointing", action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.") # diffusion setting parser.add_argument("--snr_gamma", type=float, default=None, help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. More details here: https://arxiv.org/abs/2303.09556.") parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") parser.add_argument("--ema_decay", type=float, default=0.9999) parser.add_argument("--ema_start_step", type=int, default=0) parser.add_argument("--noise_offset", type=float, default=0.0, help="The scale of noise offset.") parser.add_argument("--prediction_type", type=str, default='epsilon', help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.") parser.add_argument('--rescale_betas_zero_snr', action='store_true') # validation & logs parser.add_argument("--enable_profiling", action="store_true") parser.add_argument("--num_sampling_steps", type=int, default=20) parser.add_argument('--guidance_scale', type=float, default=4.5) parser.add_argument("--enable_tracker", action="store_true") parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument("--output_dir", type=str, default=None, help="The output directory where the model predictions and checkpoints will be written.") parser.add_argument("--checkpoints_total_limit", type=int, default=None, help=("Max number of checkpoints to store.")) parser.add_argument("--checkpointing_steps", type=int, default=500, help=( "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" " training using `--resume_from_checkpoint`." ), ) parser.add_argument("--resume_from_checkpoint", type=str, default=None, help=( "Whether training should be resumed from a previous checkpoint. Use a path saved by" ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) parser.add_argument("--logging_dir", type=str, default="logs", help=( "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ), ) parser.add_argument("--report_to", type=str, default="tensorboard", help=( 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) # optimizer & scheduler parser.add_argument("--num_train_epochs", type=int, default=100) parser.add_argument("--max_train_steps", type=int, default=1000000, help="Total number of training steps to perform. If provided, overrides num_train_epochs.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.") parser.add_argument("--optimizer", type=str, default="adamW", help='The optimizer type to use. Choose between ["AdamW", "prodigy"]') parser.add_argument("--learning_rate", type=float, default=1e-4, help="Initial learning rate (after the potential warmup period) to use.") parser.add_argument("--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler.") parser.add_argument("--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW") parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers.") parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") parser.add_argument("--adam_weight_decay", type=float, default=1e-02, help="Weight decay to use for unet params") parser.add_argument("--adam_weight_decay_text_encoder", type=float, default=None, help="Weight decay to use for text_encoder") parser.add_argument("--adam_epsilon", type=float, default=1e-15, help="Epsilon value for the Adam optimizer and Prodigy optimizers.") parser.add_argument("--prodigy_use_bias_correction", type=bool, default=True, help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW") parser.add_argument("--prodigy_safeguard_warmup", type=bool, default=True, help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. Ignored if optimizer is adamW") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--prodigy_beta3", type=float, default=None, help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " "uses the value of square root of beta2. Ignored if optimizer is adamW", ) parser.add_argument("--lr_scheduler", type=str, default="constant", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' ), ) parser.add_argument("--allow_tf32", action="store_true", help=( "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--sp_size", type=int, default=1, help="For sequence parallel") parser.add_argument("--train_sp_batch_size", type=int, default=1, help="Batch size for sequence parallel training") # inpaint parser.add_argument("--mask_config", type=str, default=None) parser.add_argument("--add_noise_to_condition", action='store_true') parser.add_argument("--default_text_ratio", type=float, default=0.5) # for inpainting mode args = parser.parse_args() assert args.mask_config is not None, 'mask_config is required!' with open(args.mask_config, 'r') as f: yaml_config = yaml.safe_load(f) for key, value in yaml_config.items(): if not hasattr(args, key): setattr(args, key, value) main(args) ================================================ FILE: opensora/train/train_t2v_diffusers.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ A minimal training script for DiT using PyTorch DDP. """ import argparse import logging import math import os import shutil from pathlib import Path from typing import Optional import gc import numpy as np from einops import rearrange import torch.utils import torch.utils.data from tqdm import tqdm import time from opensora.adaptor.modules import replace_with_fp32_forwards try: import torch_npu from opensora.npu_config import npu_config from opensora.acceleration.parallel_states import initialize_sequence_parallel_state, \ destroy_sequence_parallel_group, get_sequence_parallel_state, set_sequence_parallel_state from opensora.acceleration.communications import prepare_parallel_data, broadcast except: torch_npu = None npu_config = None from opensora.utils.parallel_states import initialize_sequence_parallel_state, \ destroy_sequence_parallel_group, get_sequence_parallel_state, set_sequence_parallel_state from opensora.utils.communications import prepare_parallel_data, broadcast pass from dataclasses import field, dataclass from torch.utils.data import DataLoader from copy import deepcopy import accelerate import torch from torch.nn import functional as F import transformers from transformers.utils import ContextManagers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import DistributedType, ProjectConfiguration, set_seed from accelerate.state import AcceleratorState from packaging import version from tqdm.auto import tqdm import copy import diffusers from diffusers import DDPMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel, compute_snr from diffusers.utils import check_min_version, is_wandb_available from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3 from opensora.models.causalvideovae import ae_stride_config, ae_channel_config from opensora.models.causalvideovae import ae_norm, ae_denorm from opensora.models import CausalVAEModelWrapper from opensora.models.text_encoder import get_text_warpper from opensora.dataset import getdataset from opensora.models import CausalVAEModelWrapper from opensora.models.diffusion import Diffusion_models, Diffusion_models_class from opensora.utils.dataset_utils import Collate, LengthGroupedSampler from opensora.utils.utils import explicit_uniform_sampling from opensora.sample.pipeline_opensora import OpenSoraPipeline from opensora.models.causalvideovae import ae_stride_config, ae_wrapper # from opensora.utils.utils import monitor_npu_power # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.24.0") logger = get_logger(__name__) @torch.inference_mode() def log_validation(args, model, vae, text_encoder, tokenizer, accelerator, weight_dtype, global_step, ema=False): positive_prompt = "(masterpiece), (best quality), (ultra-detailed), {}. emotional, harmonious, vignette, 4k epic detailed, shot on kodak, 35mm photo, sharp focus, high budget, cinemascope, moody, epic, gorgeous" negative_prompt = """nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, """ validation_prompt = [ "a cat wearing sunglasses and working as a lifeguard at pool.", "A serene underwater scene featuring a sea turtle swimming through a coral reef. The turtle, with its greenish-brown shell, is the main focus of the video, swimming gracefully towards the right side of the frame. The coral reef, teeming with life, is visible in the background, providing a vibrant and colorful backdrop to the turtle's journey. Several small fish, darting around the turtle, add a sense of movement and dynamism to the scene." ] logger.info(f"Running validation....\n") model = accelerator.unwrap_model(model) scheduler = DPMSolverMultistepScheduler() opensora_pipeline = OpenSoraPipeline(vae=vae, text_encoder_1=text_encoder[0], text_encoder_2=text_encoder[1], tokenizer=tokenizer, scheduler=scheduler, transformer=model).to(device=accelerator.device) videos = [] for prompt in validation_prompt: logger.info('Processing the ({}) prompt'.format(prompt)) video = opensora_pipeline( positive_prompt.format(prompt), negative_prompt=negative_prompt, num_frames=args.num_frames, height=args.max_height, width=args.max_width, num_inference_steps=args.num_sampling_steps, guidance_scale=args.guidance_scale, enable_temporal_attentions=True, num_images_per_prompt=1, mask_feature=True, max_sequence_length=args.model_max_length, ).images videos.append(video[0]) # import ipdb;ipdb.set_trace() gc.collect() torch.cuda.empty_cache() videos = torch.stack(videos).numpy() videos = rearrange(videos, 'b t h w c -> b t c h w') for tracker in accelerator.trackers: if tracker.name == "tensorboard": if videos.shape[1] == 1: assert args.num_frames == 1 images = rearrange(videos, 'b 1 c h w -> (b 1) h w c') np_images = np.stack([np.asarray(img) for img in images]) tracker.writer.add_images(f"{'ema_' if ema else ''}validation", np_images, global_step, dataformats="NHWC") else: np_videos = np.stack([np.asarray(vid) for vid in videos]) tracker.writer.add_video(f"{'ema_' if ema else ''}validation", np_videos, global_step, fps=24) if tracker.name == "wandb": import wandb if videos.shape[1] == 1: images = rearrange(videos, 'b 1 c h w -> (b 1) h w c') logs = { f"{'ema_' if ema else ''}validation": [ wandb.Image(image, caption=f"{i}: {prompt}") for i, (image, prompt) in enumerate(zip(images, validation_prompt)) ] } else: logs = { f"{'ema_' if ema else ''}validation": [ wandb.Video(video, caption=f"{i}: {prompt}", fps=24) for i, (video, prompt) in enumerate(zip(videos, validation_prompt)) ] } tracker.log(logs, step=global_step) del opensora_pipeline gc.collect() torch.cuda.empty_cache() class ProgressInfo: def __init__(self, global_step, train_loss=0.0): self.global_step = global_step self.train_loss = train_loss ################################################################################# # Training Loop # ################################################################################# def main(args): logging_dir = Path(args.output_dir, args.logging_dir) if torch_npu is not None and npu_config is not None: npu_config.print_msg(args) npu_config.seed_everything(args.seed) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, ) if args.num_frames != 1: initialize_sequence_parallel_state(args.sp_size) if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") # if accelerator.is_main_process: # from threading import Thread # Thread(target=monitor_npu_power, daemon=True).start() # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed, device_specific=True) # Handle the repository creation if accelerator.is_main_process: if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) # For mixed precision training we cast all non-trainable weigths to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # Create model: # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. # For this to work properly all models must be run through `accelerate.prepare`. But accelerate # will try to assign the same optimizer with the same weights to all models during # `deepspeed.initialize`, which of course doesn't work. # # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 # frozen models from being partitioned during `zero.Init` which gets called during # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. def deepspeed_zero_init_disabled_context_manager(): """ returns either a context list that includes one that will disable zero.Init or an empty context list """ deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None if deepspeed_plugin is None: return [] return [deepspeed_plugin.zero3_init_context_manager(enable=False)] with ContextManagers(deepspeed_zero_init_disabled_context_manager()): kwargs = {} ae = ae_wrapper[args.ae](args.ae_path, cache_dir=args.cache_dir, **kwargs).eval() if args.enable_tiling: ae.vae.enable_tiling() kwargs = { 'torch_dtype': weight_dtype, 'low_cpu_mem_usage': False } text_enc_1 = get_text_warpper(args.text_encoder_name_1)(args, **kwargs).eval() text_enc_2 = None if args.text_encoder_name_2 is not None: text_enc_2 = get_text_warpper(args.text_encoder_name_2)(args, **kwargs).eval() ae_stride_t, ae_stride_h, ae_stride_w = ae_stride_config[args.ae] ae.vae_scale_factor = (ae_stride_t, ae_stride_h, ae_stride_w) assert ae_stride_h == ae_stride_w, f"Support only ae_stride_h == ae_stride_w now, but found ae_stride_h ({ae_stride_h}), ae_stride_w ({ae_stride_w})" args.ae_stride_t, args.ae_stride_h, args.ae_stride_w = ae_stride_t, ae_stride_h, ae_stride_w args.ae_stride = args.ae_stride_h patch_size = args.model[-3:] patch_size_t, patch_size_h, patch_size_w = int(patch_size[0]), int(patch_size[1]), int(patch_size[2]) args.patch_size = patch_size_h args.patch_size_t, args.patch_size_h, args.patch_size_w = patch_size_t, patch_size_h, patch_size_w assert patch_size_h == patch_size_w, f"Support only patch_size_h == patch_size_w now, but found patch_size_h ({patch_size_h}), patch_size_w ({patch_size_w})" assert (args.num_frames - 1) % ae_stride_t == 0, f"(Frames - 1) must be divisible by ae_stride_t, but found num_frames ({args.num_frames}), ae_stride_t ({ae_stride_t})." assert args.max_height % ae_stride_h == 0, f"Height must be divisible by ae_stride_h, but found Height ({args.max_height}), ae_stride_h ({ae_stride_h})." assert args.max_width % ae_stride_h == 0, f"Width size must be divisible by ae_stride_h, but found Width ({args.max_width}), ae_stride_h ({ae_stride_h})." args.stride_t = ae_stride_t * patch_size_t args.stride = ae_stride_h * patch_size_h ae.latent_size = latent_size = (args.max_height // ae_stride_h, args.max_width // ae_stride_w) args.latent_size_t = latent_size_t = (args.num_frames - 1) // ae_stride_t + 1 model = Diffusion_models[args.model]( in_channels=ae_channel_config[args.ae], out_channels=ae_channel_config[args.ae], sample_size_h=latent_size, sample_size_w=latent_size, sample_size_t=latent_size_t, interpolation_scale_h=args.interpolation_scale_h, interpolation_scale_w=args.interpolation_scale_w, interpolation_scale_t=args.interpolation_scale_t, sparse1d=args.sparse1d, sparse_n=args.sparse_n, skip_connection=args.skip_connection, ) # # use pretrained model? if args.pretrained: model_state_dict = model.state_dict() print(f'Load from {args.pretrained}') if args.pretrained.endswith('.safetensors'): from safetensors.torch import load_file as safe_load pretrained_checkpoint = safe_load(args.pretrained, device="cpu") pretrained_keys = set(list(pretrained_checkpoint.keys())) model_keys = set(list(model_state_dict.keys())) common_keys = list(pretrained_keys & model_keys) checkpoint = {k: pretrained_checkpoint[k] for k in common_keys if model_state_dict[k].numel() == pretrained_checkpoint[k].numel()} missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) elif os.path.isdir(args.pretrained): model = Diffusion_models_class[args.model].from_pretrained(args.pretrained) missing_keys, unexpected_keys = [], [] else: pretrained_checkpoint = torch.load(args.pretrained, map_location='cpu') if 'model' in checkpoint: pretrained_checkpoint = pretrained_checkpoint['model'] pretrained_keys = set(list(pretrained_checkpoint.keys())) model_keys = set(list(model_state_dict.keys())) common_keys = list(pretrained_keys & model_keys) checkpoint = {k: pretrained_checkpoint[k] for k in common_keys if model_state_dict[k].numel() == pretrained_checkpoint[k].numel()} missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) print(f'missing_keys {len(missing_keys)} {missing_keys}, unexpected_keys {len(unexpected_keys)}') print(f'Successfully load {len(model_state_dict) - len(missing_keys)}/{len(model_state_dict)} keys from {args.pretrained}!') model.gradient_checkpointing = args.gradient_checkpointing # Freeze vae and text encoders. ae.vae.requires_grad_(False) text_enc_1.requires_grad_(False) if text_enc_2 is not None: text_enc_2.requires_grad_(False) # Set model as trainable. model.train() kwargs = dict( prediction_type=args.prediction_type, rescale_betas_zero_snr=args.rescale_betas_zero_snr ) if args.cogvideox_scheduler: noise_scheduler = CogVideoXDDIMScheduler(**kwargs) elif args.v1_5_scheduler: kwargs['beta_start'] = 0.00085 kwargs['beta_end'] = 0.0120 kwargs['beta_schedule'] = "scaled_linear" noise_scheduler = DDPMScheduler(**kwargs) elif args.rf_scheduler: noise_scheduler = FlowMatchEulerDiscreteScheduler() noise_scheduler_copy = copy.deepcopy(noise_scheduler) else: noise_scheduler = DDPMScheduler(**kwargs) # Move unet, vae and text_encoder to device and cast to weight_dtype # The VAE is in float32 to avoid NaN losses. if not args.extra_save_mem: ae.vae.to(accelerator.device, dtype=torch.float32 if args.vae_fp32 else weight_dtype) text_enc_1.to(accelerator.device, dtype=weight_dtype) if text_enc_2 is not None: text_enc_2.to(accelerator.device, dtype=weight_dtype) # Create EMA for the unet. if args.use_ema: ema_model = deepcopy(model) ema_model = EMAModel(ema_model.parameters(), decay=args.ema_decay, update_after_step=args.ema_start_step, model_cls=Diffusion_models_class[args.model], model_config=ema_model.config, foreach=args.foreach_ema) # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: if args.use_ema: ema_model.save_pretrained(os.path.join(output_dir, "model_ema")) for i, model in enumerate(models): model.save_pretrained(os.path.join(output_dir, "model")) if weights: # Don't pop if empty # make sure to pop weight so that corresponding model is not saved again weights.pop() def load_model_hook(models, input_dir): if args.use_ema: load_model = EMAModel.from_pretrained( os.path.join(input_dir, "model_ema"), Diffusion_models_class[args.model], foreach=args.foreach_ema, ) ema_model.load_state_dict(load_model.state_dict()) if args.offload_ema: ema_model.pin_memory() else: ema_model.to(accelerator.device) del load_model for i in range(len(models)): # pop models so that they are not loaded again model = models.pop() # load diffusers style into model load_model = Diffusion_models_class[args.model].from_pretrained(input_dir, subfolder="model") model.register_to_config(**load_model.config) model.load_state_dict(load_model.state_dict()) del load_model accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True params_to_optimize = model.parameters() # Optimizer creation if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): logger.warning( f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." "Defaulting to adamW" ) args.optimizer = "adamw" if args.use_8bit_adam and not args.optimizer.lower() == "adamw": logger.warning( f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " f"set to {args.optimizer.lower()}" ) if args.optimizer.lower() == "adamw": if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError( "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." ) optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) if args.optimizer.lower() == "prodigy": try: import prodigyopt except ImportError: raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") optimizer_class = prodigyopt.Prodigy if args.learning_rate <= 0.1: logger.warning( "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" ) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), beta3=args.prodigy_beta3, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, decouple=args.prodigy_decouple, use_bias_correction=args.prodigy_use_bias_correction, safeguard_warmup=args.prodigy_safeguard_warmup, ) logger.info(f"optimizer: {optimizer}") # Setup data: if args.trained_data_global_step is not None: initial_global_step_for_sampler = args.trained_data_global_step else: initial_global_step_for_sampler = 0 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = total_batch_size // args.sp_size * args.train_sp_batch_size args.total_batch_size = total_batch_size if args.max_hxw is not None and args.min_hxw is None: args.min_hxw = args.max_hxw // 4 train_dataset = getdataset(args) sampler = LengthGroupedSampler( args.train_batch_size, world_size=accelerator.num_processes, gradient_accumulation_size=args.gradient_accumulation_steps, initial_global_step=initial_global_step_for_sampler, lengths=train_dataset.lengths, group_data=args.group_data, ) train_dataloader = DataLoader( train_dataset, shuffle=False, pin_memory=True, collate_fn=Collate(args), batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers, sampler=sampler, drop_last=True, # prefetch_factor=4 ) logger.info(f'after train_dataloader') # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) # Prepare everything with our `accelerator`. # model.requires_grad_(False) # model.pos_embed.requires_grad_(True) # model.patch_embed.requires_grad_(True) logger.info(f'before accelerator.prepare') model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler ) logger.info(f'after accelerator.prepare') if args.use_ema: if args.offload_ema: ema_model.pin_memory() else: ema_model.to(accelerator.device) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: accelerator.init_trackers(os.path.basename(args.output_dir), config=vars(args)) # Train! print(f" Args = {args}") print(f" noise_scheduler = {noise_scheduler}") logger.info("***** Running training *****") logger.info(f" Model = {model}") logger.info(f" Args = {args}") logger.info(f" Noise_scheduler = {noise_scheduler}") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") logger.info(f" Total optimization steps (num_update_steps_per_epoch) = {num_update_steps_per_epoch}") logger.info(f" Total training parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9} B") logger.info(f" AutoEncoder = {args.ae}; Dtype = {ae.vae.dtype}; Parameters = {sum(p.numel() for p in ae.parameters()) / 1e9} B") logger.info(f" Text_enc_1 = {args.text_encoder_name_1}; Dtype = {weight_dtype}; Parameters = {sum(p.numel() for p in text_enc_1.parameters()) / 1e9} B") if args.text_encoder_name_2 is not None: logger.info(f" Text_enc_2 = {args.text_encoder_name_2}; Dtype = {weight_dtype}; Parameters = {sum(p.numel() for p in text_enc_2.parameters()) / 1e9} B") global_step = 0 first_epoch = 0 # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) else: # Get the most recent checkpoint dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None if path is None: accelerator.print( f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch else: initial_global_step = 0 progress_bar = tqdm( range(0, args.max_train_steps), initial=initial_global_step, desc="Steps", # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) progress_info = ProgressInfo(global_step, train_loss=0.0) def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) timesteps = timesteps.to(accelerator.device) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < n_dim: sigma = sigma.unsqueeze(-1) return sigma def sync_gradients_info(loss): # Checks if the accelerator has performed an optimization step behind the scenes if args.use_ema: if args.offload_ema: ema_model.to(device="cuda", non_blocking=True) ema_model.step(model.parameters()) if args.offload_ema: ema_model.to(device="cpu", non_blocking=True) progress_bar.update(1) progress_info.global_step += 1 end_time = time.time() one_step_duration = end_time - start_time if progress_info.global_step % args.log_interval == 0: train_loss = progress_info.train_loss.item() / args.log_interval accelerator.log({"train_loss": train_loss, "lr": lr_scheduler.get_last_lr()[0]}, step=progress_info.global_step) if torch_npu is not None and npu_config is not None: npu_config.print_msg(f"Step: [{progress_info.global_step}], local_loss={loss.detach().item()}, " f"train_loss={train_loss}, time_cost={one_step_duration}", rank=0) progress_info.train_loss = torch.tensor(0.0, device=loss.device) # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues. if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: if progress_info.global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if accelerator.is_main_process and args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") for removing_checkpoint in removing_checkpoints: removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) shutil.rmtree(removing_checkpoint) save_path = os.path.join(args.output_dir, f"checkpoint-{progress_info.global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) def run(step_, model_input, model_kwargs, prof): # print("rank {} | step {} | cd run fun".format(accelerator.process_index, step_)) global start_time start_time = time.time() noise = torch.randn_like(model_input) bsz = model_input.shape[0] if not args.rf_scheduler: if args.noise_offset: # https://www.crosslabs.org//blog/diffusion-with-offset-noise noise += args.noise_offset * torch.randn((model_input.shape[0], model_input.shape[1], 1, 1, 1), device=model_input.device) # Sample a random timestep for each image without bias. timesteps = explicit_uniform_sampling( T=noise_scheduler.config.num_train_timesteps, n=accelerator.num_processes, rank=accelerator.process_index, bsz=bsz, device=model_input.device, ) if get_sequence_parallel_state(): # image do not need sp, disable when image batch broadcast(timesteps) # Add noise to the model input according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly u = compute_density_for_timestep_sampling( weighting_scheme=args.weighting_scheme, batch_size=bsz, logit_mean=args.logit_mean, logit_std=args.logit_std, mode_scale=args.mode_scale, ) indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) # Add noise according to flow matching. # zt = (1 - texp) * x + texp * z1 sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise model_pred = model( noisy_model_input, timesteps, **model_kwargs )[0] mask = model_kwargs.get('attention_mask', None) if not args.rf_scheduler: # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(model_input, noise, timesteps) elif noise_scheduler.config.prediction_type == "sample": # We set the target to latents here, but the model_pred will return the noise sample prediction. target = model_input # We will have to subtract the noise residual from the prediction to get the target sample. model_pred = model_pred - noise else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") if get_sequence_parallel_state(): if torch.all(mask.bool()): mask = None # mask (sp_bs*b t h w) assert mask is None b, c, _, _, _ = model_pred.shape if mask is not None: mask = mask.unsqueeze(1).repeat(1, c, 1, 1, 1).float() # b t h w -> b c t h w mask = mask.reshape(b, -1) if args.snr_gamma is None: # model_pred: b c t h w, attention_mask: b t h w loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.reshape(b, -1) if mask is not None: loss = (loss * mask).sum() / mask.sum() # mean loss on unpad patches else: loss = loss.mean() else: # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( dim=1 )[0] if noise_scheduler.config.prediction_type == "epsilon": mse_loss_weights = mse_loss_weights / snr elif noise_scheduler.config.prediction_type == "v_prediction": mse_loss_weights = mse_loss_weights / (snr + 1) else: raise NameError(f'{noise_scheduler.config.prediction_type}') loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.reshape(b, -1) mse_loss_weights = mse_loss_weights.reshape(b, 1) if mask is not None: loss = (loss * mask * mse_loss_weights).sum() / mask.sum() # mean loss on unpad patches else: loss = (loss * mse_loss_weights).mean() else: if torch.all(mask.bool()): mask = None b, c, _, _, _ = model_pred.shape if mask is not None: mask = mask.unsqueeze(1).repeat(1, c, 1, 1, 1).float() # b t h w -> b c t h w mask = mask.reshape(b, -1) # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) # flow matching loss target = noise - model_input # Compute regular loss. loss_mse = (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1) if mask is not None: loss = (loss_mse * mask).sum() / mask.sum() else: loss = loss_mse.mean() # Gather the losses across all processes for logging (if we use distributed training). avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() # avg_loss = accelerator.reduce(loss, reduction="mean") # progress_info.train_loss += avg_loss.detach().item() / args.gradient_accumulation_steps progress_info.train_loss += avg_loss.detach() / args.gradient_accumulation_steps # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = model.parameters() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad() if accelerator.sync_gradients: sync_gradients_info(loss) if accelerator.is_main_process: if progress_info.global_step % args.checkpointing_steps == 0: if args.enable_tracker: log_validation( args, model, ae, [text_enc_1.text_enc, getattr(text_enc_2, 'text_enc', None)], train_dataset.tokenizer, accelerator, weight_dtype, progress_info.global_step ) if args.use_ema and npu_config is None: # Store the UNet parameters temporarily and load the EMA parameters to perform inference. ema_model.store(model.parameters()) ema_model.copy_to(model.parameters()) log_validation( args, model, ae, [text_enc_1.text_enc, getattr(text_enc_2, 'text_enc', None)], train_dataset.tokenizer, accelerator, weight_dtype, progress_info.global_step, ema=True ) # Switch back to the original UNet parameters. ema_model.restore(model.parameters()) if prof is not None: prof.step() return loss def train_one_step(step_, data_item_, prof_=None): train_loss = 0.0 # print("rank {} | step {} | unzip data".format(accelerator.process_index, step_)) x, attn_mask, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2 = data_item_ # print(f'step: {step_}, rank: {accelerator.process_index}, x: {x.shape}, dtype: {x.dtype}') # assert not torch.any(torch.isnan(x)), 'torch.any(torch.isnan(x))' if args.extra_save_mem: torch.cuda.empty_cache() ae.vae.to(accelerator.device, dtype=torch.float32 if args.vae_fp32 else weight_dtype) text_enc_1.to(accelerator.device, dtype=weight_dtype) if text_enc_2 is not None: text_enc_2.to(accelerator.device, dtype=weight_dtype) x = x.to(accelerator.device, dtype=ae.vae.dtype) # B C T H W # x = x.to(accelerator.device, dtype=torch.float32) # B C T H W attn_mask = attn_mask.to(accelerator.device) # B T H W input_ids_1 = input_ids_1.to(accelerator.device) # B 1 L cond_mask_1 = cond_mask_1.to(accelerator.device) # B 1 L input_ids_2 = input_ids_2.to(accelerator.device) if input_ids_2 is not None else input_ids_2 # B 1 L cond_mask_2 = cond_mask_2.to(accelerator.device) if cond_mask_2 is not None else cond_mask_2 # B 1 L with torch.no_grad(): B, N, L = input_ids_1.shape # B 1 L # use batch inference input_ids_1 = input_ids_1.reshape(-1, L) cond_mask_1 = cond_mask_1.reshape(-1, L) cond_1 = text_enc_1(input_ids_1, cond_mask_1) # B L D cond_1 = cond_1.reshape(B, N, L, -1) cond_mask_1 = cond_mask_1.reshape(B, N, L) if text_enc_2 is not None: B_, N_, L_ = input_ids_2.shape # B 1 L input_ids_2 = input_ids_2.reshape(-1, L_) cond_2 = text_enc_2(input_ids_2, cond_mask_2) # B D cond_2 = cond_2.reshape(B_, 1, -1) # B 1 D else: cond_2 = None # Map input images to latent space + normalize latents x = ae.encode(x) # B C T H W # print(f'step: {step_}, rank: {accelerator.process_index}, after vae.encode, x: {x.shape}, dtype: {x.dtype}, mean: {x.mean()}, std: {x.std()}') # x = torch.rand(1, 32, 14, 80, 80).to(x.device, dtype=x.dtype) # def custom_to_video(x: torch.Tensor, fps: float = 2.0, output_file: str = 'output_video.mp4') -> None: # from examples.rec_video import array_to_video # x = x.detach().cpu() # x = torch.clamp(x, -1, 1) # x = (x + 1) / 2 # x = x.permute(1, 2, 3, 0).numpy() # x = (255*x).astype(np.uint8) # array_to_video(x, fps=fps, output_file=output_file) # return # videos = ae.decode(x)[0] # videos = videos.transpose(0, 1) # custom_to_video(videos.to(torch.float32), fps=24, output_file='tmp.mp4') # import sys;sys.exit() # print("rank {} | step {} | after encode".format(accelerator.process_index, step_)) if args.extra_save_mem: ae.vae.to('cpu') text_enc_1.to('cpu') if text_enc_2 is not None: text_enc_2.to('cpu') torch.cuda.empty_cache() current_step_frame = x.shape[2] current_step_sp_state = get_sequence_parallel_state() if args.sp_size != 1: # enable sp if current_step_frame == 1: # but image do not need sp set_sequence_parallel_state(False) else: set_sequence_parallel_state(True) if get_sequence_parallel_state(): x, cond_1, attn_mask, cond_mask_1, cond_2 = prepare_parallel_data( x, cond_1, attn_mask, cond_mask_1, cond_2 ) # x (b c t h w) -gather0-> (sp*b c t h w) -scatter2-> (sp*b c t//sp h w) # cond_1 (b sp l/sp d) -gather0-> (sp*b sp l/sp d) -scatter1-> (sp*b 1 l/sp d) # attn_mask (b t*sp h w) -gather0-> (sp*b t*sp h w) -scatter1-> (sp*b t h w) # cond_mask_1 (b sp l) -gather0-> (sp*b sp l) -scatter1-> (sp*b 1 l) # cond_2 (b sp d) -gather0-> (sp*b sp d) -scatter1-> (sp*b 1 d) for iter in range(args.train_batch_size * args.sp_size // args.train_sp_batch_size): with accelerator.accumulate(model): # x (sp_bs*b c t//sp h w) # cond_1 (sp_bs*b 1 l/sp d) # attn_mask (sp_bs*b t h w) # cond_mask_1 (sp_bs*b 1 l) # cond_2 (sp_bs*b 1 d) st_idx = iter * args.train_sp_batch_size ed_idx = (iter + 1) * args.train_sp_batch_size model_kwargs = dict( encoder_hidden_states=cond_1[st_idx: ed_idx], attention_mask=attn_mask[st_idx: ed_idx], encoder_attention_mask=cond_mask_1[st_idx: ed_idx], pooled_projections=cond_2[st_idx: ed_idx] if cond_2 is not None else None, ) run(step_, x[st_idx: ed_idx], model_kwargs, prof_) else: with accelerator.accumulate(model): # assert not torch.any(torch.isnan(x)), 'after vae' x = x.to(weight_dtype) model_kwargs = dict( encoder_hidden_states=cond_1, attention_mask=attn_mask, encoder_attention_mask=cond_mask_1, pooled_projections=cond_2 ) run(step_, x, model_kwargs, prof_) set_sequence_parallel_state(current_step_sp_state) # in case the next step use sp, which need broadcast(timesteps) if progress_info.global_step >= args.max_train_steps: return True return False def train_one_epoch(prof_=None): # for epoch in range(first_epoch, args.num_train_epochs): progress_info.train_loss = 0.0 if progress_info.global_step >= args.max_train_steps: return True for step, data_item in enumerate(train_dataloader): # print("rank {} | step {} | get data".format(accelerator.process_index, step)) if train_one_step(step, data_item, prof_): break if step >= 2 and torch_npu is not None and npu_config is not None: npu_config.free_mm() if npu_config is not None and npu_config.on_npu and npu_config.profiling: experimental_config = torch_npu.profiler._ExperimentalConfig( profiler_level=torch_npu.profiler.ProfilerLevel.Level1, aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization ) profile_output_path = f"/home/image_data/npu_profiling_t2v/{os.getenv('PROJECT_NAME', 'local')}" os.makedirs(profile_output_path, exist_ok=True) with torch_npu.profiler.profile( activities=[ torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU, ], with_stack=True, record_shapes=True, profile_memory=True, experimental_config=experimental_config, schedule=torch_npu.profiler.schedule( wait=npu_config.profiling_step, warmup=0, active=1, repeat=1, skip_first=0 ), on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(f"{profile_output_path}/") ) as prof: train_one_epoch(prof) else: if args.enable_profiling: with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], schedule=torch.profiler.schedule(wait=5, warmup=1, active=1, repeat=1, skip_first=0), on_trace_ready=torch.profiler.tensorboard_trace_handler('./gpu_profiling_active_1_delmask_delbkmask_andvaemask_curope_gpu'), record_shapes=True, profile_memory=True, with_stack=True ) as prof: train_one_epoch(prof) else: train_one_epoch() accelerator.wait_for_everyone() accelerator.end_training() if get_sequence_parallel_state(): destroy_sequence_parallel_group() if __name__ == "__main__": parser = argparse.ArgumentParser() # dataset & dataloader parser.add_argument("--dataset", type=str, required=True) parser.add_argument("--data", type=str, required='') parser.add_argument("--sample_rate", type=int, default=1) parser.add_argument("--train_fps", type=int, default=24) parser.add_argument("--drop_short_ratio", type=float, default=1.0) parser.add_argument("--speed_factor", type=float, default=1.0) parser.add_argument("--num_frames", type=int, default=65) parser.add_argument("--max_height", type=int, default=320) parser.add_argument("--max_width", type=int, default=240) parser.add_argument("--max_hxw", type=int, default=None) parser.add_argument("--min_hxw", type=int, default=None) parser.add_argument("--ood_img_ratio", type=float, default=0.0) parser.add_argument("--use_img_from_vid", action="store_true") parser.add_argument("--model_max_length", type=int, default=512) parser.add_argument('--cfg', type=float, default=0.1) parser.add_argument("--dataloader_num_workers", type=int, default=10, help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.") parser.add_argument("--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader.") parser.add_argument("--group_data", action="store_true") parser.add_argument("--hw_stride", type=int, default=32) parser.add_argument("--force_resolution", action="store_true") parser.add_argument("--trained_data_global_step", type=int, default=None) parser.add_argument("--use_decord", action="store_true") # text encoder & vae & diffusion model parser.add_argument('--vae_fp32', action='store_true') parser.add_argument('--extra_save_mem', action='store_true') parser.add_argument("--model", type=str, choices=list(Diffusion_models.keys()), default="Latte-XL/122") parser.add_argument('--enable_tiling', action='store_true') parser.add_argument('--interpolation_scale_h', type=float, default=1.0) parser.add_argument('--interpolation_scale_w', type=float, default=1.0) parser.add_argument('--interpolation_scale_t', type=float, default=1.0) parser.add_argument("--ae", type=str, default="stabilityai/sd-vae-ft-mse") parser.add_argument("--ae_path", type=str, default="stabilityai/sd-vae-ft-mse") parser.add_argument("--text_encoder_name_1", type=str, default='DeepFloyd/t5-v1_1-xxl') parser.add_argument("--text_encoder_name_2", type=str, default=None) parser.add_argument("--cache_dir", type=str, default='./cache_dir') parser.add_argument("--pretrained", type=str, default=None) parser.add_argument('--sparse1d', action='store_true') parser.add_argument('--sparse_n', type=int, default=2) parser.add_argument('--skip_connection', action='store_true') parser.add_argument('--cogvideox_scheduler', action='store_true') parser.add_argument('--v1_5_scheduler', action='store_true') parser.add_argument('--rf_scheduler', action='store_true') parser.add_argument("--weighting_scheme", type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"]) parser.add_argument("--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme.") parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.") parser.add_argument("--mode_scale", type=float, default=1.29, help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.") parser.add_argument("--gradient_checkpointing", action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.") # diffusion setting parser.add_argument("--snr_gamma", type=float, default=None, help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. More details here: https://arxiv.org/abs/2303.09556.") parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") parser.add_argument("--ema_decay", type=float, default=0.9999) parser.add_argument("--ema_start_step", type=int, default=0) parser.add_argument("--offload_ema", action="store_true", help="Offload EMA model to CPU during training step.") parser.add_argument("--foreach_ema", action="store_true", help="Use faster foreach implementation of EMAModel.") parser.add_argument("--noise_offset", type=float, default=0.0, help="The scale of noise offset.") parser.add_argument("--prediction_type", type=str, default='epsilon', help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.") parser.add_argument('--rescale_betas_zero_snr', action='store_true') # validation & logs parser.add_argument("--log_interval", type=int, default=10) parser.add_argument("--enable_profiling", action="store_true") parser.add_argument("--num_sampling_steps", type=int, default=20) parser.add_argument('--guidance_scale', type=float, default=4.5) parser.add_argument("--enable_tracker", action="store_true") parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument("--output_dir", type=str, default=None, help="The output directory where the model predictions and checkpoints will be written.") parser.add_argument("--checkpoints_total_limit", type=int, default=None, help=("Max number of checkpoints to store.")) parser.add_argument("--checkpointing_steps", type=int, default=500, help=( "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" " training using `--resume_from_checkpoint`." ), ) parser.add_argument("--resume_from_checkpoint", type=str, default=None, help=( "Whether training should be resumed from a previous checkpoint. Use a path saved by" ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) parser.add_argument("--logging_dir", type=str, default="logs", help=( "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ), ) parser.add_argument("--report_to", type=str, default="tensorboard", help=( 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) # optimizer & scheduler parser.add_argument("--num_train_epochs", type=int, default=100) parser.add_argument("--max_train_steps", type=int, default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.") parser.add_argument("--optimizer", type=str, default="adamW", help='The optimizer type to use. Choose between ["AdamW", "prodigy"]') parser.add_argument("--learning_rate", type=float, default=1e-4, help="Initial learning rate (after the potential warmup period) to use.") parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.") parser.add_argument("--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW") parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers.") parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") parser.add_argument("--adam_weight_decay", type=float, default=1e-02, help="Weight decay to use for unet params") parser.add_argument("--adam_weight_decay_text_encoder", type=float, default=None, help="Weight decay to use for text_encoder") parser.add_argument("--adam_epsilon", type=float, default=1e-15, help="Epsilon value for the Adam optimizer and Prodigy optimizers.") parser.add_argument("--prodigy_use_bias_correction", type=bool, default=True, help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW") parser.add_argument("--prodigy_safeguard_warmup", type=bool, default=True, help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. Ignored if optimizer is adamW") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--prodigy_beta3", type=float, default=None, help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " "uses the value of square root of beta2. Ignored if optimizer is adamW", ) parser.add_argument("--lr_scheduler", type=str, default="constant", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' ), ) parser.add_argument("--allow_tf32", action="store_true", help=( "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--sp_size", type=int, default=1, help="For sequence parallel") parser.add_argument("--train_sp_batch_size", type=int, default=1, help="Batch size for sequence parallel training") args = parser.parse_args() main(args) ================================================ FILE: opensora/utils/communications.py ================================================ import torch import torch.distributed as dist from einops import rearrange from opensora.utils.parallel_states import nccl_info def broadcast(input_: torch.Tensor): sp_size = nccl_info.world_size src = nccl_info.rank // sp_size * sp_size dist.broadcast(input_, src=src, group=nccl_info.group) _COUNT = 0 def _all_to_all( input_: torch.Tensor, scatter_dim: int, gather_dim: int, ): group = nccl_info.group sp_size = nccl_info.world_size input_list = [t.contiguous() for t in torch.tensor_split(input_, sp_size, scatter_dim)] output_list = [torch.empty_like(input_list[0]) for _ in range(sp_size)] dist.all_to_all(output_list, input_list, group=group) return torch.cat(output_list, dim=gather_dim).contiguous() def _single_all_to_all( input_: torch.Tensor, scatter_dim: int, gather_dim: int, enable_HCCL=False, ): sp_size = nccl_info.world_size inp_shape = list(input_.shape) inp_shape[scatter_dim] = inp_shape[scatter_dim] // sp_size if scatter_dim < 1: input_t = input_.reshape( [sp_size, inp_shape[scatter_dim]] + \ inp_shape[scatter_dim + 1:] ) else: # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! input_t = input_.reshape( [-1, sp_size, inp_shape[scatter_dim]] + \ inp_shape[scatter_dim + 1:] ).transpose(0, 1).contiguous() output = torch.empty_like(input_t) dist.all_to_all_single(output, input_t, group=nccl_info.group) # if scattering the seq-dim, transpose the heads back to the original dimension if scatter_dim < 1: output = output.transpose(0, 1).contiguous() return output.reshape( inp_shape[: gather_dim] + [inp_shape[gather_dim] * sp_size, ] + inp_shape[gather_dim + 1:]) class _AllToAll(torch.autograd.Function): """All-to-all communication. Args: input_: input matrix process_group: communication group scatter_dim: scatter dimension gather_dim: gather dimension """ @staticmethod def forward(ctx, input_, scatter_dim, gather_dim, all_to_all_func): ctx.scatter_dim = scatter_dim ctx.gather_dim = gather_dim ctx.all_to_all = all_to_all_func output = ctx.all_to_all(input_, scatter_dim, gather_dim) return output @staticmethod def backward(ctx, grad_output): grad_output = ctx.all_to_all( grad_output, ctx.gather_dim, ctx.scatter_dim, ) return ( grad_output, None, None, None, ) def all_to_all_SBH( input_: torch.Tensor, scatter_dim: int = 1, gather_dim: int = 0, ): return _AllToAll.apply(input_, scatter_dim, gather_dim, _single_all_to_all) def all_to_all_BSND( input_: torch.Tensor, scatter_dim: int = 2, gather_dim: int = 1, ): return _AllToAll.apply(input_, scatter_dim, gather_dim, _all_to_all) def prepare_parallel_data( hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask, pooled_projections, ): def all_to_all( hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask, pooled_projections, ): # hidden_states (b c t h w) -gather0-> (sp*b c t h w) -scatter2-> (sp*b c t//sp h w) # encoder_hidden_states (b sp l/sp d) -gather0-> (sp*b sp l/sp d) -scatter1-> (sp*b 1 l/sp d) # attention_mask (b t*sp h w) -gather0-> (sp*b t*sp h w) -scatter1-> (sp*b t h w) # encoder_attention_mask (b sp l) -gather0-> (sp*b sp l) -scatter1-> (sp*b 1 l) # pooled_projections (b sp d) -gather0-> (sp*b sp d) -scatter1-> (sp*b 1 d) hidden_states = _single_all_to_all(hidden_states, scatter_dim=2, gather_dim=0, enable_HCCL=True) encoder_hidden_states = _single_all_to_all(encoder_hidden_states, scatter_dim=1, gather_dim=0, enable_HCCL=True) attention_mask = _single_all_to_all(attention_mask, scatter_dim=1, gather_dim=0, enable_HCCL=True) encoder_attention_mask = _single_all_to_all(encoder_attention_mask, scatter_dim=1, gather_dim=0, enable_HCCL=True) if pooled_projections is not None: pooled_projections = _single_all_to_all(pooled_projections, scatter_dim=1, gather_dim=0, enable_HCCL=True) return hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask, pooled_projections sp_size = nccl_info.world_size frame = hidden_states.shape[2] assert frame % sp_size == 0, "frame should be a multiple of sp_size" encoder_hidden_states = rearrange( encoder_hidden_states, 'b 1 (n x) h -> b n x h', n=sp_size, x=encoder_hidden_states.shape[2]//sp_size ).contiguous() hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask, pooled_projections = all_to_all( hidden_states, encoder_hidden_states, attention_mask.repeat(1, sp_size, 1, 1), encoder_attention_mask.repeat(1, sp_size, 1), pooled_projections.repeat(1, sp_size, 1) if pooled_projections is not None else None, ) return hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask, pooled_projections ================================================ FILE: opensora/utils/dataset_utils.py ================================================ import math from einops import rearrange import decord from torch.nn import functional as F import torch from typing import Optional import torch.utils import torch.utils.data import torch from torch.utils.data import Sampler from typing import List from collections import Counter, defaultdict import random IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG'] def is_image_file(filename): return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) class DecordInit(object): """Using Decord(https://github.com/dmlc/decord) to initialize the video_reader.""" def __init__(self, num_threads=1): self.num_threads = num_threads self.ctx = decord.cpu(0) def __call__(self, filename): """Perform the Decord initialization. Args: results (dict): The resulting dict to be modified and passed to the next transform in pipeline. """ reader = decord.VideoReader(filename, ctx=self.ctx, num_threads=self.num_threads) return reader def __repr__(self): repr_str = (f'{self.__class__.__name__}(' f'sr={self.sr},' f'num_threads={self.num_threads})') return repr_str def pad_to_multiple(number, ds_stride): remainder = number % ds_stride if remainder == 0: return number else: padding = ds_stride - remainder return number + padding class Collate: def __init__(self, args): self.batch_size = args.train_batch_size self.group_data = args.group_data self.force_resolution = args.force_resolution self.max_height = args.max_height self.max_width = args.max_width self.ae_stride = args.ae_stride self.ae_stride_t = args.ae_stride_t self.ae_stride_thw = (self.ae_stride_t, self.ae_stride, self.ae_stride) self.patch_size = args.patch_size self.patch_size_t = args.patch_size_t self.num_frames = args.num_frames self.max_thw = (self.num_frames, self.max_height, self.max_width) def package(self, batch): batch_tubes = [i['pixel_values'] for i in batch] # b [c t h w] input_ids_1 = [i['input_ids_1'] for i in batch] # b [1 l] cond_mask_1 = [i['cond_mask_1'] for i in batch] # b [1 l] input_ids_2 = [i['input_ids_2'] for i in batch] # b [1 l] cond_mask_2 = [i['cond_mask_2'] for i in batch] # b [1 l] assert all([i is None for i in input_ids_2]) or all([i is not None for i in input_ids_2]) assert all([i is None for i in cond_mask_2]) or all([i is not None for i in cond_mask_2]) if all([i is None for i in input_ids_2]): input_ids_2 = None if all([i is None for i in cond_mask_2]): cond_mask_2 = None return batch_tubes, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2 def __call__(self, batch): batch_tubes, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2 = self.package(batch) ds_stride = self.ae_stride * self.patch_size t_ds_stride = self.ae_stride_t * self.patch_size_t pad_batch_tubes, attention_mask, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2 = self.process( batch_tubes, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2, t_ds_stride, ds_stride, self.max_thw, self.ae_stride_thw ) assert not torch.any(torch.isnan(pad_batch_tubes)), 'after pad_batch_tubes' return pad_batch_tubes, attention_mask, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2 def process(self, batch_tubes, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2, t_ds_stride, ds_stride, max_thw, ae_stride_thw): # pad to max multiple of ds_stride batch_input_size = [i.shape for i in batch_tubes] # [(c t h w), (c t h w)] assert len(batch_input_size) == self.batch_size if self.group_data or self.batch_size == 1: # len_each_batch = batch_input_size idx_length_dict = dict([*zip(list(range(self.batch_size)), len_each_batch)]) count_dict = Counter(len_each_batch) if len(count_dict) != 1: sorted_by_value = sorted(count_dict.items(), key=lambda item: item[1]) # import ipdb;ipdb.set_trace() # print(batch, idx_length_dict, count_dict, sorted_by_value) pick_length = sorted_by_value[-1][0] # the highest frequency candidate_batch = [idx for idx, length in idx_length_dict.items() if length == pick_length] random_select_batch = [random.choice(candidate_batch) for _ in range(len(len_each_batch) - len(candidate_batch))] print(batch_input_size, idx_length_dict, count_dict, sorted_by_value, pick_length, candidate_batch, random_select_batch) pick_idx = candidate_batch + random_select_batch batch_tubes = [batch_tubes[i] for i in pick_idx] batch_input_size = [i.shape for i in batch_tubes] # [(c t h w), (c t h w)] input_ids_1 = [input_ids_1[i] for i in pick_idx] # b [1, l] cond_mask_1 = [cond_mask_1[i] for i in pick_idx] # b [1, l] if input_ids_2 is not None: input_ids_2 = [input_ids_2[i] for i in pick_idx] # b [1, l] if cond_mask_2 is not None: cond_mask_2 = [cond_mask_2[i] for i in pick_idx] # b [1, l] for i in range(1, self.batch_size): assert batch_input_size[0] == batch_input_size[i] max_t = max([i[1] for i in batch_input_size]) max_h = max([i[2] for i in batch_input_size]) max_w = max([i[3] for i in batch_input_size]) else: max_t, max_h, max_w = max_thw pad_max_t, pad_max_h, pad_max_w = pad_to_multiple(max_t-1+self.ae_stride_t, t_ds_stride), \ pad_to_multiple(max_h, ds_stride), \ pad_to_multiple(max_w, ds_stride) pad_max_t = pad_max_t + 1 - self.ae_stride_t each_pad_t_h_w = [ [ pad_max_t - i.shape[1], pad_max_h - i.shape[2], pad_max_w - i.shape[3] ] for i in batch_tubes ] pad_batch_tubes = [ F.pad(im, (0, pad_w, 0, pad_h, 0, pad_t), value=0) for (pad_t, pad_h, pad_w), im in zip(each_pad_t_h_w, batch_tubes) ] pad_batch_tubes = torch.stack(pad_batch_tubes, dim=0) max_tube_size = [pad_max_t, pad_max_h, pad_max_w] max_latent_size = [ ((max_tube_size[0]-1) // ae_stride_thw[0] + 1), max_tube_size[1] // ae_stride_thw[1], max_tube_size[2] // ae_stride_thw[2] ] valid_latent_size = [ [ int(math.ceil((i[1]-1) / ae_stride_thw[0])) + 1, int(math.ceil(i[2] / ae_stride_thw[1])), int(math.ceil(i[3] / ae_stride_thw[2])) ] for i in batch_input_size] attention_mask = [ F.pad(torch.ones(i, dtype=pad_batch_tubes.dtype), (0, max_latent_size[2] - i[2], 0, max_latent_size[1] - i[1], 0, max_latent_size[0] - i[0]), value=0) for i in valid_latent_size] attention_mask = torch.stack(attention_mask) # b t h w if self.batch_size == 1 or self.group_data: if not torch.all(attention_mask.bool()): print(batch_input_size, (max_t, max_h, max_w), (pad_max_t, pad_max_h, pad_max_w), each_pad_t_h_w, max_latent_size, valid_latent_size) assert torch.all(attention_mask.bool()) input_ids_1 = torch.stack(input_ids_1) # b 1 l cond_mask_1 = torch.stack(cond_mask_1) # b 1 l input_ids_2 = torch.stack(input_ids_2) if input_ids_2 is not None else input_ids_2 # b 1 l cond_mask_2 = torch.stack(cond_mask_2) if cond_mask_2 is not None else cond_mask_2 # b 1 l return pad_batch_tubes, attention_mask, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2 def group_data_fun(lengths, generator=None): # counter is decrease order counter = Counter(lengths) # counter {'1x256x256': 3, ''} lengths ['1x256x256', '1x256x256', '1x256x256', ...] grouped_indices = defaultdict(list) for idx, item in enumerate(lengths): # group idx to a list grouped_indices[item].append(idx) grouped_indices = dict(grouped_indices) # {'1x256x256': [0, 1, 2], ...} sorted_indices = [grouped_indices[item] for (item, _) in sorted(counter.items(), key=lambda x: x[1], reverse=True)] # shuffle in each group shuffle_sorted_indices = [] for indice in sorted_indices: shuffle_idx = torch.randperm(len(indice), generator=generator).tolist() shuffle_sorted_indices.extend([indice[idx] for idx in shuffle_idx]) return shuffle_sorted_indices def last_group_data_fun(shuffled_megabatches, lengths): # lengths ['1x256x256', '1x256x256', '1x256x256' ...] re_shuffled_megabatches = [] # print('shuffled_megabatches', len(shuffled_megabatches)) for i_megabatch, megabatch in enumerate(shuffled_megabatches): re_megabatch = [] for i_batch, batch in enumerate(megabatch): assert len(batch) != 0 len_each_batch = [lengths[i] for i in batch] # ['1x256x256', '1x256x256'] idx_length_dict = dict([*zip(batch, len_each_batch)]) # {0: '1x256x256', 100: '1x256x256'} count_dict = Counter(len_each_batch) # {'1x256x256': 2} or {'1x256x256': 1, '1x768x256': 1} if len(count_dict) != 1: sorted_by_value = sorted(count_dict.items(), key=lambda item: item[1]) # {'1x256x256': 1, '1x768x256': 1} # import ipdb;ipdb.set_trace() # print(batch, idx_length_dict, count_dict, sorted_by_value) pick_length = sorted_by_value[-1][0] # the highest frequency candidate_batch = [idx for idx, length in idx_length_dict.items() if length == pick_length] random_select_batch = [random.choice(candidate_batch) for i in range(len(len_each_batch) - len(candidate_batch))] # print(batch, idx_length_dict, count_dict, sorted_by_value, pick_length, candidate_batch, random_select_batch) batch = candidate_batch + random_select_batch # print(batch) for i in range(1, len(batch)-1): # if not lengths[batch[0]] == lengths[batch[i]]: # print(batch, [lengths[i] for i in batch]) # import ipdb;ipdb.set_trace() assert lengths[batch[0]] == lengths[batch[i]] re_megabatch.append(batch) re_shuffled_megabatches.append(re_megabatch) # for megabatch, re_megabatch in zip(shuffled_megabatches, re_shuffled_megabatches): # for batch, re_batch in zip(megabatch, re_megabatch): # for i, re_i in zip(batch, re_batch): # if i != re_i: # print(i, re_i) return re_shuffled_megabatches def split_to_even_chunks(megabatch, lengths, world_size, batch_size): """ Split a list of indices into `chunks` chunks of roughly equal lengths. """ # batch_size=2, world_size=2 # [1, 2, 3, 4] -> [[1, 2], [3, 4]] # [1, 2, 3] -> [[1, 2], [3]] # [1, 2] -> [[1], [2]] # [1] -> [[1], []] chunks = [megabatch[i::world_size] for i in range(world_size)] pad_chunks = [] for idx, chunk in enumerate(chunks): if batch_size != len(chunk): assert batch_size > len(chunk) if len(chunk) != 0: # [[1, 2], [3]] -> [[1, 2], [3, 3]] chunk = chunk + [random.choice(chunk) for _ in range(batch_size - len(chunk))] else: chunk = random.choice(pad_chunks) # [[1], []] -> [[1], [1]] print(chunks[idx], '->', chunk) pad_chunks.append(chunk) return pad_chunks def get_length_grouped_indices(lengths, batch_size, world_size, gradient_accumulation_size, initial_global_step, generator=None, group_data=False, seed=42): # We need to use torch for the random part as a distributed sampler will set the random seed for torch. if generator is None: generator = torch.Generator().manual_seed(seed) # every rank will generate a fixed order but random index # print('lengths', lengths) if group_data: indices = group_data_fun(lengths, generator) else: indices = torch.randperm(len(lengths), generator=generator).tolist() # print('indices', len(indices)) # print('sort indices', len(indices)) # print('sort indices', indices) # print('sort lengths', [lengths[i] for i in indices]) megabatch_size = world_size * batch_size megabatches = [indices[i: i + megabatch_size] for i in range(0, len(lengths), megabatch_size)] # import ipdb;ipdb.set_trace() # print('megabatches', len(megabatches)) # print('\nmegabatches', megabatches) # megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] # import ipdb;ipdb.set_trace() # print('sort megabatches', len(megabatches)) megabatches_len = [[lengths[i] for i in megabatch] for megabatch in megabatches] # print(f'\nrank {accelerator.process_index} sorted megabatches_len', megabatches_len[0], megabatches_len[1], megabatches_len[-2], megabatches_len[-1]) # import ipdb;ipdb.set_trace() megabatches = [split_to_even_chunks(megabatch, lengths, world_size, batch_size) for megabatch in megabatches] # import ipdb;ipdb.set_trace() # print('nsplit_to_even_chunks megabatches', len(megabatches)) # print('\nsplit_to_even_chunks megabatches', megabatches) split_to_even_chunks_len = [[[lengths[i] for i in batch] for batch in megabatch] for megabatch in megabatches] # print(f'\nrank {accelerator.process_index} split_to_even_chunks_len', split_to_even_chunks_len[0], split_to_even_chunks_len[1], split_to_even_chunks_len[-2], split_to_even_chunks_len[-1]) # print('\nsplit_to_even_chunks len', split_to_even_chunks_len) # return [i for megabatch in megabatches for batch in megabatch for i in batch] indices_mega = torch.randperm(len(megabatches), generator=generator).tolist() # print(f'rank {accelerator.process_index} seed {seed}, len(megabatches) {len(megabatches)}, indices_mega, {indices_mega[:50]}') shuffled_megabatches = [megabatches[i] for i in indices_mega] shuffled_megabatches_len = [[[lengths[i] for i in batch] for batch in megabatch] for megabatch in shuffled_megabatches] # print(f'\nrank {accelerator.process_index} sorted shuffled_megabatches_len', shuffled_megabatches_len[0], shuffled_megabatches_len[1], shuffled_megabatches_len[-2], shuffled_megabatches_len[-1]) # import ipdb;ipdb.set_trace() # print('shuffled_megabatches', len(shuffled_megabatches)) if group_data: shuffled_megabatches = last_group_data_fun(shuffled_megabatches, lengths) group_shuffled_megabatches_len = [[[lengths[i] for i in batch] for batch in megabatch] for megabatch in shuffled_megabatches] # print(f'\nrank {accelerator.process_index} group_shuffled_megabatches_len', group_shuffled_megabatches_len[0], group_shuffled_megabatches_len[1], group_shuffled_megabatches_len[-2], group_shuffled_megabatches_len[-1]) # import ipdb;ipdb.set_trace() initial_global_step = initial_global_step * gradient_accumulation_size # print('shuffled_megabatches', len(shuffled_megabatches)) # print('have been trained idx:', len(shuffled_megabatches[:initial_global_step])) # print('shuffled_megabatches[:10]', shuffled_megabatches[:10]) # print('have been trained idx:', shuffled_megabatches[:initial_global_step]) shuffled_megabatches = shuffled_megabatches[initial_global_step:] print(f'Skip the data of {initial_global_step} step!') # print('after shuffled_megabatches', len(shuffled_megabatches)) # print('after shuffled_megabatches[:10]', shuffled_megabatches[:10]) # print('\nshuffled_megabatches', shuffled_megabatches) # import ipdb;ipdb.set_trace() # print('\nshuffled_megabatches len', [[i, lengths[i]] for megabatch in shuffled_megabatches for batch in megabatch for i in batch]) return [i for megabatch in shuffled_megabatches for batch in megabatch for i in batch] class LengthGroupedSampler(Sampler): r""" Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while keeping a bit of randomness. """ def __init__( self, batch_size: int, world_size: int, gradient_accumulation_size: int, initial_global_step: int, lengths: Optional[List[int]] = None, group_data=False, generator=None, ): if lengths is None: raise ValueError("Lengths must be provided.") self.batch_size = batch_size self.world_size = world_size self.initial_global_step = initial_global_step self.gradient_accumulation_size = gradient_accumulation_size self.lengths = lengths self.group_data = group_data self.generator = generator # print('self.lengths, self.initial_global_step, self.batch_size, self.world_size, self.gradient_accumulation_size', # len(self.lengths), self.initial_global_step, self.batch_size, self.world_size, self.gradient_accumulation_size) def __len__(self): return len(self.lengths) - self.initial_global_step * self.batch_size * self.world_size * self.gradient_accumulation_size def __iter__(self): indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, self.gradient_accumulation_size, self.initial_global_step, group_data=self.group_data, generator=self.generator) # print(len(indices), indices[23640:23690]) # import sys;sys.exit() return iter(indices) ================================================ FILE: opensora/utils/downloader.py ================================================ import gdown import os opensora_cache_home = os.path.expanduser( os.getenv("OPENSORA_HOME", os.path.join("~/.cache", "opensora")) ) def gdown_download(id, fname, cache_dir=None): cache_dir = opensora_cache_home if not cache_dir else cache_dir os.makedirs(cache_dir, exist_ok=True) destination = os.path.join(cache_dir, fname) if os.path.exists(destination): return destination gdown.download(id=id, output=destination, quiet=False) return destination ================================================ FILE: opensora/utils/ema.py ================================================ import contextlib import copy import random from typing import Any, Dict, Iterable, List, Optional, Union from diffusers.utils import ( deprecate, is_torchvision_available, is_transformers_available, ) if is_transformers_available(): import transformers if is_torchvision_available(): from torchvision import transforms import numpy as np import torch # Adapted from diffusers-style ema https://github.com/huggingface/diffusers/blob/main/src/diffusers/training_utils.py#L263 class EMAModel: """ Exponential Moving Average of models weights """ def __init__( self, parameters: Iterable[torch.nn.Parameter], decay: float = 0.9999, min_decay: float = 0.0, update_after_step: int = 0, use_ema_warmup: bool = False, inv_gamma: Union[float, int] = 1.0, power: Union[float, int] = 2 / 3, model_cls: Optional[Any] = None, model_config: Dict[str, Any] = None, **kwargs, ): """ Args: parameters (Iterable[torch.nn.Parameter]): The parameters to track. decay (float): The decay factor for the exponential moving average. min_decay (float): The minimum decay factor for the exponential moving average. update_after_step (int): The number of steps to wait before starting to update the EMA weights. use_ema_warmup (bool): Whether to use EMA warmup. inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA weights will be stored on CPU. @crowsonkb's notes on EMA Warmup: If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at 215.4k steps). """ if isinstance(parameters, torch.nn.Module): deprecation_message = ( "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. " "Please pass the parameters of the module instead." ) deprecate( "passing a `torch.nn.Module` to `ExponentialMovingAverage`", "1.0.0", deprecation_message, standard_warn=False, ) parameters = parameters.parameters() # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility use_ema_warmup = True if kwargs.get("max_value", None) is not None: deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead." deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False) decay = kwargs["max_value"] if kwargs.get("min_value", None) is not None: deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead." deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False) min_decay = kwargs["min_value"] parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters] if kwargs.get("device", None) is not None: deprecation_message = "The `device` argument is deprecated. Please use `to` instead." deprecate("device", "1.0.0", deprecation_message, standard_warn=False) self.to(device=kwargs["device"]) self.temp_stored_params = None self.decay = decay self.min_decay = min_decay self.update_after_step = update_after_step self.use_ema_warmup = use_ema_warmup self.inv_gamma = inv_gamma self.power = power self.optimization_step = 0 self.cur_decay_value = None # set in `step()` self.model_cls = model_cls self.model_config = model_config @classmethod def extract_ema_kwargs(cls, kwargs): """ Extracts the EMA kwargs from the kwargs of a class method. """ ema_kwargs = {} for key in [ "decay", "min_decay", "optimization_step", "update_after_step", "use_ema_warmup", "inv_gamma", "power", ]: if kwargs.get(key, None) is not None: ema_kwargs[key] = kwargs.pop(key) return ema_kwargs @classmethod def from_pretrained(cls, path, model_cls) -> "EMAModel": config = model_cls.load_config(path) ema_kwargs = cls.extract_ema_kwargs(config) model = model_cls.from_pretrained(path) ema_model = cls(model.parameters(), model_cls=model_cls, model_config=config) ema_model.load_state_dict(ema_kwargs) return ema_model def save_pretrained(self, path): if self.model_cls is None: raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.") if self.model_config is None: raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.") model = self.model_cls.from_config(self.model_config) state_dict = self.state_dict() state_dict.pop("shadow_params", None) model.register_to_config(**state_dict) self.copy_to(model.parameters()) model.save_pretrained(path) def get_decay(self, optimization_step: int) -> float: """ Compute the decay factor for the exponential moving average. """ step = max(0, optimization_step - self.update_after_step - 1) if step <= 0: return 0.0 if self.use_ema_warmup: cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power else: cur_decay_value = (1 + step) / (10 + step) cur_decay_value = min(cur_decay_value, self.decay) # make sure decay is not smaller than min_decay cur_decay_value = max(cur_decay_value, self.min_decay) return cur_decay_value @torch.no_grad() def step(self, parameters: Iterable[torch.nn.Parameter]): if isinstance(parameters, torch.nn.Module): deprecation_message = ( "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. " "Please pass the parameters of the module instead." ) deprecate( "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`", "1.0.0", deprecation_message, standard_warn=False, ) parameters = parameters.parameters() parameters = list(parameters) self.optimization_step += 1 # Compute the decay factor for the exponential moving average. decay = self.get_decay(self.optimization_step) self.cur_decay_value = decay one_minus_decay = 1 - decay context_manager = contextlib.nullcontext if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): import deepspeed for s_param, param in zip(self.shadow_params, parameters): if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) with context_manager(): if param.requires_grad: s_param.sub_(one_minus_decay * (s_param - param)) else: s_param.copy_(param) def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ Copy current averaged parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ parameters = list(parameters) for s_param, param in zip(self.shadow_params, parameters): param.data.copy_(s_param.to(param.device).data) def to(self, device=None, dtype=None) -> None: r"""Move internal buffers of the ExponentialMovingAverage to `device`. Args: device: like `device` argument to `torch.Tensor.to` """ # .to() on the tensors handles None correctly self.shadow_params = [ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) for p in self.shadow_params ] def state_dict(self) -> dict: r""" Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during checkpointing to save the ema state dict. """ # Following PyTorch conventions, references to tensors are returned: # "returns a reference to the state and not its copy!" - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict return { "decay": self.decay, "min_decay": self.min_decay, "optimization_step": self.optimization_step, "update_after_step": self.update_after_step, "use_ema_warmup": self.use_ema_warmup, "inv_gamma": self.inv_gamma, "power": self.power, "shadow_params": self.shadow_params, } def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" Args: Save the current parameters for restoring later. parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. """ self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" Args: Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After validation (or model saving), use this to restore the former parameters. parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ if self.temp_stored_params is None: raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") for c_param, param in zip(self.temp_stored_params, parameters): param.data.copy_(c_param.data) # Better memory-wise. self.temp_stored_params = None def load_state_dict(self, state_dict: dict) -> None: r""" Args: Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the ema state dict. state_dict (dict): EMA state. Should be an object returned from a call to :meth:`state_dict`. """ # deepcopy, to be consistent with module API state_dict = copy.deepcopy(state_dict) self.decay = state_dict.get("decay", self.decay) if self.decay < 0.0 or self.decay > 1.0: raise ValueError("Decay must be between 0 and 1") self.min_decay = state_dict.get("min_decay", self.min_decay) if not isinstance(self.min_decay, float): raise ValueError("Invalid min_decay") self.optimization_step = state_dict.get("optimization_step", self.optimization_step) if not isinstance(self.optimization_step, int): raise ValueError("Invalid optimization_step") self.update_after_step = state_dict.get("update_after_step", self.update_after_step) if not isinstance(self.update_after_step, int): raise ValueError("Invalid update_after_step") self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) if not isinstance(self.use_ema_warmup, bool): raise ValueError("Invalid use_ema_warmup") self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) if not isinstance(self.inv_gamma, (float, int)): raise ValueError("Invalid inv_gamma") self.power = state_dict.get("power", self.power) if not isinstance(self.power, (float, int)): raise ValueError("Invalid power") shadow_params = state_dict.get("shadow_params", None) if shadow_params is not None: self.shadow_params = shadow_params if not isinstance(self.shadow_params, list): raise ValueError("shadow_params must be a list") if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): raise ValueError("shadow_params must all be Tensors") ================================================ FILE: opensora/utils/ema_utils.py ================================================ from peft import get_peft_model, PeftModel import os from copy import deepcopy import torch import json from diffusers.training_utils import EMAModel as diffuser_EMAModel class EMAModel(diffuser_EMAModel): def __init__(self, parameters, **kwargs): self.lora_config = kwargs.pop('lora_config', None) super().__init__(parameters, **kwargs) @classmethod def from_pretrained(cls, path, model_cls, lora_config, model_base) -> "EMAModel": # 1. load model if lora_config is not None: # 1.1 load origin model model_base = model_cls.from_pretrained(model_base) # model_base config = model_base.config # 1.2 convert to lora model automatically and load lora weight model = PeftModel.from_pretrained(model_base, path) # lora_origin_model else: model = model_cls.from_pretrained(path) config = model.config # 3. ema the whole model ema_model = cls(model.parameters(), model_cls=model_cls, model_config=config, lora_config=lora_config) # 4. load ema_config, e.g decay... with open(os.path.join(path, 'ema_config.json'), 'r') as f: state_dict = json.load(f) ema_model.load_state_dict(state_dict) return ema_model def save_pretrained(self, path): if self.model_cls is None: raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.") if self.model_config is None: raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.") # 1. init a base model randomly model = self.model_cls.from_config(self.model_config) # 1.1 convert lora_model if self.lora_config is not None: model = get_peft_model(model, self.lora_config) # 2. ema_model copy to model self.copy_to(model.parameters()) # 3. save weight if self.lora_config is not None: model.save_pretrained(path) # only lora weight merge_model = model.merge_and_unload() merge_model.save_pretrained(path) # merge_model weight else: merge_model.save_pretrained(path) # model weight # 4. save ema_config, e.g decay... state_dict = self.state_dict() # lora_model weight state_dict.pop("shadow_params", None) with open(os.path.join(path, 'ema_config.json'), 'w') as f: json.dump(state_dict, f, indent=2) ================================================ FILE: opensora/utils/freeinit_utils.py ================================================ import torch import torch.fft as fft import math def freq_mix_3d(x, noise, LPF): """ Noise reinitialization. Args: x: diffused latent noise: randomly sampled noise LPF: low pass filter """ # FFT x_freq = fft.fftn(x, dim=(-3, -2, -1)) x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) # frequency mix HPF = 1 - LPF x_freq_low = x_freq * LPF noise_freq_high = noise_freq * HPF x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain # IFFT x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real return x_mixed def get_freq_filter(shape, device, filter_type, n, d_s, d_t): """ Form the frequency filter for noise reinitialization. Args: shape: shape of latent (B, C, T, H, W) filter_type: type of the freq filter n: (only for butterworth) order of the filter, larger n ~ ideal, smaller n ~ gaussian d_s: normalized stop frequency for spatial dimensions (0.0-1.0) d_t: normalized stop frequency for temporal dimension (0.0-1.0) """ if filter_type == "gaussian": return gaussian_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) elif filter_type == "ideal": return ideal_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) elif filter_type == "box": return box_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) elif filter_type == "butterworth": return butterworth_low_pass_filter(shape=shape, n=n, d_s=d_s, d_t=d_t).to(device) else: raise NotImplementedError def gaussian_low_pass_filter(shape, d_s=0.25, d_t=0.25): """ Compute the gaussian low pass filter mask. Args: shape: shape of the filter (volume) d_s: normalized stop frequency for spatial dimensions (0.0-1.0) d_t: normalized stop frequency for temporal dimension (0.0-1.0) """ T, H, W = shape[-3], shape[-2], shape[-1] mask = torch.zeros(shape) if d_s==0 or d_t==0: return mask for t in range(T): for h in range(H): for w in range(W): d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) mask[..., t,h,w] = math.exp(-1/(2*d_s**2) * d_square) return mask def butterworth_low_pass_filter(shape, n=4, d_s=0.25, d_t=0.25): """ Compute the butterworth low pass filter mask. Args: shape: shape of the filter (volume) n: order of the filter, larger n ~ ideal, smaller n ~ gaussian d_s: normalized stop frequency for spatial dimensions (0.0-1.0) d_t: normalized stop frequency for temporal dimension (0.0-1.0) """ T, H, W = shape[-3], shape[-2], shape[-1] mask = torch.zeros(shape) if d_s==0 or d_t==0: return mask for t in range(T): for h in range(H): for w in range(W): d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) mask[..., t,h,w] = 1 / (1 + (d_square / d_s**2)**n) return mask def ideal_low_pass_filter(shape, d_s=0.25, d_t=0.25): """ Compute the ideal low pass filter mask. Args: shape: shape of the filter (volume) d_s: normalized stop frequency for spatial dimensions (0.0-1.0) d_t: normalized stop frequency for temporal dimension (0.0-1.0) """ T, H, W = shape[-3], shape[-2], shape[-1] mask = torch.zeros(shape) if d_s==0 or d_t==0: return mask for t in range(T): for h in range(H): for w in range(W): d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) mask[..., t,h,w] = 1 if d_square <= d_s*2 else 0 return mask def box_low_pass_filter(shape, d_s=0.25, d_t=0.25): """ Compute the ideal low pass filter mask (approximated version). Args: shape: shape of the filter (volume) d_s: normalized stop frequency for spatial dimensions (0.0-1.0) d_t: normalized stop frequency for temporal dimension (0.0-1.0) """ T, H, W = shape[-3], shape[-2], shape[-1] mask = torch.zeros(shape) if d_s==0 or d_t==0: return mask threshold_s = round(int(H // 2) * d_s) threshold_t = round(T // 2 * d_t) cframe, crow, ccol = T // 2, H // 2, W //2 mask[..., cframe - threshold_t:cframe + threshold_t, crow - threshold_s:crow + threshold_s, ccol - threshold_s:ccol + threshold_s] = 1.0 return mask ================================================ FILE: opensora/utils/lora_utils.py ================================================ from peft import get_peft_model, PeftModel import os from copy import deepcopy import torch import json def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param # Borrowed from peft.utils.get_peft_model_state_dict def get_peft_state_maybe_zero_3(named_params, bias): if bias == "none": to_return = {k: t for k, t in named_params if "lora_" in k} elif bias == "all": to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} elif bias == "lora_only": to_return = {} maybe_lora_bias = {} lora_bias_names = set() for k, t in named_params: if "lora_" in k: to_return[k] = t bias_name = k.split("lora_")[0] + "bias" lora_bias_names.add(bias_name) elif "bias" in k: maybe_lora_bias[k] = t for k, t in maybe_lora_bias: if bias_name in lora_bias_names: to_return[bias_name] = t else: raise NotImplementedError to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} return to_return ================================================ FILE: opensora/utils/mask_utils.py ================================================ from math import floor, ceil from abc import ABC, abstractmethod import cv2 import torch import torch.nn.functional as F import imageio import numpy as np try: import torch_npu from opensora.npu_config import npu_config except: torch_npu = None npu_config = None import random from enum import Enum, auto from einops import rearrange class MaskType(Enum): t2iv = auto() # For video, execute t2v (all frames are masked), for image, execute t2i (the image are masked) i2v = auto() # Only for video, execute i2v (i.e. maintain the first frame and mask the rest) transition = auto() # Only for video, execute transition (i.e. maintain the first and last frame and mask the rest) continuation = auto() # Only for video, execute video continuation (i.e. maintain the starting k frames and mask the rest) clear = auto() # For video and image, all frames are not masked random_temporal = auto() # For video, randomly mask some frames TYPE_TO_STR = {mask_type: mask_type.name for mask_type in MaskType} STR_TO_TYPE = {mask_type.name: mask_type for mask_type in MaskType} def save_mask_to_video(mask, save_path='mask.mp4', fps=24): T, _, H, W = mask.shape writer = imageio.get_writer(save_path, fps=fps, codec='libx264', quality=6) for t in range(T): frame = mask[t, 0].cpu().numpy() * 255 frame = frame.astype(np.uint8) # 确保数据类型是 uint8 writer.append_data(frame) writer.close() def read_video(video_path): reader = imageio.get_reader(video_path) frames = [] for frame in reader: frame = np.transpose(frame, (2, 0, 1)) frames.append(frame) video_array = np.stack(frames) video_tensor = torch.from_numpy(video_array).float() reader.close() return video_tensor class BaseMaskGenerator(ABC): def create_system_mask(self, num_frames, height, width, device, dtype): if num_frames is None or height is None or width is None: raise ValueError('num_frames, height, and width should be provided.') return torch.ones([num_frames, 1, height, width], device=device, dtype=dtype) @abstractmethod def process(self, mask): # process self.mask to meet the specific task pass def __call__(self, num_frames=None, height=None, width=None, device='cuda', dtype=torch.float32): mask = self.create_system_mask(num_frames, height, width, device, dtype) return self.process(mask) class T2IVMaskGenerator(BaseMaskGenerator): def process(self, mask): mask.fill_(1) return mask class I2VMaskGenerator(BaseMaskGenerator): def process(self, mask): mask[0] = 0 return mask class TransitionMaskGenerator(BaseMaskGenerator): def process(self, mask): mask[0] = 0 mask[-1] = 0 return mask class ContinuationMaskGenerator(BaseMaskGenerator): def __init__(self, min_clear_ratio=0.0, max_clear_ratio=1.0): assert min_clear_ratio >= 0 and min_clear_ratio <= 1, 'min_clear_ratio should be in the range of [0, 1].' assert max_clear_ratio >= 0 and max_clear_ratio <= 1, 'max_clear_ratio should be in the range of [0, 1].' assert min_clear_ratio <= max_clear_ratio, 'min_clear_ratio should be less than max_clear_ratio.' self.min_clear_ratio = min_clear_ratio self.max_clear_ratio = max_clear_ratio def process(self, mask): num_frames = mask.shape[0] end_idx = random.randint(floor(num_frames * self.min_clear_ratio), ceil(num_frames * self.max_clear_ratio)) mask[0:end_idx] = 0 return mask class ClearMaskGenerator(BaseMaskGenerator): def process(self, mask): mask.zero_() return mask class RandomTemporalMaskGenerator(BaseMaskGenerator): def __init__(self, min_clear_ratio=0.0, max_clear_ratio=1.0): assert min_clear_ratio >= 0 and min_clear_ratio <= 1, 'min_clear_ratio should be in the range of [0, 1].' assert max_clear_ratio >= 0 and max_clear_ratio <= 1, 'max_clear_ratio should be in the range of [0, 1].' assert min_clear_ratio <= max_clear_ratio, 'min_clear_ratio should be less than max_clear_ratio.' self.min_clear_ratio = min_clear_ratio self.max_clear_ratio = max_clear_ratio def process(self, mask): num_frames = mask.shape[0] num_to_select = random.randint(floor(num_frames * self.min_clear_ratio), ceil(num_frames * self.max_clear_ratio)) selected_indices = random.sample(range(num_frames), num_to_select) mask[selected_indices] = 0 return mask class MaskProcessor: def __init__( self, max_height=640, max_width=640, min_clear_ratio=0.0, max_clear_ratio=1.0, ): self.max_height = max_height self.max_width = max_width self.min_clear_ratio = min_clear_ratio self.max_clear_ratio = max_clear_ratio self.init_mask_generators() def init_mask_generators(self): self.mask_generators = { MaskType.t2iv: T2IVMaskGenerator(), MaskType.i2v: I2VMaskGenerator(), MaskType.transition: TransitionMaskGenerator(), MaskType.continuation: ContinuationMaskGenerator(min_clear_ratio=self.min_clear_ratio, max_clear_ratio=self.max_clear_ratio), MaskType.clear: ClearMaskGenerator(), MaskType.random_temporal: RandomTemporalMaskGenerator(min_clear_ratio=self.min_clear_ratio, max_clear_ratio=self.max_clear_ratio), } def get_mask(self, mask_generator_type, num_frames, height, width, device='cuda', dtype=torch.float32): return self.mask_generators[mask_generator_type](num_frames, height, width, device=device, dtype=dtype) def __call__(self, pixel_values, mask_type=None, mask_type_ratio_dict=None): num_frames, _, height, width = pixel_values.shape if mask_type_ratio_dict is not None: assert isinstance(mask_type_ratio_dict, dict), 'mask_type_ratio_dict should be a dict.' assert mask_type_ratio_dict.keys() <= set(MaskType), f'Invalid mask type: {set(MaskType) - mask_type_ratio_dict.keys()}' mask_generator_type = random.choices(list(mask_type_ratio_dict.keys()), list(mask_type_ratio_dict.values()))[0] elif mask_type is not None: assert mask_type in STR_TO_TYPE.keys() or mask_type in STR_TO_TYPE.values(), f'Invalid mask type: {mask_type}' mask_generator_type = mask_type if mask_type in MaskType else STR_TO_TYPE[mask_type] else: raise ValueError('mask_type or mask_type_ratio_dict should be provided.') mask = self.get_mask(mask_generator_type, num_frames, height, width, device=pixel_values.device, dtype=pixel_values.dtype) masked_pixel_values = pixel_values * (mask < 0.5) return dict(mask=mask, masked_pixel_values=masked_pixel_values) class MaskCompressor: def __init__(self, ae_stride_h=8, ae_stride_w=8, ae_stride_t=4, **kwargs): self.ae_stride_h = ae_stride_h self.ae_stride_w = ae_stride_w self.ae_stride_t = ae_stride_t def __call__(self, mask): B, C, T, H, W = mask.shape new_H, new_W = H // self.ae_stride_h, W // self.ae_stride_w mask = rearrange(mask, 'b c t h w -> (b c t) 1 h w') if torch_npu is not None: dtype = mask.dtype mask = mask.to(dtype=torch.float32) mask = F.interpolate(mask, size=(new_H, new_W), mode='bilinear') mask = mask.to(dtype) else: mask = F.interpolate(mask, size=(new_H, new_W), mode='bilinear') mask = rearrange(mask, '(b c t) 1 h w -> b c t h w', t=T, b=B) if T % 2 == 1: new_T = T // self.ae_stride_t + 1 mask_first_frame = mask[:, :, 0:1].repeat(1, 1, self.ae_stride_t, 1, 1).contiguous() mask = torch.cat([mask_first_frame, mask[:, :, 1:]], dim=2) else: new_T = T // self.ae_stride_t mask = mask.view(B, new_T, self.ae_stride_t, new_H, new_W) mask = mask.transpose(1, 2).contiguous() # Transpose to allows the channel dimension to represent a portion of the region in the original mask return mask class BaseNoiseAdder(ABC): @abstractmethod def add_noise(self, mask_pixel_values, mask): pass def __call__(self, mask_pixel_values, mask): return self.add_noise(mask_pixel_values, mask) class GaussianNoiseAdder(BaseNoiseAdder): def __init__(self, mean=-3.0, std=0.5, clear_ratio=0.05): self.mean = mean self.std = std self.clear_ratio = clear_ratio # pixel_values: (B, C, T, H, W) # mask: (B, 1, T, H, W) def add_noise(self, masked_pixel_values, mask): if random.random() < self.clear_ratio: return masked_pixel_values noise_sigma = torch.normal(mean=self.mean, std=self.std, size=(masked_pixel_values.shape[0],), device=masked_pixel_values.device) noise_sigma = torch.exp(noise_sigma).to(dtype=masked_pixel_values.dtype) noise = torch.randn_like(masked_pixel_values) * noise_sigma[:, None, None, None, None] noise = torch.where(mask < 0.5, noise, torch.zeros_like(noise)) return masked_pixel_values + noise if __name__ == '__main__': video_path = '/home/image_data/hxy/data/video/000184.mp4' video = read_video(video_path) processor = MaskProcessor() ratio_dict = { MaskType.t2iv: 0, MaskType.i2v: 0, MaskType.transition: 0, MaskType.continuation: 0, MaskType.clear: 0, MaskType.random_temporal: 1, } mask = processor(video, mask_type_ratio_dict=ratio_dict)['mask'] print(mask.shape) save_mask_to_video(mask, save_path='test_mask.mp4', fps=24) ================================================ FILE: opensora/utils/parallel_states.py ================================================ import torch import torch.distributed as dist import os class COMM_INFO: def __init__(self): self.group = None self.world_size = 0 self.rank = -1 nccl_info = COMM_INFO() _SEQUENCE_PARALLEL_STATE = False def initialize_sequence_parallel_state(sequence_parallel_size): global _SEQUENCE_PARALLEL_STATE if sequence_parallel_size > 1: _SEQUENCE_PARALLEL_STATE = True initialize_sequence_parallel_group(sequence_parallel_size) def set_sequence_parallel_state(state): global _SEQUENCE_PARALLEL_STATE _SEQUENCE_PARALLEL_STATE = state def get_sequence_parallel_state(): return _SEQUENCE_PARALLEL_STATE def initialize_sequence_parallel_group(sequence_parallel_size): """Initialize the sequence parallel group.""" rank = int(os.getenv('RANK', '0')) world_size = int(os.getenv("WORLD_SIZE", '1')) assert world_size % sequence_parallel_size == 0, "world_size must be divisible by sequence_parallel_size" # hccl nccl_info.world_size = sequence_parallel_size nccl_info.rank = rank num_sequence_parallel_groups: int = world_size // sequence_parallel_size for i in range(num_sequence_parallel_groups): ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) group = dist.new_group(ranks) if rank in ranks: nccl_info.group = group def destroy_sequence_parallel_group(): """Destroy the sequence parallel group.""" dist.destroy_process_group() ================================================ FILE: opensora/utils/sample_utils.py ================================================ from diffusers.schedulers import ( DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler, HeunDiscreteScheduler, EulerAncestralDiscreteScheduler, DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler, DPMSolverSinglestepScheduler, CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler ) from einops import rearrange import time import torch import os import torch.distributed as dist from torchvision.utils import save_image import imageio import math import argparse from transformers import AutoModelForCausalLM try: import torch_npu from opensora.npu_config import npu_config from opensora.acceleration.parallel_states import initialize_sequence_parallel_state, hccl_info except: torch_npu = None npu_config = None from opensora.utils.parallel_states import initialize_sequence_parallel_state, nccl_info pass from opensora.utils.utils import set_seed from opensora.models.causalvideovae import ae_stride_config, ae_wrapper from opensora.sample.pipeline_opensora import OpenSoraPipeline from opensora.sample.pipeline_inpaint import OpenSoraInpaintPipeline from opensora.models.diffusion.opensora_v1_3.modeling_opensora import OpenSoraT2V_v1_3 from opensora.models.diffusion.opensora_v1_3.modeling_inpaint import OpenSoraInpaint_v1_3 from transformers import T5EncoderModel, T5Tokenizer, AutoTokenizer, MT5EncoderModel, CLIPTextModelWithProjection def get_scheduler(args): kwargs = dict( prediction_type=args.prediction_type, rescale_betas_zero_snr=args.rescale_betas_zero_snr, timestep_spacing="trailing" if args.rescale_betas_zero_snr else 'leading', ) if args.v1_5_scheduler: kwargs['beta_start'] = 0.00085 kwargs['beta_end'] = 0.0120 kwargs['beta_schedule'] = "scaled_linear" if args.sample_method == 'DDIM': scheduler_cls = DDIMScheduler kwargs['clip_sample'] = False elif args.sample_method == 'EulerDiscrete': scheduler_cls = EulerDiscreteScheduler elif args.sample_method == 'DDPM': scheduler_cls = DDPMScheduler kwargs['clip_sample'] = False elif args.sample_method == 'DPMSolverMultistep': scheduler_cls = DPMSolverMultistepScheduler elif args.sample_method == 'DPMSolverSinglestep': scheduler_cls = DPMSolverSinglestepScheduler elif args.sample_method == 'PNDM': scheduler_cls = PNDMScheduler kwargs.pop('rescale_betas_zero_snr', None) elif args.sample_method == 'HeunDiscrete': ######## scheduler_cls = HeunDiscreteScheduler elif args.sample_method == 'EulerAncestralDiscrete': scheduler_cls = EulerAncestralDiscreteScheduler elif args.sample_method == 'DEISMultistep': scheduler_cls = DEISMultistepScheduler kwargs.pop('rescale_betas_zero_snr', None) elif args.sample_method == 'KDPM2AncestralDiscrete': ######### scheduler_cls = KDPM2AncestralDiscreteScheduler elif args.sample_method == 'CogVideoX': scheduler_cls = CogVideoXDDIMScheduler elif args.sample_method == 'FlowMatchEulerDiscrete': scheduler_cls = FlowMatchEulerDiscreteScheduler kwargs = {} else: raise NameError(f'Unsupport sample_method {args.sample_method}') scheduler = scheduler_cls(**kwargs) return scheduler def prepare_pipeline(args, dtype, device): weight_dtype = dtype vae = ae_wrapper[args.ae](args.ae_path) vae.vae = vae.vae.to(device=device, dtype=weight_dtype).eval() vae.vae_scale_factor = ae_stride_config[args.ae] if args.enable_tiling: vae.vae.enable_tiling() if 'mt5' in args.text_encoder_name_1: text_encoder_1 = MT5EncoderModel.from_pretrained( args.text_encoder_name_1, cache_dir=args.cache_dir, torch_dtype=weight_dtype ).eval() else: text_encoder_1 = T5EncoderModel.from_pretrained( args.text_encoder_name_1, cache_dir=args.cache_dir, torch_dtype=weight_dtype ).eval() tokenizer_1 = AutoTokenizer.from_pretrained( args.text_encoder_name_1, cache_dir=args.cache_dir ) if args.text_encoder_name_2 is not None: text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( args.text_encoder_name_2, cache_dir=args.cache_dir, torch_dtype=weight_dtype ).eval() tokenizer_2 = AutoTokenizer.from_pretrained( args.text_encoder_name_2, cache_dir=args.cache_dir ) else: text_encoder_2, tokenizer_2 = None, None if args.version == 'v1_3': if args.model_type == 'inpaint' or args.model_type == 'i2v': transformer_model = OpenSoraInpaint_v1_3.from_pretrained( args.model_path, cache_dir=args.cache_dir, device_map=None, torch_dtype=weight_dtype ).eval() else: transformer_model = OpenSoraT2V_v1_3.from_pretrained( args.model_path, cache_dir=args.cache_dir, device_map=None, torch_dtype=weight_dtype ).eval() elif args.version == 'v1_5': if args.model_type == 'inpaint' or args.model_type == 'i2v': raise NotImplementedError('Inpainting model is not available in v1_5') else: from opensora.models.diffusion.opensora_v1_5.modeling_opensora import OpenSoraT2V_v1_5 transformer_model = OpenSoraT2V_v1_5.from_pretrained( args.model_path, cache_dir=args.cache_dir, # device_map=None, torch_dtype=weight_dtype ).eval() scheduler = get_scheduler(args) pipeline_class = OpenSoraInpaintPipeline if args.model_type == 'inpaint' or args.model_type == 'i2v' else OpenSoraPipeline pipeline = pipeline_class( vae=vae, text_encoder=text_encoder_1, tokenizer=tokenizer_1, scheduler=scheduler, transformer=transformer_model, text_encoder_2=text_encoder_2, tokenizer_2=tokenizer_2, ).to(device) if args.save_memory: print('enable_model_cpu_offload AND enable_sequential_cpu_offload AND enable_tiling') pipeline.enable_model_cpu_offload() pipeline.enable_sequential_cpu_offload() # torch.cuda.empty_cache() vae.vae.enable_tiling() vae.vae.t_chunk_enc = 8 vae.vae.t_chunk_dec = vae.vae.t_chunk_enc // 2 if args.compile: pipeline.transformer = torch.compile(pipeline.transformer) return pipeline def init_gpu_env(args): local_rank = int(os.getenv('RANK', 0)) world_size = int(os.getenv('WORLD_SIZE', 1)) args.local_rank = local_rank args.world_size = world_size torch.cuda.set_device(local_rank) dist.init_process_group( backend='nccl', init_method='env://', world_size=world_size, rank=local_rank ) if args.sp: initialize_sequence_parallel_state(world_size) return args def init_npu_env(args): local_rank = int(os.getenv('RANK', 0)) world_size = int(os.getenv('WORLD_SIZE', 1)) args.local_rank = local_rank args.world_size = world_size torch_npu.npu.set_device(local_rank) dist.init_process_group( backend='hccl', init_method='env://', world_size=world_size, rank=local_rank ) if args.sp: initialize_sequence_parallel_state(world_size) return args def save_video_grid(video, nrow=None): b, t, h, w, c = video.shape if nrow is None: nrow = math.ceil(math.sqrt(b)) ncol = math.ceil(b / nrow) padding = 1 video_grid = torch.zeros( ( t, (padding + h) * nrow + padding, (padding + w) * ncol + padding, c ), dtype=torch.uint8 ) for i in range(b): r = i // ncol c = i % ncol start_r = (padding + h) * r start_c = (padding + w) * c video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] return video_grid def run_model_and_save_samples(args, pipeline, caption_refiner_model=None, enhance_video_model=None): if args.seed is not None: set_seed(args.seed, rank=args.local_rank, device_specific=True) if args.local_rank >= 0: torch.manual_seed(args.seed + args.local_rank) if not os.path.exists(args.save_img_path): os.makedirs(args.save_img_path, exist_ok=True) video_grids = [] if not isinstance(args.text_prompt, list): args.text_prompt = [args.text_prompt] if len(args.text_prompt) == 1 and args.text_prompt[0].endswith('txt'): text_prompt = open(args.text_prompt[0], 'r').readlines() args.text_prompt = [i.strip() for i in text_prompt] if args.model_type == 'inpaint' or args.model_type == 'i2v': if not isinstance(args.conditional_pixel_values_path, list): args.conditional_pixel_values_path = [args.conditional_pixel_values_path] if len(args.conditional_pixel_values_path) == 1 and args.conditional_pixel_values_path[0].endswith('txt'): temp = open(args.conditional_pixel_values_path[0], 'r').readlines() conditional_pixel_values_path = [i.strip().split(',') for i in temp] mask_type = args.mask_type if args.mask_type is not None else None positive_prompt = """ high quality, high aesthetic, {} """ negative_prompt = """ nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry. """ def generate(prompt, conditional_pixel_values_path=None, mask_type=None): if args.caption_refiner is not None: if args.model_type != 'inpaint' and args.model_type != 'i2v': refine_prompt = caption_refiner_model.get_refiner_output(prompt) print(f'\nOrigin prompt: {prompt}\n->\nRefine prompt: {refine_prompt}') prompt = refine_prompt else: # Due to the current use of LLM as the caption refiner, additional content that is not present in the control image will be added. Therefore, caption refiner is not used in this mode. print('Caption refiner is not available for inpainting model, use the original prompt...') time.sleep(3) input_prompt = positive_prompt.format(prompt) if args.model_type == 'inpaint' or args.model_type == 'i2v': print(f'\nConditional pixel values path: {conditional_pixel_values_path}') videos = pipeline( conditional_pixel_values_path=conditional_pixel_values_path, mask_type=mask_type, crop_for_hw=args.crop_for_hw, max_hxw=args.max_hxw, noise_strength=args.noise_strength, prompt=input_prompt, negative_prompt=negative_prompt, num_frames=args.num_frames, height=args.height, width=args.width, num_inference_steps=args.num_sampling_steps, guidance_scale=args.guidance_scale, num_samples_per_prompt=args.num_samples_per_prompt, max_sequence_length=args.max_sequence_length, ).videos else: videos = pipeline( input_prompt, negative_prompt=negative_prompt, num_frames=args.num_frames, height=args.height, width=args.width, num_inference_steps=args.num_sampling_steps, guidance_scale=args.guidance_scale, num_samples_per_prompt=args.num_samples_per_prompt, max_sequence_length=args.max_sequence_length, ).videos if enhance_video_model is not None: # b t h w c videos = enhance_video_model.enhance_a_video(videos, input_prompt, 2.0, args.fps, 250) if (not args.sp) or (args.sp and args.local_rank <= 0): if args.num_frames == 1: videos = rearrange(videos, 'b t h w c -> (b t) c h w') if args.num_samples_per_prompt != 1: for i, image in enumerate(videos): save_image( image / 255.0, os.path.join( args.save_img_path, f'{args.sample_method}_{index}_gs{args.guidance_scale}_s{args.num_sampling_steps}_i{i}.jpg' ), nrow=math.ceil(math.sqrt(videos.shape[0])), normalize=True, value_range=(0, 1) ) # b c h w save_image( videos / 255.0, os.path.join( args.save_img_path, f'{args.sample_method}_{index}_gs{args.guidance_scale}_s{args.num_sampling_steps}.jpg' ), nrow=math.ceil(math.sqrt(videos.shape[0])), normalize=True, value_range=(0, 1) ) # b c h w else: if args.num_samples_per_prompt == 1: imageio.mimwrite( os.path.join( args.save_img_path, f'{args.sample_method}_{index}_gs{args.guidance_scale}_s{args.num_sampling_steps}.mp4' ), videos[0], fps=args.fps, quality=6 ) # highest quality is 10, lowest is 0 else: for i in range(args.num_samples_per_prompt): imageio.mimwrite( os.path.join( args.save_img_path, f'{args.sample_method}_{index}_gs{args.guidance_scale}_s{args.num_sampling_steps}_i{i}.mp4' ), videos[i], fps=args.fps, quality=6 ) # highest quality is 10, lowest is 0 videos = save_video_grid(videos) imageio.mimwrite( os.path.join( args.save_img_path, f'{args.sample_method}_{index}_gs{args.guidance_scale}_s{args.num_sampling_steps}.mp4' ), videos, fps=args.fps, quality=6 ) # highest quality is 10, lowest is 0) videos = videos.unsqueeze(0) # 1 t h w c video_grids.append(videos) if args.model_type == 'inpaint' or args.model_type == 'i2v': for index, (prompt, cond_path) in enumerate(zip(args.text_prompt, conditional_pixel_values_path)): if not args.sp and args.local_rank != -1 and index % args.world_size != args.local_rank: continue generate(prompt, cond_path, mask_type) else: for index, prompt in enumerate(args.text_prompt): if not args.sp and args.local_rank != -1 and index % args.world_size != args.local_rank: continue # skip when ddp generate(prompt) if (args.model_type == "inpaint" or args.model_type == "i2v") and not args.crop_for_hw: print('completed, please check the saved images and videos') else: if not args.sp: if args.local_rank != -1: dist.barrier() video_grids = torch.cat(video_grids, dim=0).cuda() shape = list(video_grids.shape) shape[0] *= args.world_size gathered_tensor = torch.zeros(shape, dtype=video_grids.dtype).cuda() dist.all_gather_into_tensor(gathered_tensor, video_grids.contiguous()) video_grids = gathered_tensor.cpu() dist.barrier() else: video_grids = torch.cat(video_grids, dim=0) elif args.sp and args.local_rank <= 0: video_grids = torch.cat(video_grids) if args.local_rank <= 0: if args.num_frames == 1: save_image( video_grids / 255.0, os.path.join( args.save_img_path, f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.jpg' ), nrow=math.ceil(math.sqrt(len(video_grids))), normalize=True, value_range=(0, 1) ) else: video_grids = save_video_grid(video_grids) imageio.mimwrite( os.path.join( args.save_img_path, f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.mp4' ), video_grids, fps=args.fps, quality=6 ) print('save path {}'.format(args.save_img_path)) def run_model_and_save_samples_npu(args, pipeline, caption_refiner_model=None, enhance_video_model=None): # experimental_config = torch_npu.profiler._ExperimentalConfig( # profiler_level=torch_npu.profiler.ProfilerLevel.Level1, # aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization # ) # profile_output_path = "/home/image_data/npu_profiling_t2v" # os.makedirs(profile_output_path, exist_ok=True) # with torch_npu.profiler.profile( # activities=[ # torch_npu.profiler.ProfilerActivity.NPU, # torch_npu.profiler.ProfilerActivity.CPU # ], # with_stack=True, # record_shapes=True, # profile_memory=True, # experimental_config=experimental_config, # schedule=torch_npu.profiler.schedule( # wait=10000, warmup=0, active=1, repeat=1, skip_first=0 # ), # on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(f"{profile_output_path}/") # ) as prof: run_model_and_save_samples(args, pipeline, caption_refiner_model, enhance_video_model) # prof.step() def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default='LanguageBind/Open-Sora-Plan-v1.0.0') parser.add_argument("--version", type=str, default='v1_3', choices=['v1_3', 'v1_5']) parser.add_argument("--model_type", type=str, default='t2v', choices=['t2v', 'inpaint', 'i2v']) parser.add_argument("--num_frames", type=int, default=1) parser.add_argument("--height", type=int, default=512) parser.add_argument("--width", type=int, default=512) parser.add_argument("--device", type=str, default='cuda:0') parser.add_argument("--cache_dir", type=str, default='./cache_dir') parser.add_argument("--caption_refiner", type=str, default=None) parser.add_argument("--ae", type=str, default='CausalVAEModel_4x8x8') parser.add_argument("--ae_path", type=str, default='CausalVAEModel_4x8x8') parser.add_argument("--enhance_video", type=str, default=None) parser.add_argument("--text_encoder_name_1", type=str, default='DeepFloyd/t5-v1_1-xxl') parser.add_argument("--text_encoder_name_2", type=str, default=None) parser.add_argument("--save_img_path", type=str, default="./sample_videos/t2v") parser.add_argument("--guidance_scale", type=float, default=7.5) parser.add_argument("--sample_method", type=str, default="PNDM") parser.add_argument("--num_sampling_steps", type=int, default=50) parser.add_argument("--fps", type=int, default=24) parser.add_argument("--max_sequence_length", type=int, default=512) parser.add_argument("--text_prompt", nargs='+') parser.add_argument("--seed", type=int, default=42) parser.add_argument("--num_samples_per_prompt", type=int, default=1) parser.add_argument('--enable_tiling', action='store_true') parser.add_argument('--refine_caption', action='store_true') parser.add_argument('--compile', action='store_true') parser.add_argument('--save_memory', action='store_true') parser.add_argument("--prediction_type", type=str, default='epsilon', help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.") parser.add_argument('--rescale_betas_zero_snr', action='store_true') parser.add_argument('--local_rank', type=int, default=-1) parser.add_argument('--world_size', type=int, default=1) parser.add_argument('--sp', action='store_true') parser.add_argument('--v1_5_scheduler', action='store_true') parser.add_argument('--conditional_pixel_values_path', type=str, default=None) parser.add_argument('--mask_type', type=str, default=None) parser.add_argument('--crop_for_hw', action='store_true') parser.add_argument('--max_hxw', type=int, default=236544) # 480*480 parser.add_argument('--noise_strength', type=float, default=0.0) args = parser.parse_args() assert not (args.sp and args.num_frames == 1) return args ================================================ FILE: opensora/utils/utils.py ================================================ import os import torch import os import math import torch import logging import random import subprocess import numpy as np import torch.distributed as dist # from torch._six import inf import accelerate from torch import inf from PIL import Image from typing import Union, Iterable import collections from collections import OrderedDict from torch.utils.tensorboard import SummaryWriter import wandb import time from diffusers.utils import is_bs4_available, is_ftfy_available import html import re import urllib.parse as ul if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy _tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] def to_2tuple(x): if isinstance(x, collections.abc.Iterable): return x return (x, x) def explicit_uniform_sampling(T, n, rank, bsz, device): """ Explicit Uniform Sampling with integer timesteps and PyTorch. Args: T (int): Maximum timestep value. n (int): Number of ranks (data parallel processes). rank (int): The rank of the current process (from 0 to n-1). bsz (int): Batch size, number of timesteps to return. Returns: torch.Tensor: A tensor of shape (bsz,) containing uniformly sampled integer timesteps within the rank's interval. """ interval_size = T / n # Integer division to ensure boundaries are integers lower_bound = interval_size * rank - 0.5 upper_bound = interval_size * (rank + 1) - 0.5 sampled_timesteps = [round(random.uniform(lower_bound, upper_bound)) for _ in range(bsz)] # Uniformly sample within the rank's interval, returning integers sampled_timesteps = torch.tensor([round(random.uniform(lower_bound, upper_bound)) for _ in range(bsz)], device=device) sampled_timesteps = sampled_timesteps.long() return sampled_timesteps ################################################################################# # Training Clip Gradients # ################################################################################# def get_grad_norm( parameters: _tensor_or_tensors, norm_type: float = 2.0) -> torch.Tensor: r""" Copy from torch.nn.utils.clip_grad_norm_ Clips gradient norm of an iterable of parameters. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place. Args: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. error_if_nonfinite (bool): if True, an error is thrown if the total norm of the gradients from :attr:`parameters` is ``nan``, ``inf``, or ``-inf``. Default: False (will switch to True in the future) Returns: Total norm of the parameter gradients (viewed as a single vector). """ if isinstance(parameters, torch.Tensor): parameters = [parameters] grads = [p.grad for p in parameters if p.grad is not None] norm_type = float(norm_type) if len(grads) == 0: return torch.tensor(0.) device = grads[0].device if norm_type == inf: norms = [g.detach().abs().max().to(device) for g in grads] total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) else: total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) return total_norm def clip_grad_norm_( parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0, error_if_nonfinite: bool = False, clip_grad=True) -> torch.Tensor: r""" Copy from torch.nn.utils.clip_grad_norm_ Clips gradient norm of an iterable of parameters. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place. Args: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. error_if_nonfinite (bool): if True, an error is thrown if the total norm of the gradients from :attr:`parameters` is ``nan``, ``inf``, or ``-inf``. Default: False (will switch to True in the future) Returns: Total norm of the parameter gradients (viewed as a single vector). """ if isinstance(parameters, torch.Tensor): parameters = [parameters] grads = [p.grad for p in parameters if p.grad is not None] max_norm = float(max_norm) norm_type = float(norm_type) if len(grads) == 0: return torch.tensor(0.) device = grads[0].device if norm_type == inf: norms = [g.detach().abs().max().to(device) for g in grads] total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) else: total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) if clip_grad: if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): raise RuntimeError( f'The total norm of order {norm_type} for gradients from ' '`parameters` is non-finite, so it cannot be clipped. To disable ' 'this error and scale the gradients by the non-finite norm anyway, ' 'set `error_if_nonfinite=False`') clip_coef = max_norm / (total_norm + 1e-6) # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization # when the gradients do not reside in CPU memory. clip_coef_clamped = torch.clamp(clip_coef, max=1.0) for g in grads: g.detach().mul_(clip_coef_clamped.to(g.device)) # gradient_cliped = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) # print(gradient_cliped) return total_norm def get_experiment_dir(root_dir, args): # if args.pretrained is not None and 'Latte-XL-2-256x256.pt' not in args.pretrained: # root_dir += '-WOPRE' if args.use_compile: root_dir += '-Compile' # speedup by torch compile if args.attention_mode: root_dir += f'-{args.attention_mode.upper()}' # if args.enable_xformers_memory_efficient_attention: # root_dir += '-Xfor' if args.gradient_checkpointing: root_dir += '-Gc' if args.mixed_precision: root_dir += f'-{args.mixed_precision.upper()}' root_dir += f'-{args.max_image_size}' return root_dir def get_precision(args): if args.mixed_precision == "bf16": dtype = torch.bfloat16 elif args.mixed_precision == "fp16": dtype = torch.float16 else: dtype = torch.float32 return dtype ################################################################################# # Training Logger # ################################################################################# def create_logger(logging_dir): """ Create a logger that writes to a log file and stdout. """ if dist.get_rank() == 0: # real logger logging.basicConfig( level=logging.INFO, # format='[\033[34m%(asctime)s\033[0m] %(message)s', format='[%(asctime)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S', handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] ) logger = logging.getLogger(__name__) else: # dummy logger (does nothing) logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) return logger def create_tensorboard(tensorboard_dir): """ Create a tensorboard that saves losses. """ if dist.get_rank() == 0: # real tensorboard # tensorboard writer = SummaryWriter(tensorboard_dir) return writer def write_tensorboard(writer, *args): ''' write the loss information to a tensorboard file. Only for pytorch DDP mode. ''' if dist.get_rank() == 0: # real tensorboard writer.add_scalar(args[0], args[1], args[2]) def get_npu_power(): result = subprocess.run(["npu-smi", "info"], stdout=subprocess.PIPE, text=True) power_data = {} npu_id = None # 解析npu-smi的输出 for line in result.stdout.splitlines(): if line.startswith("| NPU"): npu_id = 0 # 开始新NPU记录 elif line.startswith("|") and npu_id is not None: parts = line.split("|") if len(parts) > 4: power = parts[4].strip().split()[0] # 提取Power(W) # 记录每个NPU的功率信息 power_data[f"NPU_{npu_id}_Power_W"] = float(power) npu_id += 1 return power_data def monitor_npu_power(): while wandb.run is not None: power_data = get_npu_power() wandb.log(power_data) # 实时记录NPU功率信息到wandb time.sleep(10) # 每10秒采集一次数据 ################################################################################# # EMA Update/ DDP Training Utils # ################################################################################# @torch.no_grad() def update_ema(ema_model, model, decay=0.9999): """ Step the EMA model towards the current model. """ ema_params = OrderedDict(ema_model.named_parameters()) model_params = OrderedDict(model.named_parameters()) for name, param in model_params.items(): # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) def requires_grad(model, flag=True): """ Set requires_grad flag for all parameters in a model. """ for p in model.parameters(): p.requires_grad = flag def cleanup(): """ End DDP training. """ dist.destroy_process_group() # adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/random.py#L31 def set_seed(seed, rank, device_specific=True): if device_specific: seed += rank random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def setup_distributed(backend="nccl", port=None): """Initialize distributed training environment. support both slurm and torch.distributed.launch see torch.distributed.init_process_group() for more details """ num_gpus = torch.cuda.device_count() if "SLURM_JOB_ID" in os.environ: rank = int(os.environ["SLURM_PROCID"]) world_size = int(os.environ["SLURM_NTASKS"]) node_list = os.environ["SLURM_NODELIST"] addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") # specify master port if port is not None: os.environ["MASTER_PORT"] = str(port) elif "MASTER_PORT" not in os.environ: # os.environ["MASTER_PORT"] = "29566" os.environ["MASTER_PORT"] = str(29567 + num_gpus) if "MASTER_ADDR" not in os.environ: os.environ["MASTER_ADDR"] = addr os.environ["WORLD_SIZE"] = str(world_size) os.environ["LOCAL_RANK"] = str(rank % num_gpus) os.environ["RANK"] = str(rank) else: rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) # torch.cuda.set_device(rank % num_gpus) dist.init_process_group( backend=backend, world_size=world_size, rank=rank, ) ################################################################################# # MMCV Utils # ################################################################################# def collect_env(): # Copyright (c) OpenMMLab. All rights reserved. from mmcv.utils import collect_env as collect_base_env from mmcv.utils import get_git_hash """Collect the information of the running environments.""" env_info = collect_base_env() env_info['MMClassification'] = get_git_hash()[:7] for name, val in env_info.items(): print(f'{name}: {val}') print(torch.cuda.get_arch_list()) print(torch.version.cuda) ################################################################################# # Pixart-alpha Utils # ################################################################################# bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa def text_preprocessing(text, support_Chinese=True): # The exact text cleaning as was in the training stage: text = clean_caption(text, support_Chinese=support_Chinese) return text def basic_clean(text): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) return text.strip() def clean_caption(caption, support_Chinese=True): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub('', 'person', caption) # urls: caption = re.sub( r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa '', caption) # regex for urls caption = re.sub( r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa '', caption) # regex for urls # html: caption = BeautifulSoup(caption, features='html.parser').text # @ caption = re.sub(r'@[\w\d]+\b', '', caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r'[\u31c0-\u31ef]+', '', caption) caption = re.sub(r'[\u31f0-\u31ff]+', '', caption) caption = re.sub(r'[\u3200-\u32ff]+', '', caption) caption = re.sub(r'[\u3300-\u33ff]+', '', caption) caption = re.sub(r'[\u3400-\u4dbf]+', '', caption) caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption) if not support_Chinese: caption = re.sub(r'[\u4e00-\u9fff]+', '', caption) # Chinese ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa '-', caption) # кавычки к одному стандарту caption = re.sub(r'[`´«»“”¨]', '"', caption) caption = re.sub(r'[‘’]', "'", caption) # " caption = re.sub(r'"?', '', caption) # & caption = re.sub(r'&', '', caption) # ip adresses: caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption) # article ids: caption = re.sub(r'\d:\d\d\s+$', '', caption) # \n caption = re.sub(r'\\n', ' ', caption) # "#123" caption = re.sub(r'#\d{1,3}\b', '', caption) # "#12345.." caption = re.sub(r'#\d{5,}\b', '', caption) # "123456.." caption = re.sub(r'\b\d{6,}\b', '', caption) # filenames: caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption) # caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT""" caption = re.sub(bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r'\s+\.\s+', r' ', caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r'(?:\-|\_)') if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, ' ', caption) caption = basic_clean(caption) caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640 caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231 caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption) caption = re.sub(r'(free\s)?download(\sfree)?', '', caption) caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption) caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption) caption = re.sub(r'\bpage\s+\d+\b', '', caption) caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a... caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption) caption = re.sub(r'\b\s+\:\s+', r': ', caption) caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption) caption = re.sub(r'\s+', ' ', caption) caption.strip() caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption) caption = re.sub(r'^[\'\_,\-\:;]', r'', caption) caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption) caption = re.sub(r'^\.\S+$', '', caption) return caption.strip() if __name__ == '__main__': # caption = re.sub(r'[\u4e00-\u9fff]+', '', caption) a = "امرأة مسنة بشعر أبيض ووجه مليء بالتجاعيد تجلس داخل سيارة قديمة الطراز، تنظر من خلال النافذة الجانبية بتعبير تأملي أو حزين قليلاً." print(a) print(text_preprocessing(a)) ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" [project] name = "opensora" version = "1.3.0" description = "Reproduce OpenAI's Sora." readme = "README.md" requires-python = ">=3.8" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", ] dependencies = [ "transformers==4.44.2", "tokenizers==0.19.1", "albumentations==1.4.0", "av==11.0.0", "decord==0.6.0", "einops==0.7.0", "fastapi==0.110.0", "gdown==5.1.0", "h5py==3.10.0", "idna==3.8", 'imageio==2.34.0', "matplotlib==3.7.5", "numpy==1.24.4", "omegaconf==2.1.1", "opencv-python==4.9.0.80", "opencv-python-headless==4.9.0.80", "pandas==2.0.3", "pillow==10.2.0", "pydub==0.25.1", "pytorchvideo==0.1.5", "PyYAML==6.0.2", "regex==2024.7.24", "requests==2.32.3", "scikit-learn==1.3.2", "scipy==1.10.1", "six==1.16.0", "test-tube==0.7.5", "timm==0.9.16", "torchdiffeq==0.2.3", "torchmetrics==1.3.2", "tqdm==4.66.5", "urllib3==2.2.2", "uvicorn==0.27.1", "scikit-video==1.1.11", "imageio-ffmpeg==0.4.9", "sentencepiece==0.1.99", "beautifulsoup4==4.12.3", "ftfy==6.1.3", "moviepy==1.0.3", "wandb==0.16.3", "tensorboard==2.14.0", "pydantic==2.6.4", "gradio==4.0.0", "torch==2.1.0", "torchvision==0.16.0", "xformers==0.0.22.post7", "accelerate==0.34.0", "diffusers==0.30.2", "deepspeed==0.12.6" ] [project.optional-dependencies] dev = ["mypy==1.8.0"] [project.urls] "Homepage" = "https://github.com/PKU-YuanGroup/Open-Sora-Plan" "Bug Tracker" = "https://github.com/PKU-YuanGroup/Open-Sora-Plan/issues" [tool.setuptools.packages.find] exclude = ["assets*", "docker*", "docs", "scripts*"] [tool.wheel] exclude = ["assets*", "docker*", "docs", "scripts*"] [tool.mypy] warn_return_any = true warn_unused_configs = true ignore_missing_imports = true disallow_untyped_calls = true check_untyped_defs = true no_implicit_optional = true ================================================ FILE: scripts/accelerate_configs/ddp_config.yaml ================================================ compute_environment: LOCAL_MACHINE distributed_type: MULTI_GPU fsdp_config: {} machine_rank: 0 main_process_ip: null main_process_port: 29501 main_training_function: main num_machines: 1 num_processes: 1 gpu_ids: 0, use_cpu: false ================================================ FILE: scripts/accelerate_configs/deepspeed_zero2_config.yaml ================================================ compute_environment: LOCAL_MACHINE distributed_type: DEEPSPEED deepspeed_config: deepspeed_config_file: scripts/accelerate_configs/zero2.json fsdp_config: {} machine_rank: 0 main_process_ip: null main_process_port: 29513 main_training_function: main num_machines: 1 num_processes: 8 gpu_ids: 0,1,2,3,4,5,6,7 use_cpu: false ================================================ FILE: scripts/accelerate_configs/deepspeed_zero2_offload_config.yaml ================================================ compute_environment: LOCAL_MACHINE distributed_type: DEEPSPEED deepspeed_config: deepspeed_config_file: scripts/accelerate_configs/zero2_offload.json fsdp_config: {} machine_rank: 0 main_process_ip: null main_process_port: 29501 main_training_function: main num_machines: 1 num_processes: 8 gpu_ids: 0,1,2,3,4,5,6,7 use_cpu: false ================================================ FILE: scripts/accelerate_configs/deepspeed_zero3_config.yaml ================================================ compute_environment: LOCAL_MACHINE distributed_type: DEEPSPEED deepspeed_config: deepspeed_config_file: scripts/accelerate_configs/zero3.json fsdp_config: {} machine_rank: 0 main_process_ip: null main_process_port: 29501 main_training_function: main num_machines: 1 num_processes: 8 gpu_ids: 0,1,2,3,4,5,6,7 use_cpu: false ================================================ FILE: scripts/accelerate_configs/deepspeed_zero3_offload_config.yaml ================================================ compute_environment: LOCAL_MACHINE distributed_type: DEEPSPEED deepspeed_config: deepspeed_config_file: scripts/accelerate_configs/zero3_offload.json fsdp_config: {} machine_rank: 0 main_process_ip: null main_process_port: 29501 main_training_function: main num_machines: 1 num_processes: 8 gpu_ids: 0,1,2,3,4,5,6,7 use_cpu: false ================================================ FILE: scripts/accelerate_configs/default_config.yaml ================================================ compute_environment: LOCAL_MACHINE distributed_type: MULTI_GPU fsdp_config: {} machine_rank: 0 main_process_ip: null main_process_port: 29501 main_training_function: main mixed_precision: bf16 num_machines: 1 num_processes: 8 gpu_ids: 0,1,2,3,4,5,6,7 use_cpu: false ================================================ FILE: scripts/accelerate_configs/hostfile ================================================ 100.64.24.30 slots=8 100.64.24.6 slots=8 100.64.24.7 slots=8 100.64.24.8 slots=8 100.64.24.10 slots=8 100.64.24.11 slots=8 100.64.24.13 slots=8 100.64.24.14 slots=8 100.64.24.17 slots=8 100.64.24.19 slots=8 100.64.24.26 slots=8 100.64.24.27 slots=8 100.64.24.28 slots=8 100.64.24.29 slots=8 100.64.24.31 slots=8 100.64.24.32 slots=8 ================================================ FILE: scripts/accelerate_configs/zero2.json ================================================ { "fp16": { "enabled": false, "loss_scale": 0, "loss_scale_window": 1000, "initial_scale_power": 16, "hysteresis": 2, "min_loss_scale": 1 }, "bf16": { "enabled": "auto" }, "communication_data_type": "fp32", "gradient_clipping": 1.0, "train_micro_batch_size_per_gpu": "auto", "train_batch_size": "auto", "gradient_accumulation_steps": "auto", "zero_optimization": { "stage": 2, "overlap_comm": true, "contiguous_gradients": true, "sub_group_size": 1e9, "reduce_bucket_size": 5e8 } } ================================================ FILE: scripts/accelerate_configs/zero2_npu.json ================================================ { "fp16": { "enabled": false, "loss_scale": 0, "loss_scale_window": 1000, "initial_scale_power": 16, "hysteresis": 2, "min_loss_scale": 1 }, "bf16": { "enabled": "auto" }, "communication_data_type": "fp32", "gradient_clipping": 1.0, "train_micro_batch_size_per_gpu": "auto", "train_batch_size": "auto", "gradient_accumulation_steps": "auto", "zero_optimization": { "stage": 2, "overlap_comm": true, "allgather_bucket_size": 536870912, "contiguous_gradients": true, "reduce_bucket_size": 536870912 } } ================================================ FILE: scripts/accelerate_configs/zero2_offload.json ================================================ { "fp16": { "enabled": "auto", "loss_scale": 0, "loss_scale_window": 1000, "initial_scale_power": 16, "hysteresis": 2, "min_loss_scale": 1 }, "bf16": { "enabled": "auto" }, "communication_data_type": "fp32", "gradient_clipping": 1.0, "train_micro_batch_size_per_gpu": "auto", "train_batch_size": "auto", "gradient_accumulation_steps": "auto", "zero_optimization": { "stage": 2, "offload_optimizer": { "device": "cpu" }, "overlap_comm": true, "contiguous_gradients": true, "sub_group_size": 1e9, "reduce_bucket_size": 5e8, "round_robin_gradients": true } } ================================================ FILE: scripts/accelerate_configs/zero3.json ================================================ { "fp16": { "enabled": "auto", "loss_scale": 0, "loss_scale_window": 1000, "initial_scale_power": 16, "hysteresis": 2, "min_loss_scale": 1 }, "bf16": { "enabled": "auto" }, "communication_data_type": "fp32", "gradient_clipping": 1.0, "train_micro_batch_size_per_gpu": "auto", "train_batch_size": "auto", "gradient_accumulation_steps": "auto", "zero_optimization": { "stage": 3, "overlap_comm": true, "contiguous_gradients": true, "sub_group_size": 1e9, "reduce_bucket_size": 5e8, "stage3_prefetch_bucket_size": 5e8, "stage3_param_persistence_threshold": "auto", "stage3_max_live_parameters": 1e9, "stage3_max_reuse_distance": 1e9, "stage3_gather_16bit_weights_on_model_save": true } } ================================================ FILE: scripts/accelerate_configs/zero3_offload.json ================================================ { "fp16": { "enabled": "auto", "loss_scale": 0, "loss_scale_window": 1000, "initial_scale_power": 16, "hysteresis": 2, "min_loss_scale": 1 }, "bf16": { "enabled": "auto" }, "zero_optimization": { "stage": 3, "offload_optimizer": { "device": "cpu", "pin_memory": true }, "offload_param": { "device": "cpu", "pin_memory": true }, "overlap_comm": true, "contiguous_gradients": true, "sub_group_size": 1e9, "reduce_bucket_size": 5e8, "stage3_prefetch_bucket_size": "auto", "stage3_param_persistence_threshold": "auto", "stage3_max_live_parameters": 1e9, "stage3_max_reuse_distance": 1e9, "gather_16bit_weights_on_model_save": true }, "gradient_accumulation_steps": "auto", "gradient_clipping": "auto", "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", "steps_per_print": 1e5, "wall_clock_breakdown": false } ================================================ FILE: scripts/causalvae/eval.sh ================================================ EXP_NAME=wfvae-4dim SAMPLE_RATE=1 NUM_FRAMES=33 RESOLUTION=256 METRIC=lpips SUBSET_SIZE=0 ORIGIN_DIR=video_gen/${EXP_NAME}_sr${SAMPLE_RATE}_nf${NUM_FRAMES}_res${RESOLUTION}_subset${SUBSET_SIZE}/origin RECON_DIR=video_gen/${EXP_NAME}_sr${SAMPLE_RATE}_nf${NUM_FRAMES}_res${RESOLUTION}_subset${SUBSET_SIZE} python opensora/models/causalvideovae/eval/eval.py \ --batch_size 8 \ --real_video_dir ${ORIGIN_DIR} \ --generated_video_dir ${RECON_DIR} \ --device cuda:1 \ --sample_fps 1 \ --sample_rate ${SAMPLE_RATE} \ --num_frames ${NUM_FRAMES} \ --resolution ${RESOLUTION} \ --crop_size ${RESOLUTION} \ --subset_size ${SUBSET_SIZE} \ --metric ${METRIC} ================================================ FILE: scripts/causalvae/prepare_eval.sh ================================================ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 DATASET_DIR=test_video EXP_NAME=wfvae SAMPLE_RATE=1 NUM_FRAMES=33 RESOLUTION=256 CKPT=ckpt SUBSET_SIZE=0 accelerate launch \ --config_file scripts/accelerate_configs/default_config.yaml \ opensora/models/causalvideovae/sample/rec_video_vae.py \ --batch_size 1 \ --real_video_dir ${DATASET_DIR} \ --generated_video_dir video_gen/${EXP_NAME}_sr${SAMPLE_RATE}_nf${NUM_FRAMES}_res${RESOLUTION}_subset${SUBSET_SIZE} \ --device cuda \ --sample_fps 24 \ --sample_rate ${SAMPLE_RATE} \ --num_frames ${NUM_FRAMES} \ --resolution ${RESOLUTION} \ --subset_size ${SUBSET_SIZE} \ --num_workers 8 \ --from_pretrained ${CKPT} \ --model_name WFVAE \ --output_origin \ --crop_size ${RESOLUTION} ================================================ FILE: scripts/causalvae/rec_image.sh ================================================ CUDA_VISIBLE_DEVICES=0 python examples/rec_image.py \ --ae WFVAEModel_D8_4x8x8 \ --ae_path "/storage/lcm/WF-VAE/results/latent8" \ --image_path /storage/dataset/image/anytext3m/ocr_data/Art/images/gt_5544.jpg \ --rec_path rec_.jpg \ --device cuda \ --short_size 512 ================================================ FILE: scripts/causalvae/rec_video.sh ================================================ CUDA_VISIBLE_DEVICES=1 python examples/rec_video.py \ --ae WFVAEModel_D8_4x8x8 \ --ae_path "/storage/lcm/WF-VAE/results/latent8" \ --video_path /storage/lcm/WF-VAE/testvideo/gm1190263332-337350271.mp4 \ --rec_path rec_tile_.mp4 \ --device cuda \ --sample_rate 1 \ --num_frames 65 \ --height 512 \ --width 512 \ --fps 30 \ --enable_tiling ================================================ FILE: scripts/causalvae/train.sh ================================================ export WANDB_PROJECT=WFVAE export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export GLOO_SOCKET_IFNAME=bond0 export NCCL_SOCKET_IFNAME=bond0 export NCCL_IB_HCA=mlx5_10:1,mlx5_11:1,mlx5_12:1,mlx5_13:1 export NCCL_IB_GID_INDEX=3 export NCCL_IB_TC=162 export NCCL_IB_TIMEOUT=22 export NCCL_PXN_DISABLE=0 export NCCL_IB_QPS_PER_CONNECTION=4 export NCCL_ALGO=Ring export OMP_NUM_THREADS=1 export MKL_NUM_THREADS=1 EXP_NAME=TRAIN torchrun \ --nnodes=1 --nproc_per_node=8 \ --master_addr=localhost \ --master_port=12133 \ opensora/train/train_causalvae.py \ --exp_name ${EXP_NAME} \ --video_path /storage/dataset/vae_eval/OpenMMLab___Kinetics-400/raw/Kinetics-400/videos_train/ \ --eval_video_path /storage/dataset/vae_eval/OpenMMLab___Kinetics-400/raw/Kinetics-400/videos_val/ \ --model_name WFVAE \ --model_config scripts/causalvae/wfvae_4dim.json \ --resolution 256 \ --num_frames 25 \ --batch_size 1 \ --lr 0.00001 \ --epochs 4 \ --disc_start 0 \ --save_ckpt_step 5000 \ --eval_steps 1000 \ --eval_batch_size 1 \ --eval_num_frames 33 \ --eval_sample_rate 1 \ --eval_subset_size 500 \ --eval_lpips \ --ema \ --ema_decay 0.999 \ --perceptual_weight 1.0 \ --loss_type l1 \ --sample_rate 1 \ --disc_cls opensora.models.causalvideovae.model.losses.LPIPSWithDiscriminator3D \ --wavelet_loss \ --wavelet_weight 0.1 ================================================ FILE: scripts/causalvae/wfvae_4dim.json ================================================ { "_class_name": "WFVAEModel", "_diffusers_version": "0.30.2", "base_channels": 128, "connect_res_layer_num": 1, "decoder_energy_flow_hidden_size": 128, "decoder_num_resblocks": 2, "dropout": 0.0, "encoder_energy_flow_hidden_size": 128, "encoder_num_resblocks": 2, "l1_dowmsample_block": "Downsample", "l1_downsample_wavelet": "HaarWaveletTransform2D", "l1_upsample_block": "Upsample", "l1_upsample_wavelet": "InverseHaarWaveletTransform2D", "l2_dowmsample_block": "Spatial2xTime2x3DDownsample", "l2_downsample_wavelet": "HaarWaveletTransform3D", "l2_upsample_block": "Spatial2xTime2x3DUpsample", "l2_upsample_wavelet": "InverseHaarWaveletTransform3D", "latent_dim": 4, "norm_type": "layernorm", "t_interpolation": "trilinear", "use_attention": true } ================================================ FILE: scripts/slurm/placeholder ================================================ ================================================ FILE: scripts/text_condition/gpu/sample_inpaint_v1_3.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nnodes=1 --nproc_per_node 8 --master_port 29513 \ -m opensora.sample.sample \ --model_type "inpaint" \ --model_path model_path \ --version v1_3 \ --num_frames 93 \ --height 352 \ --width 640 \ --max_hxw 236544 \ --crop_for_hw \ --cache_dir "../cache_dir" \ --text_encoder_name_1 "/storage/ongoing/new/Open-Sora-Plan/cache_dir/mt5-xxl" \ --text_prompt examples/cond_prompt.txt \ --conditional_pixel_values_path examples/cond_pix_path.txt \ --ae WFVAEModel_D8_4x8x8 \ --ae_path "/storage/lcm/WF-VAE/results/latent8" \ --save_img_path "./save_path" \ --fps 18 \ --guidance_scale 7.5 \ --num_sampling_steps 100 \ --max_sequence_length 512 \ --sample_method EulerAncestralDiscrete \ --seed 1234 \ --num_samples_per_prompt 1 \ --rescale_betas_zero_snr \ --prediction_type "v_prediction" \ --noise_strength 0.0 \ ================================================ FILE: scripts/text_condition/gpu/sample_t2v_v1_3.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nnodes=1 --nproc_per_node 8 --master_port 29514 \ -m opensora.sample.sample \ --model_path /storage/ongoing/9.29/mmdit/Open-Sora-Plan/final_ft_any93x352x640_v1_3_bs512_lr1e-5_snr5.0_fps16_zsnr_nofix_16node/checkpoint-5500/model_ema \ --version v1_3 \ --num_frames 93 \ --height 352 \ --width 640 \ --cache_dir "../cache_dir" \ --text_encoder_name_1 "/storage/ongoing/new/Open-Sora-Plan/cache_dir/mt5-xxl" \ --text_prompt "examples/sora.txt" \ --ae WFVAEModel_D8_4x8x8 \ --ae_path "/storage/lcm/WF-VAE/results/latent8" \ --save_img_path "./train_1_3_nomotion_fps18" \ --fps 18 \ --guidance_scale 7.5 \ --num_sampling_steps 100 \ --max_sequence_length 512 \ --sample_method EulerAncestralDiscrete \ --seed 1234 \ --num_samples_per_prompt 1 \ --rescale_betas_zero_snr \ --prediction_type "v_prediction" ================================================ FILE: scripts/text_condition/gpu/train_inpaint_v1_3.sh ================================================ export HF_DATASETS_OFFLINE=1 export TRANSFORMERS_OFFLINE=1 export PDSH_RCMD_TYPE=ssh # NCCL setting export GLOO_SOCKET_IFNAME=bond0 export NCCL_SOCKET_IFNAME=bond0 export NCCL_IB_HCA=mlx5_10:1,mlx5_11:1,mlx5_12:1,mlx5_13:1 export NCCL_IB_GID_INDEX=3 export NCCL_IB_TC=162 export NCCL_IB_TIMEOUT=25 export NCCL_PXN_DISABLE=0 export NCCL_IB_QPS_PER_CONNECTION=4 export NCCL_ALGO=Ring export OMP_NUM_THREADS=1 export MKL_NUM_THREADS=1 export NCCL_IB_RETRY_CNT=32 # export NCCL_ALGO=Tree accelerate launch \ --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \ opensora/train/train_inpaint.py \ --model OpenSoraInpaint_v1_3-2B/122 \ --text_encoder_name_1 google/mt5-xxl \ --cache_dir "../../cache_dir/" \ --dataset inpaint \ --data "scripts/train_data/video_data.txt" \ --ae WFVAEModel_D8_4x8x8 \ --ae_path "/storage/lcm/WF-VAE/results/latent8" \ --sample_rate 1 \ --num_frames 93 \ --max_hxw 236544 \ --min_hxw 102400 \ --interpolation_scale_t 1.0 \ --interpolation_scale_h 1.0 \ --interpolation_scale_w 1.0 \ --gradient_checkpointing \ --train_batch_size=1 \ --dataloader_num_workers 8 \ --gradient_accumulation_steps=1 \ --max_train_steps=1000000 \ --learning_rate=1e-5 \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --mixed_precision="bf16" \ --report_to="wandb" \ --checkpointing_steps=1000 \ --allow_tf32 \ --model_max_length 512 \ --use_ema \ --ema_start_step 0 \ --cfg 0.1 \ --resume_from_checkpoint="latest" \ --speed_factor 1.0 \ --ema_decay 0.9999 \ --drop_short_ratio 0.0 \ --hw_stride 32 \ --sparse1d --sparse_n 4 \ --train_fps 18 \ --seed 1234 \ --trained_data_global_step 0 \ --group_data \ --use_decord \ --prediction_type "v_prediction" \ --output_dir="debug" \ --rescale_betas_zero_snr \ --mask_config scripts/train_configs/mask_config.yaml \ --add_noise_to_condition \ --default_text_ratio 0.5 # --pretrained "" ================================================ FILE: scripts/text_condition/gpu/train_t2v_v1_3.sh ================================================ export HF_DATASETS_OFFLINE=1 export TRANSFORMERS_OFFLINE=1 export PDSH_RCMD_TYPE=ssh # NCCL setting export GLOO_SOCKET_IFNAME=bond0 export NCCL_SOCKET_IFNAME=bond0 export NCCL_IB_HCA=mlx5_10:1,mlx5_11:1,mlx5_12:1,mlx5_13:1 export NCCL_IB_GID_INDEX=3 export NCCL_IB_TC=162 export NCCL_IB_TIMEOUT=25 export NCCL_PXN_DISABLE=0 export NCCL_IB_QPS_PER_CONNECTION=4 export NCCL_ALGO=Ring export OMP_NUM_THREADS=1 export MKL_NUM_THREADS=1 export NCCL_IB_RETRY_CNT=32 # export NCCL_ALGO=Tree accelerate launch \ --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \ opensora/train/train_t2v_diffusers.py \ --model OpenSoraT2V_v1_3-2B/122 \ --text_encoder_name_1 google/mt5-xxl \ --cache_dir "../../cache_dir/" \ --dataset t2v \ --data "scripts/train_data/merge_data.txt" \ --ae WFVAEModel_D8_4x8x8 \ --ae_path "/storage/lcm/WF-VAE/results/latent8" \ --sample_rate 1 \ --num_frames 1 \ --max_height 352 \ --max_width 640 \ --interpolation_scale_t 1.0 \ --interpolation_scale_h 1.0 \ --interpolation_scale_w 1.0 \ --gradient_checkpointing \ --train_batch_size=4 \ --dataloader_num_workers 16 \ --gradient_accumulation_steps=1 \ --max_train_steps=1000000 \ --learning_rate=1e-5 \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --mixed_precision="bf16" \ --report_to="wandb" \ --checkpointing_steps=500 \ --allow_tf32 \ --model_max_length 512 \ --use_ema \ --ema_start_step 0 \ --cfg 0.1 \ --resume_from_checkpoint="latest" \ --speed_factor 1.0 \ --ema_decay 0.9999 \ --drop_short_ratio 0.0 \ --pretrained "" \ --hw_stride 32 \ --sparse1d --sparse_n 4 \ --train_fps 16 \ --seed 1234 \ --trained_data_global_step 0 \ --group_data \ --use_decord \ --prediction_type "v_prediction" \ --snr_gamma 5.0 \ --force_resolution \ --rescale_betas_zero_snr \ --output_dir="debug" ================================================ FILE: scripts/text_condition/npu/sample_inpaint_v1_3.sh ================================================ export TASK_QUEUE_ENABLE=0 torchrun --nnodes=1 --nproc_per_node 8 --master_port 29522 \ -m opensora.sample.sample \ --model_type "inpaint" \ --model_path model_path \ --version v1_3 \ --num_frames 93 \ --crop_for_hw \ --height 352 \ --width 640 \ --max_hxw 236544 \ --cache_dir "../cache_dir" \ --text_encoder_name_1 "/home/save_dir/pretrained/mt5-xxl" \ --text_prompt /home/image_data/gyy/mmdit/Open-Sora-Plan/validation_dir/prompt.txt \ --conditional_pixel_values_path /home/image_data/gyy/mmdit/Open-Sora-Plan/validation_dir/cond_imgs_path.txt \ --ae WFVAEModel_D8_4x8x8 \ --ae_path "/home/save_dir/lzj/formal_8dim/latent8" \ --save_img_path "./test" \ --fps 18 \ --guidance_scale 7.5 \ --num_sampling_steps 50 \ --max_sequence_length 512 \ --sample_method EulerAncestralDiscrete \ --seed 2514 \ --num_samples_per_prompt 1 \ --prediction_type "v_prediction" \ --rescale_betas_zero_snr \ --noise_strength 0.0 \ # --mask_type i2v \ # --enable_tiling ================================================ FILE: scripts/text_condition/npu/sample_t2v_v1_3.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nnodes=1 --nproc_per_node 8 --master_port 29513 \ -m opensora.sample.sample \ --model_path model_path \ --version v1_3 \ --num_frames 93 \ --height 352 \ --width 640 \ --cache_dir "../cache_dir" \ --text_encoder_name_1 "/home/save_dir/pretrained/mt5-xxl" \ --text_prompt examples/sora_refine.txt \ --ae WFVAEModel_D8_4x8x8 \ --ae_path "/home/save_dir/lzj/formal_8dim/latent8" \ --save_img_path "./test" \ --fps 18 \ --guidance_scale 7.5 \ --num_sampling_steps 100 \ --max_sequence_length 512 \ --sample_method EulerAncestralDiscrete \ --seed 1234 \ --num_samples_per_prompt 1 \ --rescale_betas_zero_snr \ --prediction_type "v_prediction" ================================================ FILE: scripts/text_condition/npu/train_inpaint_v1_3.sh ================================================ export PROJECT=$PROJECT_NAME # export PROJECT='test' export HF_DATASETS_OFFLINE=1 export TRANSFORMERS_OFFLINE=1 export TASK_QUEUE_ENABLE=0 export HCCL_OP_BASE_FFTS_MODE_ENABLE=TRUE export MULTI_STREAM_MEMORY_REUSE=1 export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True # export HCCL_ALGO="level0:NA;level1:H-D_R" # --machine_rank=${MACHINE_RANK} \ # --main_process_ip=${MAIN_PROCESS_IP_VALUE} \ # multi_node_example_by_deepspeed.yaml # deepspeed_zero2_config.yaml accelerate launch \ --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \ opensora/train/train_inpaint.py \ --model OpenSoraInpaint_v1_3-2B/122 \ --text_encoder_name_1 google/mt5-xxl \ --cache_dir "../../cache_dir/" \ --dataset inpaint \ --data "scripts/train_data/video_data.txt" \ --ae WFVAEModel_D8_4x8x8 \ --ae_path "/home/save_dir/lzj/formal_8dim/latent8" \ --vae_fp32 \ --sample_rate 1 \ --num_frames 93 \ --max_hxw 236544 \ --min_hxw 102400 \ --snr_gamma 5.0 \ --interpolation_scale_t 1.0 \ --interpolation_scale_h 1.0 \ --interpolation_scale_w 1.0 \ --gradient_checkpointing \ --train_batch_size=1 \ --dataloader_num_workers 8 \ --gradient_accumulation_steps=1 \ --max_train_steps=1000000 \ --learning_rate=1e-5 \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --mixed_precision="bf16" \ --report_to="wandb" \ --checkpointing_steps=500 \ --allow_tf32 \ --model_max_length 512 \ --use_ema \ --ema_start_step 0 \ --cfg 0.1 \ --speed_factor 1.0 \ --ema_decay 0.9999 \ --drop_short_ratio 0.0 \ --hw_stride 32 \ --sparse1d --sparse_n=4 \ --train_fps 16 \ --seed 1234 \ --trained_data_global_step 0 \ --group_data \ --use_decord \ --prediction_type "v_prediction" \ --output_dir="/home/save_dir/runs/$PROJECT" \ --mask_config scripts/train_configs/mask_config.yaml \ --add_noise_to_condition \ --default_text_ratio 0.5 \ --resume_from_checkpoint="latest" # --pretrained "/home/save_dir/pretrained/93x640x640_144k_ema" # --force_resolution # --force_resolution \ # --max_height 352 \ # --max_width 640 \ ================================================ FILE: scripts/text_condition/npu/train_t2v_v1_3.sh ================================================ export PROJECT=$PROJECT_NAME # export PROJECT='test' export HF_DATASETS_OFFLINE=1 export TRANSFORMERS_OFFLINE=1 export TASK_QUEUE_ENABLE=0 export HCCL_OP_BASE_FFTS_MODE_ENABLE=TRUE export MULTI_STREAM_MEMORY_REUSE=1 export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True # export HCCL_ALGO="level0:NA;level1:H-D_R" # --machine_rank=${MACHINE_RANK} \ # --main_process_ip=${MAIN_PROCESS_IP_VALUE} \ # multi_node_example_by_deepspeed.yaml # deepspeed_zero2_config.yaml accelerate launch \ --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \ opensora/train/train_t2v_diffusers.py \ --model OpenSoraT2V_v1_3-2B/122 \ --text_encoder_name_1 google/mt5-xxl \ --cache_dir "../../cache_dir/" \ --dataset t2v \ --data "scripts/train_data/video_data_debug_on_npu.txt" \ --ae WFVAEModel_D8_4x8x8 \ --ae_path "/home/save_dir/lzj/formal_8dim/latent8" \ --sample_rate 1 \ --num_frames 93 \ --max_height 352 \ --max_width 640 \ --force_resolution \ --interpolation_scale_t 1.0 \ --interpolation_scale_h 1.0 \ --interpolation_scale_w 1.0 \ --gradient_checkpointing \ --train_batch_size=1 \ --dataloader_num_workers 8 \ --gradient_accumulation_steps=1 \ --max_train_steps=1000000 \ --learning_rate=1e-5 \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --mixed_precision="bf16" \ --report_to="wandb" \ --checkpointing_steps=500 \ --allow_tf32 \ --model_max_length 512 \ --use_ema \ --ema_start_step 0 \ --cfg 0.1 \ --resume_from_checkpoint="latest" \ --speed_factor 1.0 \ --ema_decay 0.9999 \ --drop_short_ratio 0.0 \ --pretrained "/home/save_dir/pretrained/93x640x640_144k_ema" \ --hw_stride 32 \ --sparse1d --sparse_n 4 \ --train_fps 16 \ --seed 1234 \ --trained_data_global_step 0 \ --group_data \ --use_decord \ --prediction_type "v_prediction" \ --snr_gamma 5.0 \ --rescale_betas_zero_snr \ --output_dir="debug" ================================================ FILE: scripts/train_configs/mask_config.yaml ================================================ # mask processor args min_clear_ratio: 0.0 max_clear_ratio: 1.0 # mask_type_ratio_dict_video mask_type_ratio_dict_video: t2iv: 1 i2v: 8 transition: 8 continuation: 2 clear: 0 random_temporal: 1 mask_type_ratio_dict_image: t2iv: 0 clear: 0 ================================================ FILE: scripts/train_data/merge_data.txt ================================================ /storage/dataset/recap_datacomp_1b_data/output,/storage/anno_pkl/img_nocn_res160_pkl/recap_64part_filter_aes_res160_pkl/part0_7036495.pkl