Repository: guoyww/AnimateDiff Branch: main Commit: e92bd5671ba6 Files: 39 Total size: 318.0 KB Directory structure: gitextract_t6ymel6r/ ├── .gitignore ├── LICENSE.txt ├── README.md ├── __assets__/ │ ├── animations/ │ │ └── compare/ │ │ └── ffmpeg │ └── docs/ │ ├── animatediff.md │ └── gallery.md ├── animatediff/ │ ├── data/ │ │ └── dataset.py │ ├── models/ │ │ ├── attention.py │ │ ├── motion_module.py │ │ ├── resnet.py │ │ ├── sparse_controlnet.py │ │ ├── unet.py │ │ └── unet_blocks.py │ ├── pipelines/ │ │ └── pipeline_animation.py │ └── utils/ │ ├── convert_from_ckpt.py │ ├── convert_lora_safetensor_to_diffusers.py │ └── util.py ├── app.py ├── configs/ │ ├── inference/ │ │ ├── inference-v1.yaml │ │ ├── inference-v2.yaml │ │ ├── inference-v3.yaml │ │ └── sparsectrl/ │ │ ├── image_condition.yaml │ │ └── latent_condition.yaml │ ├── prompts/ │ │ ├── 1_animate/ │ │ │ ├── 1_1_animate_RealisticVision.yaml │ │ │ ├── 1_2_animate_FilmVelvia.yaml │ │ │ ├── 1_3_animate_ToonYou.yaml │ │ │ ├── 1_4_animate_MajicMix.yaml │ │ │ ├── 1_5_animate_RcnzCartoon.yaml │ │ │ ├── 1_6_animate_Lyriel.yaml │ │ │ └── 1_7_animate_Tusun.yaml │ │ ├── 2_motionlora/ │ │ │ └── 2_motionlora_RealisticVision.yaml │ │ └── 3_sparsectrl/ │ │ ├── 3_1_sparsectrl_i2v.yaml │ │ ├── 3_2_sparsectrl_rgb_RealisticVision.yaml │ │ └── 3_3_sparsectrl_sketch_RealisticVision.yaml │ └── training/ │ └── v1/ │ ├── image_finetune.yaml │ └── training.yaml ├── requirements.txt ├── scripts/ │ └── animate.py └── train.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ wandb/ *debug* debugs/ outputs/ samples/ __pycache__/ ossutil_output/ .ossutil_checkpoint/ scripts/* !scripts/animate.py *.ipynb *.safetensors *.ckpt models/* !models/StableDiffusion/ models/StableDiffusion/* !models/StableDiffusion/*.txt !models/Motion_Module/ !models/Motion_Module/*.txt !models/DreamBooth_LoRA/ !models/DreamBooth_LoRA/*.txt !models/MotionLoRA/ !models/MotionLoRA/*.txt ================================================ FILE: LICENSE.txt ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # AnimateDiff This repository is the official implementation of [AnimateDiff](https://arxiv.org/abs/2307.04725) [ICLR2024 Spotlight]. It is a plug-and-play module turning most community text-to-image models into animation generators, without the need of additional training. **[AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725)**
[Yuwei Guo](https://guoyww.github.io/), [Ceyuan Yang✝](https://ceyuan.me/), [Anyi Rao](https://anyirao.com/), [Zhengyang Liang](https://maxleung99.github.io/), [Yaohui Wang](https://wyhsirius.github.io/), [Yu Qiao](https://scholar.google.com.hk/citations?user=gFtI-8QAAAAJ), [Maneesh Agrawala](https://graphics.stanford.edu/~maneesh/), [Dahua Lin](http://dahua.site), [Bo Dai](https://daibo.info) (✝Corresponding Author) [![arXiv](https://img.shields.io/badge/arXiv-2307.04725-b31b1b.svg)](https://arxiv.org/abs/2307.04725) [![Project Page](https://img.shields.io/badge/Project-Website-green)](https://animatediff.github.io/) [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/Masbfca/AnimateDiff) [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow)](https://huggingface.co/spaces/guoyww/AnimateDiff) ***Note:*** The `main` branch is for [Stable Diffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5); for [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), please refer `sdxl-beta` branch. ## Quick Demos More results can be found in the [Gallery](__assets__/docs/gallery.md). Some of them are contributed by the community.

Model:ToonYou

Model:Realistic Vision V2.0

## Quick Start ***Note:*** AnimateDiff is also offically supported by Diffusers. Visit [AnimateDiff Diffusers Tutorial](https://huggingface.co/docs/diffusers/api/pipelines/animatediff) for more details. *Following instructions is for working with this repository*. ***Note:*** For all scripts, checkpoint downloading will be *automatically* handled, so the script running may take longer time when first executed. ### 1. Setup repository and environment ``` git clone https://github.com/guoyww/AnimateDiff.git cd AnimateDiff pip install -r requirements.txt ``` ### 2. Launch the sampling script! The generated samples can be found in `samples/` folder. #### 2.1 Generate animations with comunity models ``` python -m scripts.animate --config configs/prompts/1_animate/1_1_animate_RealisticVision.yaml python -m scripts.animate --config configs/prompts/1_animate/1_2_animate_FilmVelvia.yaml python -m scripts.animate --config configs/prompts/1_animate/1_3_animate_ToonYou.yaml python -m scripts.animate --config configs/prompts/1_animate/1_4_animate_MajicMix.yaml python -m scripts.animate --config configs/prompts/1_animate/1_5_animate_RcnzCartoon.yaml python -m scripts.animate --config configs/prompts/1_animate/1_6_animate_Lyriel.yaml python -m scripts.animate --config configs/prompts/1_animate/1_7_animate_Tusun.yaml ``` #### 2.2 Generate animation with MotionLoRA control ``` python -m scripts.animate --config configs/prompts/2_motionlora/2_motionlora_RealisticVision.yaml ``` #### 2.3 More control with SparseCtrl RGB and sketch ``` python -m scripts.animate --config configs/prompts/3_sparsectrl/3_1_sparsectrl_i2v.yaml python -m scripts.animate --config configs/prompts/3_sparsectrl/3_2_sparsectrl_rgb_RealisticVision.yaml python -m scripts.animate --config configs/prompts/3_sparsectrl/3_3_sparsectrl_sketch_RealisticVision.yaml ``` #### 2.4 Gradio app We created a Gradio demo to make AnimateDiff easier to use. By default, the demo will run at `localhost:7860`. ``` python -u app.py ``` ## Technical Explanation
Technical Explanation ### AnimateDiff **AnimateDiff aims to learn transferable motion priors that can be applied to other variants of Stable Diffusion family.** To this end, we design the following training pipeline consisting of three stages. - In **1. Alleviate Negative Effects** stage, we train the **domain adapter**, e.g., `v3_sd15_adapter.ckpt`, to fit defective visual aritfacts (e.g., watermarks) in the training dataset. This can also benefit the distangled learning of motion and spatial appearance. By default, the adapter can be removed at inference. It can also be integrated into the model and its effects can be adjusted by a lora scaler. - In **2. Learn Motion Priors** stage, we train the **motion module**, e.g., `v3_sd15_mm.ckpt`, to learn the real-world motion patterns from videos. - In **3. (optional) Adapt to New Patterns** stage, we train **MotionLoRA**, e.g., `v2_lora_ZoomIn.ckpt`, to efficiently adapt motion module for specific motion patterns (camera zooming, rolling, etc.). ### SparseCtrl **SparseCtrl aims to add more control to text-to-video models by adopting some sparse inputs (e.g., few RGB images or sketch inputs).** Its technicall details can be found in the following paper: **[SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion Models](https://arxiv.org/abs/2311.16933)** [Yuwei Guo](https://guoyww.github.io/), [Ceyuan Yang✝](https://ceyuan.me/), [Anyi Rao](https://anyirao.com/), [Maneesh Agrawala](https://graphics.stanford.edu/~maneesh/), [Dahua Lin](http://dahua.site), [Bo Dai](https://daibo.info) (✝Corresponding Author) [![arXiv](https://img.shields.io/badge/arXiv-2311.16933-b31b1b.svg)](https://arxiv.org/abs/2311.16933) [![Project Page](https://img.shields.io/badge/Project-Website-green)](https://guoyww.github.io/projects/SparseCtrl/)
## Model Versions
Model Versions ### AnimateDiff v3 and SparseCtrl (2023.12) In this version, we use **Domain Adapter LoRA** for image model finetuning, which provides more flexiblity at inference. We also implement two (RGB image/scribble) [SparseCtrl](https://arxiv.org/abs/2311.16933) encoders, which can take abitary number of condition maps to control the animation contents.
AnimateDiff v3 Model Zoo | Name | HuggingFace | Type | Storage | Description | | - | - | - | - | - | | `v3_adapter_sd_v15.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_adapter.ckpt) | Domain Adapter | 97.4 MB | | | `v3_sd15_mm.ckpt.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_mm.ckpt) | Motion Module | 1.56 GB | | | `v3_sd15_sparsectrl_scribble.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_sparsectrl_scribble.ckpt) | SparseCtrl Encoder | 1.86 GB | scribble condition | | `v3_sd15_sparsectrl_rgb.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_sparsectrl_rgb.ckpt) | SparseCtrl Encoder | 1.85 GB | RGB image condition |
#### Limitations 1. Small fickering is noticable; 2. To stay compatible with comunity models, there is no specific optimizations for general T2V, leading to limited visual quality under this setting; 3. **(Style Alignment) For usage such as image animation/interpolation, it's recommanded to use images generated by the same community model.** #### Demos
Input (by RealisticVision) Animation Input Animation
Input Scribble Output Input Scribbles Output
### AnimateDiff SDXL-Beta (2023.11) Release the Motion Module (beta version) on SDXL, available at [Google Drive](https://drive.google.com/file/d/1EK_D9hDOPfJdK4z8YDB8JYvPracNx2SX/view?usp=share_link ) / [HuggingFace](https://huggingface.co/guoyww/animatediff/blob/main/mm_sdxl_v10_beta.ckpt ) / [CivitAI](https://civitai.com/models/108836/animatediff-motion-modules). High resolution videos (i.e., 1024x1024x16 frames with various aspect ratios) could be produced **with/without** personalized models. Inference usually requires ~13GB VRAM and tuned hyperparameters (e.g., sampling steps), depending on the chosen personalized models. Checkout to the branch [sdxl](https://github.com/guoyww/AnimateDiff/tree/sdxl) for more details of the inference.
AnimateDiff SDXL-Beta Model Zoo | Name | HuggingFace | Type | Storage Space | | - | - | - | - | | `mm_sdxl_v10_beta.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/mm_sdxl_v10_beta.ckpt) | Motion Module | 950 MB |
#### Demos
Original SDXL Community SDXL Community SDXL
### AnimateDiff v2 (2023.09) In this version, the motion module `mm_sd_v15_v2.ckpt` ([Google Drive](https://drive.google.com/drive/folders/1EqLC65eR1-W-sGD0Im7fkED6c8GkiNFI?usp=sharing) / [HuggingFace](https://huggingface.co/guoyww/animatediff) / [CivitAI](https://civitai.com/models/108836/animatediff-motion-modules)) is trained upon larger resolution and batch size. We found that the scale-up training significantly helps improve the motion quality and diversity. We also support **MotionLoRA** of eight basic camera movements. MotionLoRA checkpoints take up only **77 MB storage per model**, and are available at [Google Drive](https://drive.google.com/drive/folders/1EqLC65eR1-W-sGD0Im7fkED6c8GkiNFI?usp=sharing) / [HuggingFace](https://huggingface.co/guoyww/animatediff) / [CivitAI](https://civitai.com/models/108836/animatediff-motion-modules).
AnimateDiff v2 Model Zoo | Name | HuggingFace | Type | Parameter | Storage | | - | - | - | - | - | | `mm_sd_v15_v2.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v15_v2.ckpt) | Motion Module | 453 M | 1.7 GB | | `v2_lora_ZoomIn.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v2_lora_ZoomIn.ckpt) | MotionLoRA | 19 M | 74 MB | | `v2_lora_ZoomOut.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v2_lora_ZoomOut.ckpt) | MotionLoRA | 19 M | 74 MB | | `v2_lora_PanLeft.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v2_lora_PanLeft.ckpt) | MotionLoRA | 19 M | 74 MB | | `v2_lora_PanRight.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v2_lora_PanRight.ckpt) | MotionLoRA | 19 M | 74 MB | | `v2_lora_TiltUp.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v2_lora_TiltUp.ckpt) | MotionLoRA | 19 M | 74 MB | | `v2_lora_TiltDown.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v2_lora_TiltDown.ckpt) | MotionLoRA | 19 M | 74 MB | | `v2_lora_RollingClockwise.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v2_lora_RollingClockwise.ckpt) | MotionLoRA | 19 M | 74 MB | | `v2_lora_RollingAnticlockwise.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v2_lora_RollingAnticlockwise.ckpt) | MotionLoRA | 19 M | 74 MB |
#### Demos (MotionLoRA)
Zoom In Zoom Out Zoom Pan Left Zoom Pan Right
Tilt Up Tilt Down Rolling Anti-Clockwise Rolling Clockwise
#### Demos (Improved Motions) Here's a comparison between `mm_sd_v15.ckpt` (left) and improved `mm_sd_v15_v2.ckpt` (right).
### AnimateDiff v1 (2023.07) The first version of AnimateDiff!
AnimateDiff v1 Model Zoo | Name | HuggingFace | Parameter | Storage Space | | - | - | - | - | | mm_sd_v14.ckpt | [Link](https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v14.ckpt) | 417 M | 1.6 GB | | mm_sd_v15.ckpt | [Link](https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v15.ckpt) | 417 M | 1.6 GB |
## Training Please check [Steps for Training](__assets__/docs/animatediff.md) for details. ## Related Resources AnimateDiff for Stable Diffusion WebUI: [sd-webui-animatediff](https://github.com/continue-revolution/sd-webui-animatediff) (by [@continue-revolution](https://github.com/continue-revolution)) AnimateDiff for ComfyUI: [ComfyUI-AnimateDiff-Evolved](https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) (by [@Kosinkadink](https://github.com/Kosinkadink)) Google Colab: [Colab](https://colab.research.google.com/github/camenduru/AnimateDiff-colab/blob/main/AnimateDiff_colab.ipynb) (by [@camenduru](https://github.com/camenduru)) ## Disclaimer This project is released for academic use. We disclaim responsibility for user-generated content. Also, please be advised that our only official website are https://github.com/guoyww/AnimateDiff and https://animatediff.github.io, and all the other websites are NOT associated with us at AnimateDiff. ## Contact Us Yuwei Guo: [guoyw@ie.cuhk.edu.hk](mailto:guoyw@ie.cuhk.edu.hk) Ceyuan Yang: [limbo0066@gmail.com](mailto:limbo0066@gmail.com) Bo Dai: [doubledaibo@gmail.com](mailto:doubledaibo@gmail.com) ## BibTeX ``` @article{guo2023animatediff, title={AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning}, author={Guo, Yuwei and Yang, Ceyuan and Rao, Anyi and Liang, Zhengyang and Wang, Yaohui and Qiao, Yu and Agrawala, Maneesh and Lin, Dahua and Dai, Bo}, journal={International Conference on Learning Representations}, year={2024} } @article{guo2023sparsectrl, title={SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion Models}, author={Guo, Yuwei and Yang, Ceyuan and Rao, Anyi and Agrawala, Maneesh and Lin, Dahua and Dai, Bo}, journal={arXiv preprint arXiv:2311.16933}, year={2023} } ``` ## Acknowledgements Codebase built upon [Tune-a-Video](https://github.com/showlab/Tune-A-Video). ================================================ FILE: __assets__/animations/compare/ffmpeg ================================================ ================================================ FILE: __assets__/docs/animatediff.md ================================================ ## Steps for Training ### Dataset Before training, download the videos files and the `.csv` annotations of [WebVid10M](https://maxbain.com/webvid-dataset/) to the local mechine. Note that our examplar training script requires all the videos to be saved in a single folder. You may change this by modifying `animatediff/data/dataset.py`. ### Configuration After dataset preparations, update the below data paths in the config `.yaml` files in `configs/training/` folder: ``` train_data: csv_path: [Replace with .csv Annotation File Path] video_folder: [Replace with Video Folder Path] sample_size: 256 ``` Other training parameters (lr, epochs, validation settings, etc.) are also included in the config files. ### Training To finetune the unet's image layers ``` torchrun --nnodes=1 --nproc_per_node=1 train.py --config configs/training/v1/image_finetune.yaml ``` To train motion modules ``` torchrun --nnodes=1 --nproc_per_node=1 train.py --config configs/training/v1/training.yaml ``` ================================================ FILE: __assets__/docs/gallery.md ================================================ # Gallery Here we demonstrate several best results we found in our experiments.

Model:ToonYou

Model:Counterfeit V3.0

Model:Realistic Vision V2.0

Model: majicMIX Realistic

Model:RCNZ Cartoon

Model:FilmVelvia

#### Community Cases Here are some samples contributed by the community artists. Create a Pull Request if you would like to show your results here😚.

Character Model:Yoimiya (with an initial reference image, see WIP fork for the extended implementation.)

Character Model:Paimon; Pose Model:Hold Sign

================================================ FILE: animatediff/data/dataset.py ================================================ import os, io, csv, math, random import numpy as np from einops import rearrange from decord import VideoReader import torch import torchvision.transforms as transforms from torch.utils.data.dataset import Dataset from animatediff.utils.util import zero_rank_print class WebVid10M(Dataset): def __init__( self, csv_path, video_folder, sample_size=256, sample_stride=4, sample_n_frames=16, is_image=False, ): zero_rank_print(f"loading annotations from {csv_path} ...") with open(csv_path, 'r') as csvfile: self.dataset = list(csv.DictReader(csvfile)) self.length = len(self.dataset) zero_rank_print(f"data scale: {self.length}") self.video_folder = video_folder self.sample_stride = sample_stride self.sample_n_frames = sample_n_frames self.is_image = is_image sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) self.pixel_transforms = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.Resize(sample_size[0]), transforms.CenterCrop(sample_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ]) def get_batch(self, idx): video_dict = self.dataset[idx] videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir'] video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") video_reader = VideoReader(video_dir) video_length = len(video_reader) if not self.is_image: clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1) start_idx = random.randint(0, video_length - clip_length) batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) else: batch_index = [random.randint(0, video_length - 1)] pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() pixel_values = pixel_values / 255. del video_reader if self.is_image: pixel_values = pixel_values[0] return pixel_values, name def __len__(self): return self.length def __getitem__(self, idx): while True: try: pixel_values, name = self.get_batch(idx) break except Exception as e: idx = random.randint(0, self.length-1) pixel_values = self.pixel_transforms(pixel_values) sample = dict(pixel_values=pixel_values, text=name) return sample if __name__ == "__main__": from animatediff.utils.util import save_videos_grid dataset = WebVid10M( csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv", video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val", sample_size=256, sample_stride=4, sample_n_frames=16, is_image=True, ) import pdb pdb.set_trace() dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16,) for idx, batch in enumerate(dataloader): print(batch["pixel_values"].shape, len(batch["text"])) # for i in range(batch["pixel_values"].shape[0]): # save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True) ================================================ FILE: animatediff/models/attention.py ================================================ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py from dataclasses import dataclass from typing import Optional import torch import torch.nn.functional as F from torch import nn from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.modeling_utils import ModelMixin from diffusers.utils import BaseOutput from diffusers.utils.import_utils import is_xformers_available from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm from einops import rearrange, repeat import pdb @dataclass class Transformer3DModelOutput(BaseOutput): sample: torch.FloatTensor if is_xformers_available(): import xformers import xformers.ops else: xformers = None class Transformer3DModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, unet_use_cross_frame_attention=None, unet_use_temporal_attention=None, ): super().__init__() self.use_linear_projection = use_linear_projection self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim # Define input layers self.in_channels = in_channels self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) if use_linear_projection: self.proj_in = nn.Linear(in_channels, inner_dim) else: self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) # Define transformers blocks self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, ) for d in range(num_layers) ] ) # 4. Define output layers if use_linear_projection: self.proj_out = nn.Linear(in_channels, inner_dim) else: self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): # Input assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." video_length = hidden_states.shape[2] hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) batch, channel, height, weight = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) if not self.use_linear_projection: hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) else: inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) hidden_states = self.proj_in(hidden_states) # Blocks for block in self.transformer_blocks: hidden_states = block( hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep, video_length=video_length ) # Output if not self.use_linear_projection: hidden_states = ( hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() ) hidden_states = self.proj_out(hidden_states) else: hidden_states = self.proj_out(hidden_states) hidden_states = ( hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() ) output = hidden_states + residual output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) if not return_dict: return (output,) return Transformer3DModelOutput(sample=output) class BasicTransformerBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, unet_use_cross_frame_attention = None, unet_use_temporal_attention = None, ): super().__init__() self.only_cross_attention = only_cross_attention self.use_ada_layer_norm = num_embeds_ada_norm is not None self.unet_use_cross_frame_attention = unet_use_cross_frame_attention self.unet_use_temporal_attention = unet_use_temporal_attention # SC-Attn assert unet_use_cross_frame_attention is not None if unet_use_cross_frame_attention: self.attn1 = SparseCausalAttention2D( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, ) else: self.attn1 = CrossAttention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, ) self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) # Cross-Attn if cross_attention_dim is not None: self.attn2 = CrossAttention( query_dim=dim, cross_attention_dim=cross_attention_dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, ) else: self.attn2 = None if cross_attention_dim is not None: self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) else: self.norm2 = None # Feed-forward self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) self.norm3 = nn.LayerNorm(dim) # Temp-Attn assert unet_use_temporal_attention is not None if unet_use_temporal_attention: self.attn_temp = CrossAttention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, ) nn.init.zeros_(self.attn_temp.to_out[0].weight.data) self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): if not is_xformers_available(): print("Here is how to install it") raise ModuleNotFoundError( "Refer to https://github.com/facebookresearch/xformers for more information on how to install" " xformers", name="xformers", ) elif not torch.cuda.is_available(): raise ValueError( "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" " available for GPU " ) else: try: # Make sure we can run the memory efficient attention _ = xformers.ops.memory_efficient_attention( torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), ) except Exception as e: raise e self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers if self.attn2 is not None: self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): # SparseCausal-Attention norm_hidden_states = ( self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) ) # if self.only_cross_attention: # hidden_states = ( # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states # ) # else: # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states # pdb.set_trace() if self.unet_use_cross_frame_attention: hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states else: hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states if self.attn2 is not None: # Cross-Attention norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) hidden_states = ( self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask ) + hidden_states ) # Feed-forward hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states # Temporal-Attention if self.unet_use_temporal_attention: d = hidden_states.shape[1] hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) norm_hidden_states = ( self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) ) hidden_states = self.attn_temp(norm_hidden_states) + hidden_states hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) return hidden_states ================================================ FILE: animatediff/models/motion_module.py ================================================ from dataclasses import dataclass from typing import List, Optional, Tuple, Union import torch import numpy as np import torch.nn.functional as F from torch import nn import torchvision from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.modeling_utils import ModelMixin from diffusers.utils import BaseOutput from diffusers.utils.import_utils import is_xformers_available from diffusers.models.attention import CrossAttention, FeedForward from einops import rearrange, repeat import math def zero_module(module): # Zero out the parameters of a module and return it. for p in module.parameters(): p.detach().zero_() return module @dataclass class TemporalTransformer3DModelOutput(BaseOutput): sample: torch.FloatTensor if is_xformers_available(): import xformers import xformers.ops else: xformers = None def get_motion_module( in_channels, motion_module_type: str, motion_module_kwargs: dict ): if motion_module_type == "Vanilla": return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,) else: raise ValueError class VanillaTemporalModule(nn.Module): def __init__( self, in_channels, num_attention_heads = 8, num_transformer_block = 2, attention_block_types =( "Temporal_Self", "Temporal_Self" ), cross_frame_attention_mode = None, temporal_position_encoding = False, temporal_position_encoding_max_len = 24, temporal_attention_dim_div = 1, zero_initialize = True, ): super().__init__() self.temporal_transformer = TemporalTransformer3DModel( in_channels=in_channels, num_attention_heads=num_attention_heads, attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div, num_layers=num_transformer_block, attention_block_types=attention_block_types, cross_frame_attention_mode=cross_frame_attention_mode, temporal_position_encoding=temporal_position_encoding, temporal_position_encoding_max_len=temporal_position_encoding_max_len, ) if zero_initialize: self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None): hidden_states = input_tensor hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask) output = hidden_states return output class TemporalTransformer3DModel(nn.Module): def __init__( self, in_channels, num_attention_heads, attention_head_dim, num_layers, attention_block_types = ( "Temporal_Self", "Temporal_Self", ), dropout = 0.0, norm_num_groups = 32, cross_attention_dim = 768, activation_fn = "geglu", attention_bias = False, upcast_attention = False, cross_frame_attention_mode = None, temporal_position_encoding = False, temporal_position_encoding_max_len = 24, ): super().__init__() inner_dim = num_attention_heads * attention_head_dim self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) self.proj_in = nn.Linear(in_channels, inner_dim) self.transformer_blocks = nn.ModuleList( [ TemporalTransformerBlock( dim=inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, attention_block_types=attention_block_types, dropout=dropout, norm_num_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, attention_bias=attention_bias, upcast_attention=upcast_attention, cross_frame_attention_mode=cross_frame_attention_mode, temporal_position_encoding=temporal_position_encoding, temporal_position_encoding_max_len=temporal_position_encoding_max_len, ) for d in range(num_layers) ] ) self.proj_out = nn.Linear(inner_dim, in_channels) def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." video_length = hidden_states.shape[2] hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") batch, channel, height, weight = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) hidden_states = self.proj_in(hidden_states) # Transformer Blocks for block in self.transformer_blocks: hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length) # output hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) return output class TemporalTransformerBlock(nn.Module): def __init__( self, dim, num_attention_heads, attention_head_dim, attention_block_types = ( "Temporal_Self", "Temporal_Self", ), dropout = 0.0, norm_num_groups = 32, cross_attention_dim = 768, activation_fn = "geglu", attention_bias = False, upcast_attention = False, cross_frame_attention_mode = None, temporal_position_encoding = False, temporal_position_encoding_max_len = 24, ): super().__init__() attention_blocks = [] norms = [] for block_name in attention_block_types: attention_blocks.append( VersatileAttention( attention_mode=block_name.split("_")[0], cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None, query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, cross_frame_attention_mode=cross_frame_attention_mode, temporal_position_encoding=temporal_position_encoding, temporal_position_encoding_max_len=temporal_position_encoding_max_len, ) ) norms.append(nn.LayerNorm(dim)) self.attention_blocks = nn.ModuleList(attention_blocks) self.norms = nn.ModuleList(norms) self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) self.ff_norm = nn.LayerNorm(dim) def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): for attention_block, norm in zip(self.attention_blocks, self.norms): norm_hidden_states = norm(hidden_states) hidden_states = attention_block( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, video_length=video_length, ) + hidden_states hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states output = hidden_states return output class PositionalEncoding(nn.Module): def __init__( self, d_model, dropout = 0., max_len = 24 ): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(1, max_len, d_model) pe[0, :, 0::2] = torch.sin(position * div_term) pe[0, :, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe, persistent=False) def forward(self, x): x = x + self.pe[:, :x.size(1)] return self.dropout(x) class VersatileAttention(CrossAttention): def __init__( self, attention_mode = None, cross_frame_attention_mode = None, temporal_position_encoding = False, temporal_position_encoding_max_len = 32, *args, **kwargs ): super().__init__(*args, **kwargs) assert attention_mode == "Temporal" self.attention_mode = attention_mode self.is_cross_attention = kwargs["cross_attention_dim"] is not None self.pos_encoder = PositionalEncoding( kwargs["query_dim"], dropout=0., max_len=temporal_position_encoding_max_len ) if (temporal_position_encoding and attention_mode == "Temporal") else None def extra_repr(self): return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): batch_size, sequence_length, _ = hidden_states.shape if self.attention_mode == "Temporal": d = hidden_states.shape[1] hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) if self.pos_encoder is not None: hidden_states = self.pos_encoder(hidden_states) encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states else: raise NotImplementedError encoder_hidden_states = encoder_hidden_states if self.group_norm is not None: hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = self.to_q(hidden_states) dim = query.shape[-1] query = self.reshape_heads_to_batch_dim(query) if self.added_kv_proj_dim is not None: raise NotImplementedError encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = self.to_k(encoder_hidden_states) value = self.to_v(encoder_hidden_states) key = self.reshape_heads_to_batch_dim(key) value = self.reshape_heads_to_batch_dim(value) if attention_mask is not None: if attention_mask.shape[-1] != query.shape[1]: target_length = query.shape[1] attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) # attention, what we cannot get enough of if self._use_memory_efficient_attention_xformers: hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) # Some versions of xformers return output in fp32, cast it back to the dtype of the input hidden_states = hidden_states.to(query.dtype) else: if self._slice_size is None or query.shape[0] // self._slice_size == 1: hidden_states = self._attention(query, key, value, attention_mask) else: hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) # linear proj hidden_states = self.to_out[0](hidden_states) # dropout hidden_states = self.to_out[1](hidden_states) if self.attention_mode == "Temporal": hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) return hidden_states ================================================ FILE: animatediff/models/resnet.py ================================================ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange class InflatedConv3d(nn.Conv2d): def forward(self, x): video_length = x.shape[2] x = rearrange(x, "b c f h w -> (b f) c h w") x = super().forward(x) x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) return x class InflatedGroupNorm(nn.GroupNorm): def forward(self, x): video_length = x.shape[2] x = rearrange(x, "b c f h w -> (b f) c h w") x = super().forward(x) x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) return x class Upsample3D(nn.Module): def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.use_conv_transpose = use_conv_transpose self.name = name conv = None if use_conv_transpose: raise NotImplementedError elif use_conv: self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) def forward(self, hidden_states, output_size=None): assert hidden_states.shape[1] == self.channels if self.use_conv_transpose: raise NotImplementedError # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 dtype = hidden_states.dtype if dtype == torch.bfloat16: hidden_states = hidden_states.to(torch.float32) # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 if hidden_states.shape[0] >= 64: hidden_states = hidden_states.contiguous() # if `output_size` is passed we force the interpolation output # size and do not make use of `scale_factor=2` if output_size is None: hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") else: hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") # If the input is bfloat16, we cast back to bfloat16 if dtype == torch.bfloat16: hidden_states = hidden_states.to(dtype) # if self.use_conv: # if self.name == "conv": # hidden_states = self.conv(hidden_states) # else: # hidden_states = self.Conv2d_0(hidden_states) hidden_states = self.conv(hidden_states) return hidden_states class Downsample3D(nn.Module): def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.padding = padding stride = 2 self.name = name if use_conv: self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) else: raise NotImplementedError def forward(self, hidden_states): assert hidden_states.shape[1] == self.channels if self.use_conv and self.padding == 0: raise NotImplementedError assert hidden_states.shape[1] == self.channels hidden_states = self.conv(hidden_states) return hidden_states class ResnetBlock3D(nn.Module): def __init__( self, *, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, temb_channels=512, groups=32, groups_out=None, pre_norm=True, eps=1e-6, non_linearity="swish", time_embedding_norm="default", output_scale_factor=1.0, use_in_shortcut=None, use_inflated_groupnorm=False, ): super().__init__() self.pre_norm = pre_norm self.pre_norm = True self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.time_embedding_norm = time_embedding_norm self.output_scale_factor = output_scale_factor if groups_out is None: groups_out = groups assert use_inflated_groupnorm != None if use_inflated_groupnorm: self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) else: self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels is not None: if self.time_embedding_norm == "default": time_emb_proj_out_channels = out_channels elif self.time_embedding_norm == "scale_shift": time_emb_proj_out_channels = out_channels * 2 else: raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) else: self.time_emb_proj = None if use_inflated_groupnorm: self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) else: self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) self.dropout = torch.nn.Dropout(dropout) self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if non_linearity == "swish": self.nonlinearity = lambda x: F.silu(x) elif non_linearity == "mish": self.nonlinearity = Mish() elif non_linearity == "silu": self.nonlinearity = nn.SiLU() self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut self.conv_shortcut = None if self.use_in_shortcut: self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, input_tensor, temb): hidden_states = input_tensor hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv1(hidden_states) if temb is not None: temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) if temb is not None and self.time_embedding_norm == "scale_shift": scale, shift = torch.chunk(temb, 2, dim=1) hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: input_tensor = self.conv_shortcut(input_tensor) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor return output_tensor class Mish(torch.nn.Module): def forward(self, hidden_states): return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) ================================================ FILE: animatediff/models/sparse_controlnet.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Changes were made to this source code by Yuwei Guo. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn from torch.nn import functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.utils import BaseOutput, logging from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.modeling_utils import ModelMixin from .unet_blocks import ( CrossAttnDownBlock3D, DownBlock3D, UNetMidBlock3DCrossAttn, get_down_block, ) from einops import repeat, rearrange from .resnet import InflatedConv3d from diffusers.models.unet_2d_condition import UNet2DConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class SparseControlNetOutput(BaseOutput): down_block_res_samples: Tuple[torch.Tensor] mid_block_res_sample: torch.Tensor class SparseControlNetConditioningEmbedding(nn.Module): def __init__( self, conditioning_embedding_channels: int, conditioning_channels: int = 3, block_out_channels: Tuple[int] = (16, 32, 96, 256), ): super().__init__() self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) self.blocks = nn.ModuleList([]) for i in range(len(block_out_channels) - 1): channel_in = block_out_channels[i] channel_out = block_out_channels[i + 1] self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)) self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) self.conv_out = zero_module( InflatedConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) ) def forward(self, conditioning): embedding = self.conv_in(conditioning) embedding = F.silu(embedding) for block in self.blocks: embedding = block(embedding) embedding = F.silu(embedding) embedding = self.conv_out(embedding) return embedding class SparseControlNetModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @register_to_config def __init__( self, in_channels: int = 4, conditioning_channels: int = 3, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", projection_class_embeddings_input_dim: Optional[int] = None, controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), global_pool_conditions: bool = False, use_motion_module = True, motion_module_resolutions = ( 1,2,4,8 ), motion_module_mid_block = False, motion_module_type = "Vanilla", motion_module_kwargs = { "num_attention_heads": 8, "num_transformer_block": 1, "attention_block_types": ["Temporal_Self"], "temporal_position_encoding": True, "temporal_position_encoding_max_len": 32, "temporal_attention_dim_div": 1, "causal_temporal_attention": False, }, concate_conditioning_mask: bool = True, use_simplified_condition_embedding: bool = False, set_noisy_sample_input_to_zero: bool = False, ): super().__init__() # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) # input self.set_noisy_sample_input_to_zero = set_noisy_sample_input_to_zero conv_in_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = InflatedConv3d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) if concate_conditioning_mask: conditioning_channels = conditioning_channels + 1 self.concate_conditioning_mask = concate_conditioning_mask # control net conditioning embedding if use_simplified_condition_embedding: self.controlnet_cond_embedding = zero_module( InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding) ) else: self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding( conditioning_embedding_channels=block_out_channels[0], block_out_channels=conditioning_embedding_out_channels, conditioning_channels=conditioning_channels, ) self.use_simplified_condition_embedding = use_simplified_condition_embedding # time time_embed_dim = block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, ) # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None self.down_blocks = nn.ModuleList([]) self.controlnet_down_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): only_cross_attention = [only_cross_attention] * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) # down output_channel = block_out_channels[0] controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) for i, down_block_type in enumerate(down_block_types): res = 2 ** i input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, downsample_padding=downsample_padding, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, use_inflated_groupnorm=True, use_motion_module=use_motion_module and (res in motion_module_resolutions), motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) self.down_blocks.append(down_block) for _ in range(layers_per_block): controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) if not is_final_block: controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) # mid mid_block_channel = block_out_channels[-1] controlnet_block = InflatedConv3d(mid_block_channel, mid_block_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_mid_block = controlnet_block self.mid_block = UNetMidBlock3DCrossAttn( in_channels=mid_block_channel, temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim, attn_num_head_channels=num_attention_heads[-1], resnet_groups=norm_num_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, use_inflated_groupnorm=True, use_motion_module=use_motion_module and motion_module_mid_block, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) @classmethod def from_unet( cls, unet: UNet2DConditionModel, controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), load_weights_from_unet: bool = True, controlnet_additional_kwargs: dict = {}, ): controlnet = cls( in_channels=unet.config.in_channels, flip_sin_to_cos=unet.config.flip_sin_to_cos, freq_shift=unet.config.freq_shift, down_block_types=unet.config.down_block_types, only_cross_attention=unet.config.only_cross_attention, block_out_channels=unet.config.block_out_channels, layers_per_block=unet.config.layers_per_block, downsample_padding=unet.config.downsample_padding, mid_block_scale_factor=unet.config.mid_block_scale_factor, act_fn=unet.config.act_fn, norm_num_groups=unet.config.norm_num_groups, norm_eps=unet.config.norm_eps, cross_attention_dim=unet.config.cross_attention_dim, attention_head_dim=unet.config.attention_head_dim, num_attention_heads=unet.config.num_attention_heads, use_linear_projection=unet.config.use_linear_projection, class_embed_type=unet.config.class_embed_type, num_class_embeds=unet.config.num_class_embeds, upcast_attention=unet.config.upcast_attention, resnet_time_scale_shift=unet.config.resnet_time_scale_shift, projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, conditioning_embedding_out_channels=conditioning_embedding_out_channels, **controlnet_additional_kwargs, ) if load_weights_from_unet: m, u = controlnet.conv_in.load_state_dict(cls.image_layer_filter(unet.conv_in.state_dict()), strict=False) assert len(u) == 0 m, u = controlnet.time_proj.load_state_dict(cls.image_layer_filter(unet.time_proj.state_dict()), strict=False) assert len(u) == 0 m, u = controlnet.time_embedding.load_state_dict(cls.image_layer_filter(unet.time_embedding.state_dict()), strict=False) assert len(u) == 0 if controlnet.class_embedding: m, u = controlnet.class_embedding.load_state_dict(cls.image_layer_filter(unet.class_embedding.state_dict()), strict=False) assert len(u) == 0 m, u = controlnet.down_blocks.load_state_dict(cls.image_layer_filter(unet.down_blocks.state_dict()), strict=False) assert len(u) == 0 m, u = controlnet.mid_block.load_state_dict(cls.image_layer_filter(unet.mid_block.state_dict()), strict=False) assert len(u) == 0 return controlnet @staticmethod def image_layer_filter(state_dict): new_state_dict = {} for name, param in state_dict.items(): if "motion_modules." in name or "lora" in name: continue new_state_dict[name] = param return new_state_dict # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, controlnet_cond: torch.FloatTensor, conditioning_mask: Optional[torch.FloatTensor] = None, conditioning_scale: float = 1.0, class_labels: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guess_mode: bool = False, return_dict: bool = True, ) -> Union[SparseControlNetOutput, Tuple]: # set input noise to zero if self.set_noisy_sample_input_to_zero: sample = torch.zeros_like(sample).to(sample.device) # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) timesteps = timesteps.repeat(sample.shape[0] // timesteps.shape[0]) encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0] // encoder_hidden_states.shape[0], 1, 1) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb # 2. pre-process sample = self.conv_in(sample) if self.concate_conditioning_mask: controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1) controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) sample = sample + controlnet_cond # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, # cross_attention_kwargs=cross_attention_kwargs, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) down_block_res_samples += res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, # cross_attention_kwargs=cross_attention_kwargs, ) # 5. controlnet blocks controlnet_down_block_res_samples = () for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): down_block_res_sample = controlnet_block(down_block_res_sample) controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = controlnet_down_block_res_samples mid_block_res_sample = self.controlnet_mid_block(sample) # 6. scaling if guess_mode and not self.config.global_pool_conditions: scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 scales = scales * conditioning_scale down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] mid_block_res_sample = mid_block_res_sample * scales[-1] # last one else: down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] mid_block_res_sample = mid_block_res_sample * conditioning_scale if self.config.global_pool_conditions: down_block_res_samples = [ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples ] mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) if not return_dict: return (down_block_res_samples, mid_block_res_sample) return SparseControlNetOutput( down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample ) def zero_module(module): for p in module.parameters(): nn.init.zeros_(p) return module ================================================ FILE: animatediff/models/unet.py ================================================ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py from dataclasses import dataclass from typing import List, Optional, Tuple, Union import os import json import pdb import torch import torch.nn as nn import torch.utils.checkpoint from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.modeling_utils import ModelMixin from diffusers.utils import BaseOutput, logging from diffusers.models.embeddings import TimestepEmbedding, Timesteps from .unet_blocks import ( CrossAttnDownBlock3D, CrossAttnUpBlock3D, DownBlock3D, UNetMidBlock3DCrossAttn, UpBlock3D, get_down_block, get_up_block, ) from .resnet import InflatedConv3d, InflatedGroupNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class UNet3DConditionOutput(BaseOutput): sample: torch.FloatTensor class UNet3DConditionModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D", ), mid_block_type: str = "UNetMidBlock3DCrossAttn", up_block_types: Tuple[str] = ( "UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D" ), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: int = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, attention_head_dim: Union[int, Tuple[int]] = 8, dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", use_inflated_groupnorm=False, # Additional use_motion_module = False, motion_module_resolutions = ( 1,2,4,8 ), motion_module_mid_block = False, motion_module_decoder_only = False, motion_module_type = None, motion_module_kwargs = {}, unet_use_cross_frame_attention = False, unet_use_temporal_attention = False, ): super().__init__() self.sample_size = sample_size time_embed_dim = block_out_channels[0] * 4 # input self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) # time self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) else: self.class_embedding = None self.down_blocks = nn.ModuleList([]) self.mid_block = None self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): only_cross_attention = [only_cross_attention] * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): res = 2 ** i input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, use_inflated_groupnorm=use_inflated_groupnorm, use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only), motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) self.down_blocks.append(down_block) # mid if mid_block_type == "UNetMidBlock3DCrossAttn": self.mid_block = UNetMidBlock3DCrossAttn( in_channels=block_out_channels[-1], temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, use_inflated_groupnorm=use_inflated_groupnorm, use_motion_module=use_motion_module and motion_module_mid_block, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") # count how many layers upsample the videos self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_attention_head_dim = list(reversed(attention_head_dim)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): res = 2 ** (3 - i) is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=layers_per_block + 1, in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=reversed_attention_head_dim[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, use_inflated_groupnorm=use_inflated_groupnorm, use_motion_module=use_motion_module and (res in motion_module_resolutions), motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if use_inflated_groupnorm: self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) else: self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) self.conv_act = nn.SiLU() self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module will split the input tensor in slices, to compute attention in several steps. This is useful to save some memory in exchange for a small speed decrease. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_slicable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_slicable_dims(module) num_slicable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_slicable_layers * [1] slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, # support controlnet down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet3DConditionOutput, Tuple]: r""" Args: sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # time timesteps = timestep if not torch.is_tensor(timesteps): # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb # pre-process sample = self.conv_in(sample) # down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states) down_block_res_samples += res_samples # support controlnet down_block_res_samples = list(down_block_res_samples) if down_block_additional_residuals is not None: for i, down_block_additional_residual in enumerate(down_block_additional_residuals): if down_block_additional_residual.dim() == 4: # boardcast down_block_additional_residual = down_block_additional_residual.unsqueeze(2) down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual # mid sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask ) # support controlnet if mid_block_additional_residual is not None: if mid_block_additional_residual.dim() == 4: # boardcast mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2) sample = sample + mid_block_additional_residual # up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, upsample_size=upsample_size, attention_mask=attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, ) # post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: return (sample,) return UNet3DConditionOutput(sample=sample) @classmethod def from_pretrained_2d(cls, pretrained_model_name_or_path, unet_additional_kwargs={}, **kwargs): from diffusers import __version__ from diffusers.utils import DIFFUSERS_CACHE, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, is_safetensors_available from diffusers.modeling_utils import load_state_dict print(f"loaded 3D unet's pretrained weights from {pretrained_model_name_or_path} ...") cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) device_map = kwargs.pop("device_map", None) user_agent = { "diffusers": __version__, "file_type": "model", "framework": "pytorch", } model_file = None if is_safetensors_available(): try: model_file = cls._get_model_file( pretrained_model_name_or_path, weights_name=SAFETENSORS_WEIGHTS_NAME, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) except: pass if model_file is None: model_file = cls._get_model_file( pretrained_model_name_or_path, weights_name=WEIGHTS_NAME, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) config, unused_kwargs = cls.load_config( pretrained_model_name_or_path, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, device_map=device_map, **kwargs, ) config["_class_name"] = cls.__name__ config["down_block_types"] = [ "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D" ] config["up_block_types"] = [ "UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D" ] model = cls.from_config(config, **unused_kwargs, **unet_additional_kwargs) state_dict = load_state_dict(model_file) m, u = model.load_state_dict(state_dict, strict=False) print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()] print(f"### Motion Module Parameters: {sum(params) / 1e6} M") return model ================================================ FILE: animatediff/models/unet_blocks.py ================================================ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py import torch from torch import nn from .attention import Transformer3DModel from .resnet import Downsample3D, ResnetBlock3D, Upsample3D from .motion_module import get_motion_module import pdb def get_down_block( down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample, resnet_eps, resnet_act_fn, attn_num_head_channels, resnet_groups=None, cross_attention_dim=None, downsample_padding=None, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", unet_use_cross_frame_attention=False, unet_use_temporal_attention=False, use_inflated_groupnorm=False, use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock3D": return DownBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, use_inflated_groupnorm=use_inflated_groupnorm, use_motion_module=use_motion_module, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) elif down_block_type == "CrossAttnDownBlock3D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") return CrossAttnDownBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, use_inflated_groupnorm=use_inflated_groupnorm, use_motion_module=use_motion_module, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) raise ValueError(f"{down_block_type} does not exist.") def get_up_block( up_block_type, num_layers, in_channels, out_channels, prev_output_channel, temb_channels, add_upsample, resnet_eps, resnet_act_fn, attn_num_head_channels, resnet_groups=None, cross_attention_dim=None, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", unet_use_cross_frame_attention=False, unet_use_temporal_attention=False, use_inflated_groupnorm=False, use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock3D": return UpBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, use_inflated_groupnorm=use_inflated_groupnorm, use_motion_module=use_motion_module, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) elif up_block_type == "CrossAttnUpBlock3D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") return CrossAttnUpBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, use_inflated_groupnorm=use_inflated_groupnorm, use_motion_module=use_motion_module, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) raise ValueError(f"{up_block_type} does not exist.") class UNetMidBlock3DCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attn_num_head_channels=1, output_scale_factor=1.0, cross_attention_dim=1280, dual_cross_attention=False, use_linear_projection=False, upcast_attention=False, unet_use_cross_frame_attention=False, unet_use_temporal_attention=False, use_inflated_groupnorm=False, use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, ): super().__init__() self.has_cross_attention = True self.attn_num_head_channels = attn_num_head_channels resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) # there is always at least one resnet resnets = [ ResnetBlock3D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_inflated_groupnorm=use_inflated_groupnorm, ) ] attentions = [] motion_modules = [] for _ in range(num_layers): if dual_cross_attention: raise NotImplementedError attentions.append( Transformer3DModel( attn_num_head_channels, in_channels // attn_num_head_channels, in_channels=in_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, ) ) motion_modules.append( get_motion_module( in_channels=in_channels, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) if use_motion_module else None ) resnets.append( ResnetBlock3D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_inflated_groupnorm=use_inflated_groupnorm, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules): hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states hidden_states = resnet(hidden_states, temb) return hidden_states class CrossAttnDownBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attn_num_head_channels=1, cross_attention_dim=1280, output_scale_factor=1.0, downsample_padding=1, add_downsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, unet_use_cross_frame_attention=False, unet_use_temporal_attention=False, use_inflated_groupnorm=False, use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, ): super().__init__() resnets = [] attentions = [] motion_modules = [] self.has_cross_attention = True self.attn_num_head_channels = attn_num_head_channels for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock3D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_inflated_groupnorm=use_inflated_groupnorm, ) ) if dual_cross_attention: raise NotImplementedError attentions.append( Transformer3DModel( attn_num_head_channels, out_channels // attn_num_head_channels, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, ) ) motion_modules.append( get_motion_module( in_channels=out_channels, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) if use_motion_module else None ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample3D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): output_states = () for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, )[0] if motion_module is not None: hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) else: hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample # add motion module hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class DownBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_downsample=True, downsample_padding=1, use_inflated_groupnorm=False, use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, ): super().__init__() resnets = [] motion_modules = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock3D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_inflated_groupnorm=use_inflated_groupnorm, ) ) motion_modules.append( get_motion_module( in_channels=out_channels, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) if use_motion_module else None ) self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample3D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, temb=None, encoder_hidden_states=None): output_states = () for resnet, motion_module in zip(self.resnets, self.motion_modules): if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) if motion_module is not None: hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) else: hidden_states = resnet(hidden_states, temb) # add motion module hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class CrossAttnUpBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attn_num_head_channels=1, cross_attention_dim=1280, output_scale_factor=1.0, add_upsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, unet_use_cross_frame_attention=False, unet_use_temporal_attention=False, use_inflated_groupnorm=False, use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, ): super().__init__() resnets = [] attentions = [] motion_modules = [] self.has_cross_attention = True self.attn_num_head_channels = attn_num_head_channels for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock3D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_inflated_groupnorm=use_inflated_groupnorm, ) ) if dual_cross_attention: raise NotImplementedError attentions.append( Transformer3DModel( attn_num_head_channels, out_channels // attn_num_head_channels, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, ) ) motion_modules.append( get_motion_module( in_channels=out_channels, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) if use_motion_module else None ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) if add_upsample: self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None, attention_mask=None, ): for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, )[0] if motion_module is not None: hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) else: hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample # add motion module hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states class UpBlock3D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, use_inflated_groupnorm=False, use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, ): super().__init__() resnets = [] motion_modules = [] for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock3D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_inflated_groupnorm=use_inflated_groupnorm, ) ) motion_modules.append( get_motion_module( in_channels=out_channels, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) if use_motion_module else None ) self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) if add_upsample: self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,): for resnet, motion_module in zip(self.resnets, self.motion_modules): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) if motion_module is not None: hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) else: hidden_states = resnet(hidden_states, temb) hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states ================================================ FILE: animatediff/pipelines/pipeline_animation.py ================================================ # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py import inspect from typing import Callable, List, Optional, Union from dataclasses import dataclass import numpy as np import torch from tqdm import tqdm from diffusers.utils import is_accelerate_available from packaging import version from transformers import CLIPTextModel, CLIPTokenizer from diffusers.configuration_utils import FrozenDict from diffusers.models import AutoencoderKL from diffusers.pipeline_utils import DiffusionPipeline from diffusers.schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, ) from diffusers.utils import deprecate, logging, BaseOutput from einops import rearrange from ..models.unet import UNet3DConditionModel from ..models.sparse_controlnet import SparseControlNetModel import pdb logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class AnimationPipelineOutput(BaseOutput): videos: Union[torch.Tensor, np.ndarray] class AnimationPipeline(DiffusionPipeline): _optional_components = [] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet3DConditionModel, scheduler: Union[ DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ], controlnet: Union[SparseControlNetModel, None] = None, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, controlnet=controlnet, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def enable_vae_slicing(self): self.vae.enable_slicing() def disable_vae_slicing(self): self.vae.disable_slicing() def enable_sequential_cpu_offload(self, gpu_id=0): if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property def _execution_device(self): if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): batch_size = len(prompt) if isinstance(prompt, list) else 1 text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None text_embeddings = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) text_embeddings = text_embeddings[0] # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = text_embeddings.shape text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None uncond_embeddings = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) uncond_embeddings = uncond_embeddings[0] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = uncond_embeddings.shape[1] uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) return text_embeddings def decode_latents(self, latents): video_length = latents.shape[2] latents = 1 / 0.18215 * latents latents = rearrange(latents, "b c f h w -> (b f) c h w") # video = self.vae.decode(latents).sample video = [] for frame_idx in tqdm(range(latents.shape[0])): video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample) video = torch.cat(video) video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) video = (video / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 video = video.cpu().float().numpy() return video def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs(self, prompt, height, width, callback_steps): if not isinstance(prompt, str) and not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: rand_device = "cpu" if device.type == "mps" else device if isinstance(generator, list): shape = shape # shape = (1,) + shape[1:] latents = [ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size) ] latents = torch.cat(latents, dim=0).to(device) else: latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], video_length: Optional[int], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "tensor", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, # support controlnet controlnet_images: torch.FloatTensor = None, controlnet_image_index: list = [0], controlnet_conditioning_scale: Union[float, List[float]] = 1.0, **kwargs, ): # Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) # Define call parameters # batch_size = 1 if isinstance(prompt, str) else len(prompt) batch_size = 1 if latents is not None: batch_size = latents.shape[0] if isinstance(prompt, list): batch_size = len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # Encode input prompt prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size if negative_prompt is not None: negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size text_embeddings = self._encode_prompt( prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt ) # Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # Prepare latent variables num_channels_latents = self.unet.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, video_length, height, width, text_embeddings.dtype, device, generator, latents, ) latents_dtype = latents.dtype # Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) down_block_additional_residuals = mid_block_additional_residual = None if (getattr(self, "controlnet", None) != None) and (controlnet_images != None): assert controlnet_images.dim() == 5 controlnet_noisy_latents = latent_model_input controlnet_prompt_embeds = text_embeddings controlnet_images = controlnet_images.to(latents.device) controlnet_cond_shape = list(controlnet_images.shape) controlnet_cond_shape[2] = video_length controlnet_cond = torch.zeros(controlnet_cond_shape).to(latents.device) controlnet_conditioning_mask_shape = list(controlnet_cond.shape) controlnet_conditioning_mask_shape[1] = 1 controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(latents.device) assert controlnet_images.shape[2] >= len(controlnet_image_index) controlnet_cond[:,:,controlnet_image_index] = controlnet_images[:,:,:len(controlnet_image_index)] controlnet_conditioning_mask[:,:,controlnet_image_index] = 1 down_block_additional_residuals, mid_block_additional_residual = self.controlnet( controlnet_noisy_latents, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=controlnet_cond, conditioning_mask=controlnet_conditioning_mask, conditioning_scale=controlnet_conditioning_scale, guess_mode=False, return_dict=False, ) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=text_embeddings, down_block_additional_residuals = down_block_additional_residuals, mid_block_additional_residual = mid_block_additional_residual, ).sample.to(dtype=latents_dtype) # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # Post-processing video = self.decode_latents(latents) # Convert to tensor if output_type == "tensor": video = torch.from_numpy(video) if not return_dict: return video return AnimationPipelineOutput(videos=video) ================================================ FILE: animatediff/utils/convert_from_ckpt.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Conversion script for the Stable Diffusion checkpoints.""" import re from io import BytesIO from typing import Optional import requests import torch from transformers import ( AutoFeatureExtractor, BertTokenizerFast, CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionConfig, CLIPVisionModelWithProjection, ) from diffusers.models import ( AutoencoderKL, PriorTransformer, UNet2DConditionModel, ) from diffusers.schedulers import ( DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UnCLIPScheduler, ) from diffusers.utils.import_utils import BACKENDS_MAPPING def shave_segments(path, n_shave_prefix_segments=1): """ Removes segments. Positive values shave the first segments, negative shave the last segments. """ if n_shave_prefix_segments >= 0: return ".".join(path.split(".")[n_shave_prefix_segments:]) else: return ".".join(path.split(".")[:n_shave_prefix_segments]) def renew_resnet_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside resnets to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item.replace("in_layers.0", "norm1") new_item = new_item.replace("in_layers.2", "conv1") new_item = new_item.replace("out_layers.0", "norm2") new_item = new_item.replace("out_layers.3", "conv2") new_item = new_item.replace("emb_layers.1", "time_emb_proj") new_item = new_item.replace("skip_connection", "conv_shortcut") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside resnets to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item new_item = new_item.replace("nin_shortcut", "conv_shortcut") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def renew_attention_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside attentions to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item # new_item = new_item.replace('norm.weight', 'group_norm.weight') # new_item = new_item.replace('norm.bias', 'group_norm.bias') # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside attentions to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item new_item = new_item.replace("norm.weight", "group_norm.weight") new_item = new_item.replace("norm.bias", "group_norm.bias") new_item = new_item.replace("q.weight", "query.weight") new_item = new_item.replace("q.bias", "query.bias") new_item = new_item.replace("k.weight", "key.weight") new_item = new_item.replace("k.bias", "key.bias") new_item = new_item.replace("v.weight", "value.weight") new_item = new_item.replace("v.bias", "value.bias") new_item = new_item.replace("proj_out.weight", "proj_attn.weight") new_item = new_item.replace("proj_out.bias", "proj_attn.bias") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def assign_to_checkpoint( paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None ): """ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits attention layers, and takes into account additional replacements that may arise. Assigns the weights to the new checkpoint. """ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." # Splits the attention layers into three variables. if attention_paths_to_split is not None: for path, path_map in attention_paths_to_split.items(): old_tensor = old_checkpoint[path] channels = old_tensor.shape[0] // 3 target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) query, key, value = old_tensor.split(channels // num_heads, dim=1) checkpoint[path_map["query"]] = query.reshape(target_shape) checkpoint[path_map["key"]] = key.reshape(target_shape) checkpoint[path_map["value"]] = value.reshape(target_shape) for path in paths: new_path = path["new"] # These have already been assigned if attention_paths_to_split is not None and new_path in attention_paths_to_split: continue # Global renaming happens here new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") if additional_replacements is not None: for replacement in additional_replacements: new_path = new_path.replace(replacement["old"], replacement["new"]) # proj_attn.weight has to be converted from conv 1D to linear if "proj_attn.weight" in new_path: checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] else: checkpoint[new_path] = old_checkpoint[path["old"]] def conv_attn_to_linear(checkpoint): keys = list(checkpoint.keys()) attn_keys = ["query.weight", "key.weight", "value.weight"] for key in keys: if ".".join(key.split(".")[-2:]) in attn_keys: if checkpoint[key].ndim > 2: checkpoint[key] = checkpoint[key][:, :, 0, 0] elif "proj_attn.weight" in key: if checkpoint[key].ndim > 2: checkpoint[key] = checkpoint[key][:, :, 0] def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): """ Creates a config for the diffusers based on the config of the LDM model. """ if controlnet: unet_params = original_config.model.params.control_stage_config.params else: unet_params = original_config.model.params.unet_config.params vae_params = original_config.model.params.first_stage_config.params.ddconfig block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] down_block_types = [] resolution = 1 for i in range(len(block_out_channels)): block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" down_block_types.append(block_type) if i != len(block_out_channels) - 1: resolution *= 2 up_block_types = [] for i in range(len(block_out_channels)): block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" up_block_types.append(block_type) resolution //= 2 vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) head_dim = unet_params.num_heads if "num_heads" in unet_params else None use_linear_projection = ( unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False ) if use_linear_projection: # stable diffusion 2-base-512 and 2-768 if head_dim is None: head_dim = [5, 10, 20, 20] class_embed_type = None projection_class_embeddings_input_dim = None if "num_classes" in unet_params: if unet_params.num_classes == "sequential": class_embed_type = "projection" assert "adm_in_channels" in unet_params projection_class_embeddings_input_dim = unet_params.adm_in_channels else: raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") config = { "sample_size": image_size // vae_scale_factor, "in_channels": unet_params.in_channels, "down_block_types": tuple(down_block_types), "block_out_channels": tuple(block_out_channels), "layers_per_block": unet_params.num_res_blocks, "cross_attention_dim": unet_params.context_dim, "attention_head_dim": head_dim, "use_linear_projection": use_linear_projection, "class_embed_type": class_embed_type, "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, } if not controlnet: config["out_channels"] = unet_params.out_channels config["up_block_types"] = tuple(up_block_types) return config def create_vae_diffusers_config(original_config, image_size: int): """ Creates a config for the diffusers based on the config of the LDM model. """ vae_params = original_config.model.params.first_stage_config.params.ddconfig _ = original_config.model.params.first_stage_config.params.embed_dim block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) config = { "sample_size": image_size, "in_channels": vae_params.in_channels, "out_channels": vae_params.out_ch, "down_block_types": tuple(down_block_types), "up_block_types": tuple(up_block_types), "block_out_channels": tuple(block_out_channels), "latent_channels": vae_params.z_channels, "layers_per_block": vae_params.num_res_blocks, } return config def create_diffusers_schedular(original_config): schedular = DDIMScheduler( num_train_timesteps=original_config.model.params.timesteps, beta_start=original_config.model.params.linear_start, beta_end=original_config.model.params.linear_end, beta_schedule="scaled_linear", ) return schedular def create_ldm_bert_config(original_config): bert_params = original_config.model.parms.cond_stage_config.params config = LDMBertConfig( d_model=bert_params.n_embed, encoder_layers=bert_params.n_layer, encoder_ffn_dim=bert_params.n_embed * 4, ) return config def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False): """ Takes a state dict and a config, and returns a converted checkpoint. """ # extract state_dict for UNet unet_state_dict = {} keys = list(checkpoint.keys()) if controlnet: unet_key = "control_model." else: unet_key = "model.diffusion_model." # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: print(f"Checkpoint {path} has both EMA and non-EMA weights.") print( "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." ) for key in keys: if key.startswith("model.diffusion_model"): flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) else: if sum(k.startswith("model_ema") for k in keys) > 100: print( "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" " weights (usually better for inference), please make sure to add the `--extract_ema` flag." ) for key in keys: if key.startswith(unet_key): unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) new_checkpoint = {} new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] if config["class_embed_type"] is None: # No parameters to port ... elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] else: raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] if not controlnet: new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] # Retrieves the keys for the input blocks only num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) input_blocks = { layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] for layer_id in range(num_input_blocks) } # Retrieves the keys for the middle blocks only num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) middle_blocks = { layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] for layer_id in range(num_middle_blocks) } # Retrieves the keys for the output blocks only num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) output_blocks = { layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] for layer_id in range(num_output_blocks) } for i in range(1, num_input_blocks): block_id = (i - 1) // (config["layers_per_block"] + 1) layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) resnets = [ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key ] attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] if f"input_blocks.{i}.0.op.weight" in unet_state_dict: new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( f"input_blocks.{i}.0.op.weight" ) new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( f"input_blocks.{i}.0.op.bias" ) paths = renew_resnet_paths(resnets) meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) if len(attentions): paths = renew_attention_paths(attentions) meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) resnet_0 = middle_blocks[0] attentions = middle_blocks[1] resnet_1 = middle_blocks[2] resnet_0_paths = renew_resnet_paths(resnet_0) assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) resnet_1_paths = renew_resnet_paths(resnet_1) assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) attentions_paths = renew_attention_paths(attentions) meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} assign_to_checkpoint( attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) for i in range(num_output_blocks): block_id = i // (config["layers_per_block"] + 1) layer_in_block_id = i % (config["layers_per_block"] + 1) output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] output_block_list = {} for layer in output_block_layers: layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) if layer_id in output_block_list: output_block_list[layer_id].append(layer_name) else: output_block_list[layer_id] = [layer_name] if len(output_block_list) > 1: resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] resnet_0_paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets) meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) output_block_list = {k: sorted(v) for k, v in output_block_list.items()} if ["conv.bias", "conv.weight"] in output_block_list.values(): index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ f"output_blocks.{i}.{index}.conv.weight" ] new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ f"output_blocks.{i}.{index}.conv.bias" ] # Clear attentions as they have been attributed above. if len(attentions) == 2: attentions = [] if len(attentions): paths = renew_attention_paths(attentions) meta_path = { "old": f"output_blocks.{i}.1", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", } assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) else: resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) for path in resnet_0_paths: old_path = ".".join(["output_blocks", str(i), path["old"]]) new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) new_checkpoint[new_path] = unet_state_dict[old_path] if controlnet: # conditioning embedding orig_index = 0 new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.weight" ) new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.bias" ) orig_index += 2 diffusers_index = 0 while diffusers_index < 6: new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.weight" ) new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.bias" ) diffusers_index += 1 orig_index += 2 new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.weight" ) new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.bias" ) # down blocks for i in range(num_input_blocks): new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") # mid block new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") return new_checkpoint def convert_ldm_vae_checkpoint(checkpoint, config): # extract state dict for VAE vae_state_dict = {} vae_key = "first_stage_model." keys = list(checkpoint.keys()) for key in keys: if key.startswith(vae_key): vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) new_checkpoint = {} new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] # Retrieves the keys for the encoder down blocks only num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) down_blocks = { layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) } # Retrieves the keys for the decoder up blocks only num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) up_blocks = { layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) } for i in range(num_down_blocks): resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( f"encoder.down.{i}.downsample.conv.weight" ) new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( f"encoder.down.{i}.downsample.conv.bias" ) paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] num_mid_res_blocks = 2 for i in range(1, num_mid_res_blocks + 1): resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) conv_attn_to_linear(new_checkpoint) for i in range(num_up_blocks): block_id = num_up_blocks - 1 - i resnets = [ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key ] if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ f"decoder.up.{block_id}.upsample.conv.weight" ] new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ f"decoder.up.{block_id}.upsample.conv.bias" ] paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] num_mid_res_blocks = 2 for i in range(1, num_mid_res_blocks + 1): resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) conv_attn_to_linear(new_checkpoint) return new_checkpoint def convert_ldm_bert_checkpoint(checkpoint, config): def _copy_attn_layer(hf_attn_layer, pt_attn_layer): hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias def _copy_linear(hf_linear, pt_linear): hf_linear.weight = pt_linear.weight hf_linear.bias = pt_linear.bias def _copy_layer(hf_layer, pt_layer): # copy layer norms _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) # copy attn _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) # copy MLP pt_mlp = pt_layer[1][1] _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) _copy_linear(hf_layer.fc2, pt_mlp.net[2]) def _copy_layers(hf_layers, pt_layers): for i, hf_layer in enumerate(hf_layers): if i != 0: i += i pt_layer = pt_layers[i : i + 2] _copy_layer(hf_layer, pt_layer) hf_model = LDMBertModel(config).eval() # copy embeds hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight # copy layer norm _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) # copy hidden layers _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) return hf_model def convert_ldm_clip_checkpoint(checkpoint): text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") keys = list(checkpoint.keys()) text_model_dict = {} for key in keys: if key.startswith("cond_stage_model.transformer"): text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] text_model.load_state_dict(text_model_dict) return text_model textenc_conversion_lst = [ ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"), ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"), ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"), ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"), ] textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} textenc_transformer_conversion_lst = [ # (stable-diffusion, HF Diffusers) ("resblocks.", "text_model.encoder.layers."), ("ln_1", "layer_norm1"), ("ln_2", "layer_norm2"), (".c_fc.", ".fc1."), (".c_proj.", ".fc2."), (".attn", ".self_attn"), ("ln_final.", "transformer.text_model.final_layer_norm."), ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), ] protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} textenc_pattern = re.compile("|".join(protected.keys())) def convert_paint_by_example_checkpoint(checkpoint): config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14") model = PaintByExampleImageEncoder(config) keys = list(checkpoint.keys()) text_model_dict = {} for key in keys: if key.startswith("cond_stage_model.transformer"): text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] # load clip vision model.model.load_state_dict(text_model_dict) # load mapper keys_mapper = { k[len("cond_stage_model.mapper.res") :]: v for k, v in checkpoint.items() if k.startswith("cond_stage_model.mapper") } MAPPING = { "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], "attn.c_proj": ["attn1.to_out.0"], "ln_1": ["norm1"], "ln_2": ["norm3"], "mlp.c_fc": ["ff.net.0.proj"], "mlp.c_proj": ["ff.net.2"], } mapped_weights = {} for key, value in keys_mapper.items(): prefix = key[: len("blocks.i")] suffix = key.split(prefix)[-1].split(".")[-1] name = key.split(prefix)[-1].split(suffix)[0][1:-1] mapped_names = MAPPING[name] num_splits = len(mapped_names) for i, mapped_name in enumerate(mapped_names): new_name = ".".join([prefix, mapped_name, suffix]) shape = value.shape[0] // num_splits mapped_weights[new_name] = value[i * shape : (i + 1) * shape] model.mapper.load_state_dict(mapped_weights) # load final layer norm model.final_layer_norm.load_state_dict( { "bias": checkpoint["cond_stage_model.final_ln.bias"], "weight": checkpoint["cond_stage_model.final_ln.weight"], } ) # load final proj model.proj_out.load_state_dict( { "bias": checkpoint["proj_out.bias"], "weight": checkpoint["proj_out.weight"], } ) # load uncond vector model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) return model def convert_open_clip_checkpoint(checkpoint): text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") keys = list(checkpoint.keys()) text_model_dict = {} if "cond_stage_model.model.text_projection" in checkpoint: d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0]) else: d_model = 1024 text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") for key in keys: if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer continue if key in textenc_conversion_map: text_model_dict[textenc_conversion_map[key]] = checkpoint[key] if key.startswith("cond_stage_model.model.transformer."): new_key = key[len("cond_stage_model.model.transformer.") :] if new_key.endswith(".in_proj_weight"): new_key = new_key[: -len(".in_proj_weight")] new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] elif new_key.endswith(".in_proj_bias"): new_key = new_key[: -len(".in_proj_bias")] new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] else: new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) text_model_dict[new_key] = checkpoint[key] text_model.load_state_dict(text_model_dict) return text_model def stable_unclip_image_encoder(original_config): """ Returns the image processor and clip image encoder for the img2img unclip pipeline. We currently know of two types of stable unclip models which separately use the clip and the openclip image encoders. """ image_embedder_config = original_config.model.params.embedder_config sd_clip_image_embedder_class = image_embedder_config.target sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] if sd_clip_image_embedder_class == "ClipImageEmbedder": clip_model_name = image_embedder_config.params.model if clip_model_name == "ViT-L/14": feature_extractor = CLIPImageProcessor() image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") else: raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": feature_extractor = CLIPImageProcessor() image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") else: raise NotImplementedError( f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" ) return feature_extractor, image_encoder def stable_unclip_image_noising_components( original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None ): """ Returns the noising components for the img2img and txt2img unclip pipelines. Converts the stability noise augmentor into 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats 2. a `DDPMScheduler` for holding the noise schedule If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. """ noise_aug_config = original_config.model.params.noise_aug_config noise_aug_class = noise_aug_config.target noise_aug_class = noise_aug_class.split(".")[-1] if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": noise_aug_config = noise_aug_config.params embedding_dim = noise_aug_config.timestep_dim max_noise_level = noise_aug_config.noise_schedule_config.timesteps beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) if "clip_stats_path" in noise_aug_config: if clip_stats_path is None: raise ValueError("This stable unclip config requires a `clip_stats_path`") clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) clip_mean = clip_mean[None, :] clip_std = clip_std[None, :] clip_stats_state_dict = { "mean": clip_mean, "std": clip_std, } image_normalizer.load_state_dict(clip_stats_state_dict) else: raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") return image_normalizer, image_noising_scheduler def convert_controlnet_checkpoint( checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema ): ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) ctrlnet_config["upcast_attention"] = upcast_attention ctrlnet_config.pop("sample_size") controlnet_model = ControlNetModel(**ctrlnet_config) converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True ) controlnet_model.load_state_dict(converted_ctrl_checkpoint) return controlnet_model ================================================ FILE: animatediff/utils/convert_lora_safetensor_to_diffusers.py ================================================ # coding=utf-8 # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Changes were made to this source code by Yuwei Guo. """ Conversion script for the LoRA's safetensors checkpoints. """ import argparse import torch from safetensors.torch import load_file from diffusers import StableDiffusionPipeline def load_diffusers_lora(pipeline, state_dict, alpha=1.0): # directly update weight in diffusers model for key in state_dict: # only process lora down key if "up." in key: continue up_key = key.replace(".down.", ".up.") model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") model_key = model_key.replace("to_out.", "to_out.0.") layer_infos = model_key.split(".")[:-1] curr_layer = pipeline.unet while len(layer_infos) > 0: temp_name = layer_infos.pop(0) curr_layer = curr_layer.__getattr__(temp_name) weight_down = state_dict[key] weight_up = state_dict[up_key] curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) return pipeline def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): # load base model # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) # load LoRA weight from .safetensors # state_dict = load_file(checkpoint_path) visited = [] # directly update weight in diffusers model for key in state_dict: # it is suggested to print out the key, it usually will be something like below # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" # as we have set the alpha beforehand, so just skip if ".alpha" in key or key in visited: continue if "text" in key: layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") curr_layer = pipeline.text_encoder else: layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") curr_layer = pipeline.unet # find the target layer temp_name = layer_infos.pop(0) while len(layer_infos) > -1: try: curr_layer = curr_layer.__getattr__(temp_name) if len(layer_infos) > 0: temp_name = layer_infos.pop(0) elif len(layer_infos) == 0: break except Exception: if len(temp_name) > 0: temp_name += "_" + layer_infos.pop(0) else: temp_name = layer_infos.pop(0) pair_keys = [] if "lora_down" in key: pair_keys.append(key.replace("lora_down", "lora_up")) pair_keys.append(key) else: pair_keys.append(key) pair_keys.append(key.replace("lora_up", "lora_down")) # update weight if len(state_dict[pair_keys[0]].shape) == 4: weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device) else: weight_up = state_dict[pair_keys[0]].to(torch.float32) weight_down = state_dict[pair_keys[1]].to(torch.float32) curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) # update visited list for item in pair_keys: visited.append(item) return pipeline if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format." ) parser.add_argument( "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." ) parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") parser.add_argument( "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" ) parser.add_argument( "--lora_prefix_text_encoder", default="lora_te", type=str, help="The prefix of text encoder weight in safetensors", ) parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") parser.add_argument( "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." ) parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") args = parser.parse_args() base_model_path = args.base_model_path checkpoint_path = args.checkpoint_path dump_path = args.dump_path lora_prefix_unet = args.lora_prefix_unet lora_prefix_text_encoder = args.lora_prefix_text_encoder alpha = args.alpha pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) pipe = pipe.to(args.device) pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) ================================================ FILE: animatediff/utils/util.py ================================================ import os import imageio import numpy as np from typing import Union import torch import torchvision import torch.distributed as dist from huggingface_hub import snapshot_download from safetensors import safe_open from tqdm import tqdm from einops import rearrange from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, load_diffusers_lora MOTION_MODULES = [ "mm_sd_v14.ckpt", "mm_sd_v15.ckpt", "mm_sd_v15_v2.ckpt", "v3_sd15_mm.ckpt", ] ADAPTERS = [ # "mm_sd_v14.ckpt", # "mm_sd_v15.ckpt", # "mm_sd_v15_v2.ckpt", # "mm_sdxl_v10_beta.ckpt", "v2_lora_PanLeft.ckpt", "v2_lora_PanRight.ckpt", "v2_lora_RollingAnticlockwise.ckpt", "v2_lora_RollingClockwise.ckpt", "v2_lora_TiltDown.ckpt", "v2_lora_TiltUp.ckpt", "v2_lora_ZoomIn.ckpt", "v2_lora_ZoomOut.ckpt", "v3_sd15_adapter.ckpt", # "v3_sd15_mm.ckpt", "v3_sd15_sparsectrl_rgb.ckpt", "v3_sd15_sparsectrl_scribble.ckpt", ] BACKUP_DREAMBOOTH_MODELS = [ "realisticVisionV60B1_v51VAE.safetensors", "majicmixRealistic_v4.safetensors", "leosamsFilmgirlUltra_velvia20Lora.safetensors", "toonyou_beta3.safetensors", "majicmixRealistic_v5Preview.safetensors", "rcnzCartoon3d_v10.safetensors", "lyriel_v16.safetensors", "leosamsHelloworldXL_filmGrain20.safetensors", "TUSUN.safetensors", ] def zero_rank_print(s): if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): videos = rearrange(videos, "b c t h w -> t b c h w") outputs = [] for x in videos: x = torchvision.utils.make_grid(x, nrow=n_rows) x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) if rescale: x = (x + 1.0) / 2.0 # -1,1 -> 0,1 x = (x * 255).numpy().astype(np.uint8) outputs.append(x) os.makedirs(os.path.dirname(path), exist_ok=True) imageio.mimsave(path, outputs, fps=fps) def auto_download(local_path, is_dreambooth_lora=False): hf_repo = "guoyww/animatediff_t2i_backups" if is_dreambooth_lora else "guoyww/animatediff" folder, filename = os.path.split(local_path) if not os.path.exists(local_path): print(f"local file {local_path} does not exist. trying to download from {hf_repo}") if is_dreambooth_lora: assert filename in BACKUP_DREAMBOOTH_MODELS, f"{filename} dose not exist in {hf_repo}" else: assert filename in MOTION_MODULES + ADAPTERS, f"{filename} dose not exist in {hf_repo}" folder = "." if folder == "" else folder os.makedirs(folder, exist_ok=True) snapshot_download(repo_id=hf_repo, local_dir=folder, allow_patterns=[filename]) def load_weights( animation_pipeline, # motion module motion_module_path = "", motion_module_lora_configs = [], # domain adapter adapter_lora_path = "", adapter_lora_scale = 1.0, # image layers dreambooth_model_path = "", lora_model_path = "", lora_alpha = 0.8, ): # motion module unet_state_dict = {} if motion_module_path != "": auto_download(motion_module_path, is_dreambooth_lora=False) print(f"load motion module from {motion_module_path}") motion_module_state_dict = torch.load(motion_module_path, map_location="cpu") motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict # filter parameters for name, param in motion_module_state_dict.items(): if not "motion_modules." in name: continue if "pos_encoder.pe" in name: continue unet_state_dict.update({name: param}) unet_state_dict.pop("animatediff_config", "") missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False) assert len(unexpected) == 0 del unet_state_dict # base model if dreambooth_model_path != "": auto_download(dreambooth_model_path, is_dreambooth_lora=True) print(f"load dreambooth model from {dreambooth_model_path}") if dreambooth_model_path.endswith(".safetensors"): dreambooth_state_dict = {} with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f: for key in f.keys(): dreambooth_state_dict[key] = f.get_tensor(key) elif dreambooth_model_path.endswith(".ckpt"): dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu") # 1. vae converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config) animation_pipeline.vae.load_state_dict(converted_vae_checkpoint) # 2. unet converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config) animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) # 3. text_model animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict) del dreambooth_state_dict # lora layers if lora_model_path != "": auto_download(lora_model_path, is_dreambooth_lora=True) print(f"load lora model from {lora_model_path}") assert lora_model_path.endswith(".safetensors") lora_state_dict = {} with safe_open(lora_model_path, framework="pt", device="cpu") as f: for key in f.keys(): lora_state_dict[key] = f.get_tensor(key) animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha) del lora_state_dict # domain adapter lora if adapter_lora_path != "": auto_download(adapter_lora_path, is_dreambooth_lora=False) print(f"load domain lora from {adapter_lora_path}") domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu") domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict domain_lora_state_dict.pop("animatediff_config", "") animation_pipeline = load_diffusers_lora(animation_pipeline, domain_lora_state_dict, alpha=adapter_lora_scale) # motion module lora for motion_module_lora_config in motion_module_lora_configs: path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"] auto_download(path, is_dreambooth_lora=False) print(f"load motion LoRA from {path}") motion_lora_state_dict = torch.load(path, map_location="cpu") motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict motion_lora_state_dict.pop("animatediff_config", "") animation_pipeline = load_diffusers_lora(animation_pipeline, motion_lora_state_dict, alpha) return animation_pipeline ================================================ FILE: app.py ================================================ import os import json import torch import random import gradio as gr from glob import glob from omegaconf import OmegaConf from datetime import datetime from safetensors import safe_open from diffusers import AutoencoderKL from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler from diffusers.utils.import_utils import is_xformers_available from transformers import CLIPTextModel, CLIPTokenizer from animatediff.models.unet import UNet3DConditionModel from animatediff.pipelines.pipeline_animation import AnimationPipeline from animatediff.utils.util import save_videos_grid, load_weights, auto_download, MOTION_MODULES, BACKUP_DREAMBOOTH_MODELS from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora import pdb sample_idx = 0 scheduler_dict = { "DDIM": DDIMScheduler, "Euler": EulerDiscreteScheduler, "PNDM": PNDMScheduler, } css = """ .toolbutton { margin-buttom: 0em 0em 0em 0em; max-width: 2.5em; min-width: 2.5em !important; height: 2.5em; } """ PRETRAINED_SD = "runwayml/stable-diffusion-v1-5" default_motion_module = "v3_sd15_mm.ckpt" default_inference_config = "configs/inference/inference-v3.yaml" default_dreambooth_model = "realisticVisionV60B1_v51VAE.safetensors" default_prompt = "b&w photo of 42 y.o man in black clothes, bald, face, half body, body, high detailed skin, skin pores, coastline, overcast weather, wind, waves, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" default_n_prompt = "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck" default_seed = 8893659352891878017 device = "cuda" if torch.cuda.is_available() else "cpu" class AnimateController: def __init__(self): # config dirs self.basedir = os.getcwd() self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion") self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module") self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA") self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) self.savedir_sample = os.path.join(self.savedir, "sample") os.makedirs(self.savedir, exist_ok=True) self.stable_diffusion_list = [PRETRAINED_SD] self.motion_module_list = MOTION_MODULES self.personalized_model_list = BACKUP_DREAMBOOTH_MODELS # config models self.pipeline = None # self.lora_model_state_dict = {} self.refresh_stable_diffusion() self.refresh_personalized_model() # default setting self.update_pipeline( stable_diffusion_dropdown=PRETRAINED_SD, motion_module_dropdown=default_motion_module, base_model_dropdown=default_dreambooth_model, sampler_dropdown="DDIM", ) def refresh_stable_diffusion(self): self.stable_diffusion_list = [PRETRAINED_SD] + glob(os.path.join(self.stable_diffusion_dir, "*/")) def refresh_personalized_model(self): personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors")) self.personalized_model_list = BACKUP_DREAMBOOTH_MODELS + [os.path.basename(p) for p in personalized_model_list if os.path.basename(p) not in BACKUP_DREAMBOOTH_MODELS] # for dropdown update def update_pipeline( self, stable_diffusion_dropdown, motion_module_dropdown, base_model_dropdown="", lora_model_dropdown="none", lora_alpha_dropdown="0.6", sampler_dropdown="DDIM", ): if "v2" in motion_module_dropdown: inference_config = "configs/inference/inference-v2.yaml" elif "v3" in motion_module_dropdown: inference_config = "configs/inference/inference-v3.yaml" else: inference_config = "configs/inference/inference-v1.yaml" unet = UNet3DConditionModel.from_pretrained_2d( stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.load(inference_config).unet_additional_kwargs ) if is_xformers_available() and torch.cuda.is_available(): unet.enable_xformers_memory_efficient_attention() noise_scheduler_cls = scheduler_dict[sampler_dropdown] noise_scheduler_kwargs = OmegaConf.load(inference_config).noise_scheduler_kwargs if noise_scheduler_cls == EulerDiscreteScheduler: noise_scheduler_kwargs.pop("steps_offset") noise_scheduler_kwargs.pop("clip_sample") elif noise_scheduler_cls == PNDMScheduler: noise_scheduler_kwargs.pop("clip_sample") pipeline = AnimationPipeline( unet=unet, vae=AutoencoderKL.from_pretrained(stable_diffusion_dropdown, subfolder="vae"), text_encoder=CLIPTextModel.from_pretrained(stable_diffusion_dropdown, subfolder="text_encoder"), tokenizer=CLIPTokenizer.from_pretrained(stable_diffusion_dropdown, subfolder="tokenizer"), scheduler=noise_scheduler_cls(**noise_scheduler_kwargs), ) pipeline = load_weights( pipeline, motion_module_path=os.path.join(self.motion_module_dir, motion_module_dropdown), dreambooth_model_path=os.path.join(self.personalized_model_dir, base_model_dropdown) if base_model_dropdown != "" else "", lora_model_path=os.path.join(self.personalized_model_dir, lora_model_dropdown) if lora_model_dropdown != "none" else "", lora_alpha=float(lora_alpha_dropdown), ) pipeline.to(device) self.pipeline = pipeline print("done.") return gr.Dropdown.update() def update_pipeline_alpha( self, stable_diffusion_dropdown, motion_module_dropdown, base_model_dropdown="", lora_model_dropdown="none", lora_alpha_dropdown="0.6", sampler_dropdown="DDIM", ): if lora_model_dropdown == "none": return gr.Slider.update() self.update_pipeline( stable_diffusion_dropdown=stable_diffusion_dropdown, motion_module_dropdown=motion_module_dropdown, base_model_dropdown=base_model_dropdown, lora_model_dropdown=lora_model_dropdown, lora_alpha_dropdown=lora_alpha_dropdown, sampler_dropdown=sampler_dropdown, ) return gr.Slider.update() @torch.no_grad() def animate( self, prompt_textbox, negative_prompt_textbox, sampler_dropdown, sample_step_slider, width_slider, length_slider, height_slider, cfg_scale_slider, seed_textbox, ): if int(seed_textbox) != -1: torch.manual_seed(int(seed_textbox)) else: torch.seed() seed = torch.initial_seed() sample = self.pipeline( prompt_textbox, negative_prompt = negative_prompt_textbox, num_inference_steps = sample_step_slider, guidance_scale = cfg_scale_slider, width = width_slider, height = height_slider, video_length = length_slider, ).videos save_sample_path = os.path.join(self.savedir_sample, f"{sample_idx}.mp4") save_videos_grid(sample, save_sample_path) sample_config = { "prompt": prompt_textbox, "n_prompt": negative_prompt_textbox, "sampler": sampler_dropdown, "num_inference_steps": sample_step_slider, "guidance_scale": cfg_scale_slider, "width": width_slider, "height": height_slider, "video_length": length_slider, "seed": seed } json_str = json.dumps(sample_config, indent=4) with open(os.path.join(self.savedir, "logs.json"), "a") as f: f.write(json_str) f.write("\n\n") return gr.Video.update(value=save_sample_path) controller = AnimateController() def ui(): with gr.Blocks(css=css) as demo: gr.Markdown( """ # AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning Yuwei Guo, Ceyuan Yang✝, Anyi Rao, Zhengyang Liang, Yaohui Wang, Yu Qiao, Maneesh Agrawala, Dahua Lin, Bo Dai (✝Corresponding Author)
[Paper](https://arxiv.org/abs/2307.04725) | [Webpage](https://animatediff.github.io/) | [Github](https://github.com/guoyww/animatediff/) """ ) with gr.Column(variant="panel"): gr.Markdown( """ ### 1. Model Checkpoints """ ) with gr.Row(): stable_diffusion_dropdown = gr.Dropdown( label="Pretrained Model Path", choices=controller.stable_diffusion_list, value=PRETRAINED_SD, interactive=True, ) with gr.Row(): motion_module_dropdown = gr.Dropdown( label="Select motion module", choices=controller.motion_module_list, value=default_motion_module, interactive=True, ) base_model_dropdown = gr.Dropdown( label="Select base Dreambooth model (required)", choices=controller.personalized_model_list, value=default_dreambooth_model, interactive=True, ) lora_model_dropdown = gr.Dropdown( label="Select LoRA model (optional)", choices=["none"] + controller.personalized_model_list, value="none", interactive=True, ) lora_alpha_dropdown = gr.Dropdown( label="LoRA alpha", choices=["0.", "0.2", "0.4", "0.6", "0.8", "1.0"], value="0.6", interactive=True, ) personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") def update_personalized_model(): controller.refresh_stable_diffusion() controller.refresh_personalized_model() return [ gr.Dropdown.update(choices=controller.stable_diffusion_list), gr.Dropdown.update(choices=controller.personalized_model_list), gr.Dropdown.update(choices=["none"] + controller.personalized_model_list) ] personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[stable_diffusion_dropdown, base_model_dropdown, lora_model_dropdown]) with gr.Column(variant="panel"): gr.Markdown( """ ### 2. Configs for AnimateDiff. """ ) prompt_textbox = gr.Textbox(label="Prompt", lines=2, value=default_prompt) negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value=default_n_prompt) with gr.Row().style(equal_height=False): with gr.Column(): with gr.Row(): sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) sample_step_slider = gr.Slider(label="Sampling steps", value=25, minimum=10, maximum=100, step=1) width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64) height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64) length_slider = gr.Slider(label="Animation length (default: 16)", value=16, minimum=8, maximum=24, step=1) cfg_scale_slider = gr.Slider(label="CFG Scale", value=8.0, minimum=0, maximum=20) with gr.Row(): seed_textbox = gr.Textbox(label="Seed (-1 for random seed)", value=default_seed) seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton") seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox]) generate_button = gr.Button(value="Generate", variant='primary') result_video = gr.Video(label="Generated Animation", interactive=False) # update method stable_diffusion_dropdown.change(fn=controller.update_pipeline, inputs=[stable_diffusion_dropdown, motion_module_dropdown, base_model_dropdown, lora_model_dropdown, lora_alpha_dropdown, sampler_dropdown], outputs=[stable_diffusion_dropdown]) motion_module_dropdown.change(fn=controller.update_pipeline, inputs=[stable_diffusion_dropdown, motion_module_dropdown, base_model_dropdown, lora_model_dropdown, lora_alpha_dropdown, sampler_dropdown], outputs=[motion_module_dropdown]) base_model_dropdown.change(fn=controller.update_pipeline, inputs=[stable_diffusion_dropdown, motion_module_dropdown, base_model_dropdown, lora_model_dropdown, lora_alpha_dropdown, sampler_dropdown], outputs=[base_model_dropdown]) lora_model_dropdown.change(fn=controller.update_pipeline, inputs=[stable_diffusion_dropdown, motion_module_dropdown, base_model_dropdown, lora_model_dropdown, lora_alpha_dropdown, sampler_dropdown], outputs=[lora_model_dropdown]) lora_alpha_dropdown.change(fn=controller.update_pipeline_alpha, inputs=[stable_diffusion_dropdown, motion_module_dropdown, base_model_dropdown, lora_model_dropdown, lora_alpha_dropdown, sampler_dropdown], outputs=[lora_alpha_dropdown]) generate_button.click( fn=controller.animate, inputs=[ prompt_textbox, negative_prompt_textbox, sampler_dropdown, sample_step_slider, width_slider, length_slider, height_slider, cfg_scale_slider, seed_textbox, ], outputs=[result_video] ) return demo if __name__ == "__main__": demo = ui() demo.launch(share=True) ================================================ FILE: configs/inference/inference-v1.yaml ================================================ unet_additional_kwargs: use_inflated_groupnorm: false use_motion_module: true motion_module_resolutions: [1,2,4,8] motion_module_mid_block: false motion_module_type: "Vanilla" motion_module_kwargs: num_attention_heads: 8 num_transformer_block: 1 attention_block_types: [ "Temporal_Self", "Temporal_Self" ] temporal_position_encoding: true temporal_attention_dim_div: 1 zero_initialize: true noise_scheduler_kwargs: beta_start: 0.00085 beta_end: 0.012 beta_schedule: "linear" steps_offset: 1 clip_sample: false ================================================ FILE: configs/inference/inference-v2.yaml ================================================ unet_additional_kwargs: use_inflated_groupnorm: true use_motion_module: true motion_module_resolutions: [1,2,4,8] motion_module_mid_block: true motion_module_type: "Vanilla" motion_module_kwargs: num_attention_heads: 8 num_transformer_block: 1 attention_block_types: [ "Temporal_Self", "Temporal_Self" ] temporal_position_encoding: true temporal_attention_dim_div: 1 zero_initialize: true noise_scheduler_kwargs: beta_start: 0.00085 beta_end: 0.012 beta_schedule: "linear" steps_offset: 1 clip_sample: false ================================================ FILE: configs/inference/inference-v3.yaml ================================================ unet_additional_kwargs: use_inflated_groupnorm: true use_motion_module: true motion_module_resolutions: [1,2,4,8] motion_module_mid_block: false motion_module_type: "Vanilla" motion_module_kwargs: num_attention_heads: 8 num_transformer_block: 1 attention_block_types: [ "Temporal_Self", "Temporal_Self" ] temporal_position_encoding: true temporal_attention_dim_div: 1 zero_initialize: true noise_scheduler_kwargs: beta_start: 0.00085 beta_end: 0.012 beta_schedule: "linear" steps_offset: 1 clip_sample: false ================================================ FILE: configs/inference/sparsectrl/image_condition.yaml ================================================ controlnet_additional_kwargs: set_noisy_sample_input_to_zero: true use_simplified_condition_embedding: false conditioning_channels: 3 use_motion_module: true motion_module_resolutions: [1,2,4,8] motion_module_mid_block: false motion_module_type: "Vanilla" motion_module_kwargs: num_attention_heads: 8 num_transformer_block: 1 attention_block_types: [ "Temporal_Self" ] temporal_position_encoding: true temporal_position_encoding_max_len: 32 temporal_attention_dim_div: 1 ================================================ FILE: configs/inference/sparsectrl/latent_condition.yaml ================================================ controlnet_additional_kwargs: set_noisy_sample_input_to_zero: true use_simplified_condition_embedding: true conditioning_channels: 4 use_motion_module: true motion_module_resolutions: [1,2,4,8] motion_module_mid_block: false motion_module_type: "Vanilla" motion_module_kwargs: num_attention_heads: 8 num_transformer_block: 1 attention_block_types: [ "Temporal_Self" ] temporal_position_encoding: true temporal_position_encoding_max_len: 32 temporal_attention_dim_div: 1 ================================================ FILE: configs/prompts/1_animate/1_1_animate_RealisticVision.yaml ================================================ # motion module v3 - dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" lora_model_path: "" inference_config: "configs/inference/inference-v3.yaml" motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" seed: [8893659352891878017, 9317678091797131699, 43242532350557906, 4162228652802886667] steps: 25 guidance_scale: 8 prompt: - "b&w photo of 42 y.o man in black clothes, bald, face, half body, body, high detailed skin, skin pores, coastline, overcast weather, wind, waves, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" - "close up photo of a rabbit, forest, haze, halation, bloom, dramatic atmosphere, centred, rule of thirds, 200mm 1.4f macro shot" - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" - "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain" n_prompt: - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck" - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck" - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, art, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" # motion module v2 - dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" lora_model_path: "" inference_config: "configs/inference/inference-v2.yaml" motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" seed: [8964153601421814582, 10589116295929063558, 13214918285578813247, 3460258020075528001] steps: 25 guidance_scale: 8 prompt: - "b&w photo of 42 y.o man in black clothes, bald, face, half body, body, high detailed skin, skin pores, coastline, overcast weather, wind, waves, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" - "close up photo of a rabbit, forest, haze, halation, bloom, dramatic atmosphere, centred, rule of thirds, 200mm 1.4f macro shot" - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" - "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain" n_prompt: - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck" - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck" - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, art, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" ================================================ FILE: configs/prompts/1_animate/1_2_animate_FilmVelvia.yaml ================================================ # motion module v1_14 - dreambooth_path: "models/DreamBooth_LoRA/majicmixRealistic_v4.safetensors" lora_model_path: "models/DreamBooth_LoRA/leosamsFilmgirlUltra_velvia20Lora.safetensors" lora_alpha: 0.6 inference_config: "configs/inference/inference-v3.yaml" motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" seed: [5726977427157971918, 18368660165286593270, 9350384325017735240, 2097615141377450078] steps: 25 guidance_scale: 8 prompt: - "a woman standing on the side of a road at night,girl, long hair, motor vehicle, car, looking at viewer, ground vehicle, night, hands in pockets, blurry background, coat, black hair, parted lips, bokeh, jacket, brown hair, outdoors, red lips, upper body, artist name" - "dark shot,0mm, portrait quality of a arab man worker,boy, wasteland that stands out vividly against the background of the desert, barren landscape, closeup, moles skin, soft light, sharp, exposure blend, medium shot, bokeh, hdr, high contrast, cinematic, teal and orange5, muted colors, dim colors, soothing tones, low saturation, hyperdetailed, noir" - "fashion photography portrait of 1girl, offshoulder, fluffy short hair, soft light, rim light, beautiful shadow, low key, photorealistic, raw photo, natural skin texture, realistic eye and face details, hyperrealism, ultra high res, 4K, Best quality, masterpiece, necklace, cleavage, in the dark" - "In this lighthearted portrait, a woman is dressed as a fierce warrior, armed with an arsenal of paintbrushes and palette knives. Her war paint is composed of thick, vibrant strokes of color, and her armor is made of paint tubes and paint-splattered canvases. She stands victoriously atop a mountain of conquered blank canvases, with a beautiful, colorful landscape behind her, symbolizing the power of art and creativity. bust Portrait, close-up, Bright and transparent scene lighting, " n_prompt: - "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg" - "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg" - "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg" - "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg" # motion module v1_15 - dreambooth_path: "models/DreamBooth_LoRA/majicmixRealistic_v4.safetensors" lora_model_path: "models/DreamBooth_LoRA/leosamsFilmgirlUltra_velvia20Lora.safetensors" lora_alpha: 0.6 inference_config: "configs/inference/inference-v2.yaml" motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" seed: [2802659149552239028, 12507673598434739425, 1350017671114249824, 2813556755112853775] steps: 25 guidance_scale: 8 prompt: - "a woman standing on the side of a road at night,girl, long hair, motor vehicle, car, looking at viewer, ground vehicle, night, hands in pockets, blurry background, coat, black hair, parted lips, bokeh, jacket, brown hair, outdoors, red lips, upper body, artist name" - ", dark shot,0mm, portrait quality of a arab man worker,boy, wasteland that stands out vividly against the background of the desert, barren landscape, closeup, moles skin, soft light, sharp, exposure blend, medium shot, bokeh, hdr, high contrast, cinematic, teal and orange5, muted colors, dim colors, soothing tones, low saturation, hyperdetailed, noir" - "fashion photography portrait of 1girl, offshoulder, fluffy short hair, soft light, rim light, beautiful shadow, low key, photorealistic, raw photo, natural skin texture, realistic eye and face details, hyperrealism, ultra high res, 4K, Best quality, masterpiece, necklace, cleavage, in the dark" - "In this lighthearted portrait, a woman is dressed as a fierce warrior, armed with an arsenal of paintbrushes and palette knives. Her war paint is composed of thick, vibrant strokes of color, and her armor is made of paint tubes and paint-splattered canvases. She stands victoriously atop a mountain of conquered blank canvases, with a beautiful, colorful landscape behind her, symbolizing the power of art and creativity. bust Portrait, close-up, Bright and transparent scene lighting, " n_prompt: - "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg" - "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg" - "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg" - "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg" ================================================ FILE: configs/prompts/1_animate/1_3_animate_ToonYou.yaml ================================================ # motion module v3 - dreambooth_path: "models/DreamBooth_LoRA/toonyou_beta3.safetensors" lora_model_path: "" inference_config: "configs/inference/inference-v3.yaml" motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" seed: [12192490710448890259, 12238800062118732365, 13226337751639812613, 16431231374396590344] steps: 25 guidance_scale: 7.5 prompt: - "best quality, masterpiece, 1girl, looking at viewer, blurry background, upper body, contemporary, dress" - "masterpiece, best quality, 1girl, solo, cherry blossoms, hanami, pink flower, white flower, spring season, wisteria, petals, flower, plum blossoms, outdoors, falling petals, white hair, black eyes," - "best quality, masterpiece, 1boy, formal, abstract, looking at viewer, masculine, marble pattern" - "best quality, masterpiece, 1girl, cloudy sky, dandelion, contrapposto, alternate hairstyle," n_prompt: - "worst quality, low quality, letterboxed" # motion module v2 - dreambooth_path: "models/DreamBooth_LoRA/toonyou_beta3.safetensors" lora_model_path: "" inference_config: "configs/inference/inference-v2.yaml" motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" seed: [2362336635702964940, 8149279371559927917, 1487371078234460867, 17554906328875363976] steps: 25 guidance_scale: 7.5 prompt: - "best quality, masterpiece, 1girl, looking at viewer, blurry background, upper body, contemporary, dress" - "masterpiece, best quality, 1girl, solo, cherry blossoms, hanami, pink flower, white flower, spring season, wisteria, petals, flower, plum blossoms, outdoors, falling petals, white hair, black eyes," - "best quality, masterpiece, 1boy, formal, abstract, looking at viewer, masculine, marble pattern" - "best quality, masterpiece, 1girl, cloudy sky, dandelion, contrapposto, alternate hairstyle," n_prompt: - "worst quality, low quality, letterboxed" ================================================ FILE: configs/prompts/1_animate/1_4_animate_MajicMix.yaml ================================================ # motion module v1_14 - dreambooth_path: "models/DreamBooth_LoRA/majicmixRealistic_v5Preview.safetensors" lora_model_path: "" inference_config: "configs/inference/inference-v3.yaml" motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" seed: [11413213594134208212, 11357183503136546592, 7315638361411279346, 10191753182015596097] steps: 25 guidance_scale: 8 prompt: - "1girl, offshoulder, light smile, shiny skin best quality, masterpiece, photorealistic" - "best quality, masterpiece, photorealistic, 1boy, 50 years old beard, dramatic lighting" - "best quality, masterpiece, photorealistic, 1girl, light smile, shirt with collars, waist up, dramatic lighting, from below" - "male, man, beard, bodybuilder, skinhead,cold face, tough guy, cowboyshot, tattoo, french windows, luxury hotel masterpiece, best quality, photorealistic" n_prompt: - "ng_deepnegative_v1_75t, badhandv4, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, watermark, moles" - "nsfw, ng_deepnegative_v1_75t,badhandv4, worst quality, low quality, normal quality, lowres,watermark, monochrome" - "nsfw, ng_deepnegative_v1_75t,badhandv4, worst quality, low quality, normal quality, lowres,watermark, monochrome" - "nude, nsfw, ng_deepnegative_v1_75t, badhandv4, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, monochrome, grayscale watermark, moles, people" # motion module v1_15 - dreambooth_path: "models/DreamBooth_LoRA/majicmixRealistic_v5Preview.safetensors" lora_model_path: "" inference_config: "configs/inference/inference-v2.yaml" motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" seed: [3364626746360550707, 10635741750919791646, 3130334860012077860, 1530101570151479035] steps: 25 guidance_scale: 8 prompt: - "1girl, offshoulder, light smile, shiny skin best quality, masterpiece, photorealistic" - "best quality, masterpiece, photorealistic, 1boy, 50 years old beard, dramatic lighting" - "best quality, masterpiece, photorealistic, 1girl, light smile, shirt with collars, waist up, dramatic lighting, from below" - "male, man, beard, bodybuilder, skinhead,cold face, tough guy, cowboyshot, tattoo, french windows, luxury hotel masterpiece, best quality, photorealistic" n_prompt: - "ng_deepnegative_v1_75t, badhandv4, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, watermark, moles" - "nsfw, ng_deepnegative_v1_75t,badhandv4, worst quality, low quality, normal quality, lowres,watermark, monochrome" - "nsfw, ng_deepnegative_v1_75t,badhandv4, worst quality, low quality, normal quality, lowres,watermark, monochrome" - "nude, nsfw, ng_deepnegative_v1_75t, badhandv4, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, monochrome, grayscale watermark, moles, people" ================================================ FILE: configs/prompts/1_animate/1_5_animate_RcnzCartoon.yaml ================================================ # motion module v1_14 - dreambooth_path: "models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors" lora_model_path: "" inference_config: "configs/inference/inference-v3.yaml" motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" seed: [8085079222822100088, 15493278891844617620, 17384760730172253253, 3896292336733512420] steps: 25 guidance_scale: 8 prompt: - "Jane Eyre with headphones, natural skin texture,4mm,k textures, soft cinematic light, adobe lightroom, photolab, hdr, intricate, elegant, highly detailed, sharp focus, cinematic look, soothing tones, insane details, intricate details, hyperdetailed, low contrast, soft cinematic light, dim colors, exposure blend, hdr, faded" - "close up Portrait photo of muscular bearded guy in a worn mech suit, light bokeh, intricate, steel metal [rust], elegant, sharp focus, photo by greg rutkowski, soft lighting, vibrant colors, masterpiece, streets, detailed face" - "absurdres, photorealistic, masterpiece, a 30 year old man with gold framed, aviator reading glasses and a black hooded jacket and a beard, professional photo, a character portrait, altermodern, detailed eyes, detailed lips, detailed face, grey eyes" - "a golden labrador, warm vibrant colours, natural lighting, dappled lighting, diffused lighting, absurdres, highres,k, uhd, hdr, rtx, unreal, octane render, RAW photo, photorealistic, global illumination, subsurface scattering" n_prompt: - "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation" - "nude, cross eyed, tongue, open mouth, inside, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, red eyes, muscular" - "easynegative, cartoon, anime, sketches, necklace, earrings worst quality, low quality, normal quality, bad anatomy, bad hands, shiny skin, error, missing fingers, extra digit, fewer digits, jpeg artifacts, signature, watermark, username, blurry, chubby, anorectic, bad eyes, old, wrinkled skin, red skin, photograph By bad artist -neg, big eyes, muscular face," - "beard, EasyNegative, lowres, chromatic aberration, depth of field, motion blur, blurry, bokeh, bad quality, worst quality, multiple arms, badhand" # motion module v1_15 - dreambooth_path: "models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors" lora_model_path: "" inference_config: "configs/inference/inference-v2.yaml" motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" seed: [10087939632512181573, 6440765888009826001, 4292543217695451092, 14003068315619866795] steps: 25 guidance_scale: 8 prompt: - "Jane Eyre with headphones, natural skin texture,4mm,k textures, soft cinematic light, adobe lightroom, photolab, hdr, intricate, elegant, highly detailed, sharp focus, cinematic look, soothing tones, insane details, intricate details, hyperdetailed, low contrast, soft cinematic light, dim colors, exposure blend, hdr, faded" - "close up Portrait photo of muscular bearded guy in a worn mech suit, light bokeh, intricate, steel metal [rust], elegant, sharp focus, photo by greg rutkowski, soft lighting, vibrant colors, masterpiece, streets, detailed face" - "absurdres, photorealistic, masterpiece, a 30 year old man with gold framed, aviator reading glasses and a black hooded jacket and a beard, professional photo, a character portrait, altermodern, detailed eyes, detailed lips, detailed face, grey eyes" - "a golden labrador, warm vibrant colours, natural lighting, dappled lighting, diffused lighting, absurdres, highres,k, uhd, hdr, rtx, unreal, octane render, RAW photo, photorealistic, global illumination, subsurface scattering" n_prompt: - "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation" - "nude, cross eyed, tongue, open mouth, inside, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, red eyes, muscular" - "easynegative, cartoon, anime, sketches, necklace, earrings worst quality, low quality, normal quality, bad anatomy, bad hands, shiny skin, error, missing fingers, extra digit, fewer digits, jpeg artifacts, signature, watermark, username, blurry, chubby, anorectic, bad eyes, old, wrinkled skin, red skin, photograph By bad artist -neg, big eyes, muscular face," - "beard, EasyNegative, lowres, chromatic aberration, depth of field, motion blur, blurry, bokeh, bad quality, worst quality, multiple arms, badhand" ================================================ FILE: configs/prompts/1_animate/1_6_animate_Lyriel.yaml ================================================ # motion module v1_14 - dreambooth_path: "models/DreamBooth_LoRA/lyriel_v16.safetensors" lora_model_path: "" inference_config: "configs/inference/inference-v3.yaml" motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" seed: [10917152860782582783, 6399018107401806238, 15875751942533906793, 6653196880059936551] steps: 25 guidance_scale: 8 prompt: - "dark shot, epic realistic, portrait of halo, sunglasses, blue eyes, tartan scarf, white hair by atey ghailan, by greg rutkowski, by greg tocchini, by james gilleard, by joe fenton, by kaethe butcher, gradient yellow, black, brown and magenta color scheme, grunge aesthetic!!! graffiti tag wall background, art by greg rutkowski and artgerm, soft cinematic light, adobe lightroom, photolab, hdr, intricate, highly detailed, depth of field, faded, neutral colors, hdr, muted colors, hyperdetailed, artstation, cinematic, warm lights, dramatic light, intricate details, complex background, rutkowski, teal and orange" - "A forbidden castle high up in the mountains, pixel art, intricate details2, hdr, intricate details, hyperdetailed5, natural skin texture, hyperrealism, soft light, sharp, game art, key visual, surreal" - "dark theme, medieval portrait of a man sharp features, grim, cold stare, dark colors, Volumetric lighting, baroque oil painting by Greg Rutkowski, Artgerm, WLOP, Alphonse Mucha dynamic lighting hyperdetailed intricately detailed, hdr, muted colors, complex background, hyperrealism, hyperdetailed, amandine van ray" - "As I have gone alone in there and with my treasures bold, I can keep my secret where and hint of riches new and old. Begin it where warm waters halt and take it in a canyon down, not far but too far to walk, put in below the home of brown." n_prompt: - "3d, cartoon, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, young, loli, elf, 3d, illustration" - "3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular" - "dof, grayscale, black and white, bw, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular,badhandsv5-neg, By bad artist -neg 1, monochrome" - "holding an item, cowboy, hat, cartoon, 3d, disfigured, bad art, deformed,extra limbs,close up,b&w, wierd colors, blurry, duplicate, morbid, mutilated, [out of frame], extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, out of frame, ugly, extra limbs, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, Photoshop, video game, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, body out of frame, blurry, bad art, bad anatomy, 3d render" # motion module v1_15 - dreambooth_path: "models/DreamBooth_LoRA/lyriel_v16.safetensors" lora_model_path: "" inference_config: "configs/inference/inference-v2.yaml" motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" seed: [9217823730598265840, 10815047877257294769, 15033600051075248739, 3730216622332453211] steps: 25 guidance_scale: 8 prompt: - "dark shot, epic realistic, portrait of halo, sunglasses, blue eyes, tartan scarf, white hair by atey ghailan, by greg rutkowski, by greg tocchini, by james gilleard, by joe fenton, by kaethe butcher, gradient yellow, black, brown and magenta color scheme, grunge aesthetic!!! graffiti tag wall background, art by greg rutkowski and artgerm, soft cinematic light, adobe lightroom, photolab, hdr, intricate, highly detailed, depth of field, faded, neutral colors, hdr, muted colors, hyperdetailed, artstation, cinematic, warm lights, dramatic light, intricate details, complex background, rutkowski, teal and orange" - "A forbidden castle high up in the mountains, pixel art, intricate details2, hdr, intricate details, hyperdetailed5, natural skin texture, hyperrealism, soft light, sharp, game art, key visual, surreal" - "dark theme, medieval portrait of a man sharp features, grim, cold stare, dark colors, Volumetric lighting, baroque oil painting by Greg Rutkowski, Artgerm, WLOP, Alphonse Mucha dynamic lighting hyperdetailed intricately detailed, hdr, muted colors, complex background, hyperrealism, hyperdetailed, amandine van ray" - "As I have gone alone in there and with my treasures bold, I can keep my secret where and hint of riches new and old. Begin it where warm waters halt and take it in a canyon down, not far but too far to walk, put in below the home of brown." n_prompt: - "3d, cartoon, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, young, loli, elf, 3d, illustration" - "3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular" - "dof, grayscale, black and white, bw, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular,badhandsv5-neg, By bad artist -neg 1, monochrome" - "holding an item, cowboy, hat, cartoon, 3d, disfigured, bad art, deformed,extra limbs,close up,b&w, wierd colors, blurry, duplicate, morbid, mutilated, [out of frame], extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, out of frame, ugly, extra limbs, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, Photoshop, video game, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, body out of frame, blurry, bad art, bad anatomy, 3d render" ================================================ FILE: configs/prompts/1_animate/1_7_animate_Tusun.yaml ================================================ # motion module v1_14 - dreambooth_path: "models/DreamBooth_LoRA/leosamsHelloworldXL_filmGrain20.safetensors" lora_model_path: "models/DreamBooth_LoRA/TUSUN.safetensors" lora_alpha: 0.6 inference_config: "configs/inference/inference-v3.yaml" motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" seed: [7107114461349773341, 17169636352587613974, 9844335976427375435, 6372518434592560610] steps: 25 guidance_scale: 8 prompt: - "tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing" - "cute tusun with a blurry background, black background, simple background, signature, face, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing" - "cut tusuncub walking in the snow, blurry, looking at viewer, depth of field, blurry background, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing" - "character design, cyberpunk tusun kitten wearing astronaut suit, sci-fic, realistic eye color and details, fluffy, big head, science fiction, communist ideology, Cyborg, fantasy, intense angle, soft lighting, photograph, 4k, hyper detailed, portrait wallpaper, realistic, photo-realistic, DSLR, 24 Megapixels, Full Frame, vibrant details, octane render, finely detail, best quality, incredibly absurdres, robotic parts, rim light, vibrant details, luxurious cyberpunk, hyperrealistic, cable electric wires, microchip, full body" n_prompt: - "worst quality, low quality, deformed, distorted, disfigured, bad eyes, bad anatomy, disconnected limbs, wrong body proportions, low quality, worst quality, text, watermark, signatre, logo, illustration, painting, cartoons, ugly, easy_negative" # motion module v1_15 - dreambooth_path: "models/DreamBooth_LoRA/leosamsHelloworldXL_filmGrain20.safetensors" lora_model_path: "models/DreamBooth_LoRA/TUSUN.safetensors" lora_alpha: 0.6 inference_config: "configs/inference/inference-v2.yaml" motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" seed: [8605999221232672724, 110148213803975296, 9191327304973552413, 174075196208604916] steps: 25 guidance_scale: 8 prompt: - "tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing" - "cute tusun with a blurry background, black background, simple background, signature, face, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing" - "cut tusuncub walking in the snow, blurry, looking at viewer, depth of field, blurry background, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing" - "character design, cyberpunk tusun kitten wearing astronaut suit, sci-fic, realistic eye color and details, fluffy, big head, science fiction, communist ideology, Cyborg, fantasy, intense angle, soft lighting, photograph, 4k, hyper detailed, portrait wallpaper, realistic, photo-realistic, DSLR, 24 Megapixels, Full Frame, vibrant details, octane render, finely detail, best quality, incredibly absurdres, robotic parts, rim light, vibrant details, luxurious cyberpunk, hyperrealistic, cable electric wires, microchip, full body" n_prompt: - "worst quality, low quality, deformed, distorted, disfigured, bad eyes, bad anatomy, disconnected limbs, wrong body proportions, low quality, worst quality, text, watermark, signatre, logo, illustration, painting, cartoons, ugly, easy_negative" ================================================ FILE: configs/prompts/2_motionlora/2_motionlora_RealisticVision.yaml ================================================ # ZoomIn - inference_config: "configs/inference/inference-v2.yaml" motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" motion_module_lora_configs: - path: "models/MotionLoRA/v2_lora_ZoomIn.ckpt" alpha: 1.0 dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" lora_model_path: "" seed: 43242532350557906 steps: 25 guidance_scale: 7.5 prompt: - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" n_prompt: - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" # ZoomOut - inference_config: "configs/inference/inference-v2.yaml" motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" motion_module_lora_configs: - path: "models/MotionLoRA/v2_lora_ZoomOut.ckpt" alpha: 1.0 dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" lora_model_path: "" seed: 43242532350557906 steps: 25 guidance_scale: 7.5 prompt: - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" n_prompt: - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" # PanLeft - inference_config: "configs/inference/inference-v2.yaml" motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" motion_module_lora_configs: - path: "models/MotionLoRA/v2_lora_PanLeft.ckpt" alpha: 1.0 dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" lora_model_path: "" seed: 43242532350557906 steps: 25 guidance_scale: 7.5 prompt: - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" n_prompt: - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" # PanRight - inference_config: "configs/inference/inference-v2.yaml" motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" motion_module_lora_configs: - path: "models/MotionLoRA/v2_lora_PanRight.ckpt" alpha: 1.0 dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" lora_model_path: "" seed: 43242532350557906 steps: 25 guidance_scale: 7.5 prompt: - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" n_prompt: - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" # TiltUp - inference_config: "configs/inference/inference-v2.yaml" motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" motion_module_lora_configs: - path: "models/MotionLoRA/v2_lora_TiltUp.ckpt" alpha: 1.0 dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" lora_model_path: "" seed: 43242532350557906 steps: 25 guidance_scale: 7.5 prompt: - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" n_prompt: - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" # TiltDown - inference_config: "configs/inference/inference-v2.yaml" motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" motion_module_lora_configs: - path: "models/MotionLoRA/v2_lora_TiltDown.ckpt" alpha: 1.0 dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" lora_model_path: "" seed: 43242532350557906 steps: 25 guidance_scale: 7.5 prompt: - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" n_prompt: - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" # RollingAnticlockwise - inference_config: "configs/inference/inference-v2.yaml" motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" motion_module_lora_configs: - path: "models/MotionLoRA/v2_lora_RollingAnticlockwise.ckpt" alpha: 1.0 dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" lora_model_path: "" seed: 43242532350557906 steps: 25 guidance_scale: 7.5 prompt: - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" n_prompt: - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" # RollingClockwise - inference_config: "configs/inference/inference-v2.yaml" motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" motion_module_lora_configs: - path: "models/MotionLoRA/v2_lora_RollingClockwise.ckpt" alpha: 1.0 dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" lora_model_path: "" seed: 43242532350557906 steps: 25 guidance_scale: 7.5 prompt: - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" n_prompt: - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" ================================================ FILE: configs/prompts/3_sparsectrl/3_1_sparsectrl_i2v.yaml ================================================ # 1-animation - adapter_lora_scale: 1.0 adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt" dreambooth_path: "" inference_config: "configs/inference/inference-v3.yaml" motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" controlnet_config: "configs/inference/sparsectrl/latent_condition.yaml" controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt" H: 256 W: 384 seed: [123,234] steps: 25 guidance_scale: 8.5 controlnet_image_indexs: [0] controlnet_images: - "__assets__/demos/image/painting.png" prompt: - an oil painting of a sailboat in the ocean wave - an oil painting of a sailboat in the ocean wave n_prompt: - "worst quality, low quality, letterboxed" # 2-interpolation - adapter_lora_scale: 1.0 adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt" dreambooth_path: "" inference_config: "configs/inference/inference-v3.yaml" motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" controlnet_config: "configs/inference/sparsectrl/latent_condition.yaml" controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt" H: 256 W: 384 seed: [123,234] steps: 25 guidance_scale: 8.5 controlnet_image_indexs: [0,-1] controlnet_images: - "__assets__/demos/image/interpolation_1.png" - "__assets__/demos/image/interpolation_2.png" prompt: - "aerial view, beautiful forest, autumn, 4k, high quality" - "aerial view, beautiful forest, autumn, 4k, high quality" n_prompt: - "worst quality, low quality, letterboxed" # 3-interpolation - adapter_lora_scale: 1.0 adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt" dreambooth_path: "" inference_config: "configs/inference/inference-v3.yaml" motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" controlnet_config: "configs/inference/sparsectrl/latent_condition.yaml" controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt" H: 256 W: 384 seed: [123,234] steps: 25 guidance_scale: 8.5 controlnet_image_indexs: [0,5,10,15] controlnet_images: - "__assets__/demos/image/low_fps_1.png" - "__assets__/demos/image/low_fps_2.png" - "__assets__/demos/image/low_fps_3.png" - "__assets__/demos/image/low_fps_4.png" prompt: - "two people holding hands in a field with wind turbines in the background" - "two people holding hands in a field with wind turbines in the background" n_prompt: - "worst quality, low quality, letterboxed" # 3-prediction - adapter_lora_scale: 1.0 adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt" dreambooth_path: "" inference_config: "configs/inference/inference-v3.yaml" motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" controlnet_config: "configs/inference/sparsectrl/latent_condition.yaml" controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt" H: 256 W: 384 seed: [123,234] steps: 25 guidance_scale: 8.5 controlnet_image_indexs: [0,1,2,3] controlnet_images: - "__assets__/demos/image/prediction_1.png" - "__assets__/demos/image/prediction_2.png" - "__assets__/demos/image/prediction_3.png" - "__assets__/demos/image/prediction_4.png" prompt: - "an astronaut is flying in the space, 4k, high resolution" - "an astronaut is flying in the space, 4k, high resolution" n_prompt: - "worst quality, low quality, letterboxed" ================================================ FILE: configs/prompts/3_sparsectrl/3_2_sparsectrl_rgb_RealisticVision.yaml ================================================ # animation-1 - adapter_lora_scale: 1.0 adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt" dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" inference_config: "configs/inference/inference-v3.yaml" motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" controlnet_config: "configs/inference/sparsectrl/latent_condition.yaml" controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt" seed: -1 steps: 25 guidance_scale: 8.5 controlnet_image_indexs: [0] controlnet_images: - "__assets__/demos/image/RealisticVision_firework.png" prompt: - "closeup face photo of man in black clothes, night city street, bokeh, fireworks in background" - "closeup face photo of man in black clothes, night city street, bokeh, fireworks in background" n_prompt: - "worst quality, low quality, letterboxed" # animation-2 - adapter_lora_scale: 1.0 adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt" dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" inference_config: "configs/inference/inference-v3.yaml" motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" controlnet_config: "configs/inference/sparsectrl/latent_condition.yaml" controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt" seed: -1 steps: 25 guidance_scale: 8.5 controlnet_image_indexs: [0] controlnet_images: - "__assets__/demos/image/RealisticVision_sunset.png" prompt: - "masterpiece, bestquality, highlydetailed, ultradetailed, sunset, orange sky, warm lighting, fishing boats, ocean waves, seagulls, rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, golden hour, coastal landscape, seaside scenery" - "masterpiece, bestquality, highlydetailed, ultradetailed, sunset, orange sky, warm lighting, fishing boats, ocean waves, seagulls, rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, golden hour, coastal landscape, seaside scenery" n_prompt: - "worst quality, low quality, letterboxed" ================================================ FILE: configs/prompts/3_sparsectrl/3_3_sparsectrl_sketch_RealisticVision.yaml ================================================ # 1-sketch-to-video - adapter_lora_scale: 1.0 adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt" dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" inference_config: "configs/inference/inference-v3.yaml" motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" controlnet_config: "configs/inference/sparsectrl/image_condition.yaml" controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_scribble.ckpt" seed: -1 steps: 25 guidance_scale: 8.5 controlnet_image_indexs: [0] controlnet_images: - "__assets__/demos/scribble/scribble_1.png" prompt: - "a back view of a boy, standing on the ground, looking at the sky, sunlight, masterpieces" - "a back view of a boy, standing on the ground, looking at the sky, clouds, sunset, orange sky, beautiful sunlight, masterpieces" n_prompt: - "worst quality, low quality, letterboxed" # 2-storyboarding - adapter_lora_scale: 1.0 adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt" dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" inference_config: "configs/inference/inference-v3.yaml" motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" controlnet_config: "configs/inference/sparsectrl/image_condition.yaml" controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_scribble.ckpt" seed: -1 steps: 25 guidance_scale: 8.5 controlnet_image_indexs: [0,8,15] controlnet_images: - "__assets__/demos/scribble/scribble_2_1.png" - "__assets__/demos/scribble/scribble_2_2.png" - "__assets__/demos/scribble/scribble_2_3.png" prompt: - "an aerial view of a modern city, sunlight, day time, masterpiece, high quality" - "an aerial view of a cyberpunk city, night time, neon lights, masterpiece, high quality" n_prompt: - "worst quality, low quality, letterboxed" ================================================ FILE: configs/training/v1/image_finetune.yaml ================================================ image_finetune: true output_dir: "outputs" pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5" noise_scheduler_kwargs: num_train_timesteps: 1000 beta_start: 0.00085 beta_end: 0.012 beta_schedule: "scaled_linear" steps_offset: 1 clip_sample: false train_data: csv_path: "path_to_csv_file" video_folder: "path_to_video_foler" sample_size: 256 validation_data: prompts: - "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons." - "A drone view of celebration with Christma tree and fireworks, starry sky - background." - "Robot dancing in times square." - "Pacific coast, carmel by the sea ocean and waves." num_inference_steps: 25 guidance_scale: 8. trainable_modules: - "." unet_checkpoint_path: "" learning_rate: 1.e-5 train_batch_size: 50 max_train_epoch: -1 max_train_steps: 100 checkpointing_epochs: -1 checkpointing_steps: 60 validation_steps: 5000 validation_steps_tuple: [2, 50] global_seed: 42 mixed_precision_training: true enable_xformers_memory_efficient_attention: True is_debug: False ================================================ FILE: configs/training/v1/training.yaml ================================================ image_finetune: false output_dir: "outputs" pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5" unet_additional_kwargs: use_motion_module : true motion_module_resolutions : [ 1,2,4,8 ] unet_use_cross_frame_attention : false unet_use_temporal_attention : false motion_module_type: Vanilla motion_module_kwargs: num_attention_heads : 8 num_transformer_block : 1 attention_block_types : [ "Temporal_Self", "Temporal_Self" ] temporal_position_encoding : true temporal_position_encoding_max_len : 24 temporal_attention_dim_div : 1 zero_initialize : true noise_scheduler_kwargs: num_train_timesteps: 1000 beta_start: 0.00085 beta_end: 0.012 beta_schedule: "linear" steps_offset: 1 clip_sample: false train_data: csv_path: "path_to_csv_file" video_folder: "path_to_video_foler" sample_size: 256 sample_stride: 4 sample_n_frames: 16 validation_data: prompts: - "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons." - "A drone view of celebration with Christma tree and fireworks, starry sky - background." - "Robot dancing in times square." - "Pacific coast, carmel by the sea ocean and waves." num_inference_steps: 25 guidance_scale: 8. trainable_modules: - "motion_modules." unet_checkpoint_path: "" learning_rate: 1.e-4 train_batch_size: 4 max_train_epoch: -1 max_train_steps: 100 checkpointing_epochs: -1 checkpointing_steps: 60 validation_steps: 5000 validation_steps_tuple: [2, 50] global_seed: 42 mixed_precision_training: true enable_xformers_memory_efficient_attention: True is_debug: False ================================================ FILE: requirements.txt ================================================ torch==2.3.1 torchvision==0.18.1 diffusers==0.11.1 transformers==4.25.1 xformers==0.0.27 imageio==2.27.0 imageio-ffmpeg==0.4.9 decord==0.6.0 omegaconf==2.3.0 gradio==3.36.1 safetensors einops wandb ================================================ FILE: scripts/animate.py ================================================ import argparse import datetime import inspect import os from omegaconf import OmegaConf import torch import torchvision.transforms as transforms import diffusers from diffusers import AutoencoderKL, DDIMScheduler from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from animatediff.models.unet import UNet3DConditionModel from animatediff.models.sparse_controlnet import SparseControlNetModel from animatediff.pipelines.pipeline_animation import AnimationPipeline from animatediff.utils.util import save_videos_grid from animatediff.utils.util import load_weights, auto_download from diffusers.utils.import_utils import is_xformers_available from einops import rearrange, repeat import csv, pdb, glob, math from pathlib import Path from PIL import Image import numpy as np @torch.no_grad() def main(args): *_, func_args = inspect.getargvalues(inspect.currentframe()) func_args = dict(func_args) time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") savedir = f"samples/{Path(args.config).stem}-{time_str}" os.makedirs(savedir) config = OmegaConf.load(args.config) samples = [] # create validation pipeline tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder").cuda() vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").cuda() sample_idx = 0 for model_idx, model_config in enumerate(config): model_config.W = model_config.get("W", args.W) model_config.H = model_config.get("H", args.H) model_config.L = model_config.get("L", args.L) inference_config = OmegaConf.load(model_config.get("inference_config", args.inference_config)) unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)).cuda() # load controlnet model controlnet = controlnet_images = None if model_config.get("controlnet_path", "") != "": assert model_config.get("controlnet_images", "") != "" assert model_config.get("controlnet_config", "") != "" unet.config.num_attention_heads = 8 unet.config.projection_class_embeddings_input_dim = None controlnet_config = OmegaConf.load(model_config.controlnet_config) controlnet = SparseControlNetModel.from_unet(unet, controlnet_additional_kwargs=controlnet_config.get("controlnet_additional_kwargs", {})) auto_download(model_config.controlnet_path, is_dreambooth_lora=False) print(f"loading controlnet checkpoint from {model_config.controlnet_path} ...") controlnet_state_dict = torch.load(model_config.controlnet_path, map_location="cpu") controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict controlnet_state_dict = {name: param for name, param in controlnet_state_dict.items() if "pos_encoder.pe" not in name} controlnet_state_dict.pop("animatediff_config", "") controlnet.load_state_dict(controlnet_state_dict) controlnet.cuda() image_paths = model_config.controlnet_images if isinstance(image_paths, str): image_paths = [image_paths] print(f"controlnet image paths:") for path in image_paths: print(path) assert len(image_paths) <= model_config.L image_transforms = transforms.Compose([ transforms.RandomResizedCrop( (model_config.H, model_config.W), (1.0, 1.0), ratio=(model_config.W/model_config.H, model_config.W/model_config.H) ), transforms.ToTensor(), ]) if model_config.get("normalize_condition_images", False): def image_norm(image): image = image.mean(dim=0, keepdim=True).repeat(3,1,1) image -= image.min() image /= image.max() return image else: image_norm = lambda x: x controlnet_images = [image_norm(image_transforms(Image.open(path).convert("RGB"))) for path in image_paths] os.makedirs(os.path.join(savedir, "control_images"), exist_ok=True) for i, image in enumerate(controlnet_images): Image.fromarray((255. * (image.numpy().transpose(1,2,0))).astype(np.uint8)).save(f"{savedir}/control_images/{i}.png") controlnet_images = torch.stack(controlnet_images).unsqueeze(0).cuda() controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w") if controlnet.use_simplified_condition_embedding: num_controlnet_images = controlnet_images.shape[2] controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w") controlnet_images = vae.encode(controlnet_images * 2. - 1.).latent_dist.sample() * 0.18215 controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images) # set xformers if is_xformers_available() and (not args.without_xformers): unet.enable_xformers_memory_efficient_attention() if controlnet is not None: controlnet.enable_xformers_memory_efficient_attention() pipeline = AnimationPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), ).to("cuda") pipeline = load_weights( pipeline, # motion module motion_module_path = model_config.get("motion_module", ""), motion_module_lora_configs = model_config.get("motion_module_lora_configs", []), # domain adapter adapter_lora_path = model_config.get("adapter_lora_path", ""), adapter_lora_scale = model_config.get("adapter_lora_scale", 1.0), # image layers dreambooth_model_path = model_config.get("dreambooth_path", ""), lora_model_path = model_config.get("lora_model_path", ""), lora_alpha = model_config.get("lora_alpha", 0.8), ).to("cuda") prompts = model_config.prompt n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt random_seeds = model_config.get("seed", [-1]) random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds config[model_idx].random_seed = [] for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)): # manually set random seed for reproduction if random_seed != -1: torch.manual_seed(random_seed) else: torch.seed() config[model_idx].random_seed.append(torch.initial_seed()) print(f"current seed: {torch.initial_seed()}") print(f"sampling {prompt} ...") sample = pipeline( prompt, negative_prompt = n_prompt, num_inference_steps = model_config.steps, guidance_scale = model_config.guidance_scale, width = model_config.W, height = model_config.H, video_length = model_config.L, controlnet_images = controlnet_images, controlnet_image_index = model_config.get("controlnet_image_indexs", [0]), ).videos samples.append(sample) prompt = "-".join((prompt.replace("/", "").split(" ")[:10])) save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif") print(f"save to {savedir}/sample/{prompt}.gif") sample_idx += 1 samples = torch.concat(samples) save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4) OmegaConf.save(config, f"{savedir}/config.yaml") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--pretrained-model-path", type=str, default="runwayml/stable-diffusion-v1-5") parser.add_argument("--inference-config", type=str, default="configs/inference/inference-v1.yaml") parser.add_argument("--config", type=str, required=True) parser.add_argument("--L", type=int, default=16 ) parser.add_argument("--W", type=int, default=512) parser.add_argument("--H", type=int, default=512) parser.add_argument("--without-xformers", action="store_true") args = parser.parse_args() main(args) ================================================ FILE: train.py ================================================ import os import math import wandb import random import logging import inspect import argparse import datetime import subprocess from pathlib import Path from tqdm.auto import tqdm from einops import rearrange from omegaconf import OmegaConf from safetensors import safe_open from typing import Dict, Optional, Tuple import torch import torchvision import torch.nn.functional as F import torch.distributed as dist from torch.optim.swa_utils import AveragedModel from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP import diffusers from diffusers import AutoencoderKL, DDIMScheduler from diffusers.models import UNet2DConditionModel from diffusers.pipelines import StableDiffusionPipeline from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version from diffusers.utils.import_utils import is_xformers_available import transformers from transformers import CLIPTextModel, CLIPTokenizer from animatediff.data.dataset import WebVid10M from animatediff.models.unet import UNet3DConditionModel from animatediff.pipelines.pipeline_animation import AnimationPipeline from animatediff.utils.util import save_videos_grid, zero_rank_print def init_dist(launcher="slurm", backend='nccl', port=29500, **kwargs): """Initializes distributed environment.""" if launcher == 'pytorch': rank = int(os.environ['RANK']) num_gpus = torch.cuda.device_count() local_rank = rank % num_gpus torch.cuda.set_device(local_rank) dist.init_process_group(backend=backend, **kwargs) elif launcher == 'slurm': proc_id = int(os.environ['SLURM_PROCID']) ntasks = int(os.environ['SLURM_NTASKS']) node_list = os.environ['SLURM_NODELIST'] num_gpus = torch.cuda.device_count() local_rank = proc_id % num_gpus torch.cuda.set_device(local_rank) addr = subprocess.getoutput( f'scontrol show hostname {node_list} | head -n1') os.environ['MASTER_ADDR'] = addr os.environ['WORLD_SIZE'] = str(ntasks) os.environ['RANK'] = str(proc_id) port = os.environ.get('PORT', port) os.environ['MASTER_PORT'] = str(port) dist.init_process_group(backend=backend) zero_rank_print(f"proc_id: {proc_id}; local_rank: {local_rank}; ntasks: {ntasks}; node_list: {node_list}; num_gpus: {num_gpus}; addr: {addr}; port: {port}") else: raise NotImplementedError(f'Not implemented launcher type: `{launcher}`!') return local_rank def main( image_finetune: bool, name: str, use_wandb: bool, launcher: str, output_dir: str, pretrained_model_path: str, train_data: Dict, validation_data: Dict, cfg_random_null_text: bool = True, cfg_random_null_text_ratio: float = 0.1, unet_checkpoint_path: str = "", unet_additional_kwargs: Dict = {}, ema_decay: float = 0.9999, noise_scheduler_kwargs = None, max_train_epoch: int = -1, max_train_steps: int = 100, validation_steps: int = 100, validation_steps_tuple: Tuple = (-1,), learning_rate: float = 3e-5, scale_lr: bool = False, lr_warmup_steps: int = 0, lr_scheduler: str = "constant", trainable_modules: Tuple[str] = (None, ), num_workers: int = 32, train_batch_size: int = 1, adam_beta1: float = 0.9, adam_beta2: float = 0.999, adam_weight_decay: float = 1e-2, adam_epsilon: float = 1e-08, max_grad_norm: float = 1.0, gradient_accumulation_steps: int = 1, gradient_checkpointing: bool = False, checkpointing_epochs: int = 5, checkpointing_steps: int = -1, mixed_precision_training: bool = True, enable_xformers_memory_efficient_attention: bool = True, global_seed: int = 42, is_debug: bool = False, ): check_min_version("0.10.0.dev0") # Initialize distributed training local_rank = init_dist(launcher=launcher) global_rank = dist.get_rank() num_processes = dist.get_world_size() is_main_process = global_rank == 0 seed = global_seed + global_rank torch.manual_seed(seed) # Logging folder folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H-%M-%S") output_dir = os.path.join(output_dir, folder_name) if is_debug and os.path.exists(output_dir): os.system(f"rm -rf {output_dir}") *_, config = inspect.getargvalues(inspect.currentframe()) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) if is_main_process and (not is_debug) and use_wandb: run = wandb.init(project="animatediff", name=folder_name, config=config) # Handle the output folder creation if is_main_process: os.makedirs(output_dir, exist_ok=True) os.makedirs(f"{output_dir}/samples", exist_ok=True) os.makedirs(f"{output_dir}/sanity_check", exist_ok=True) os.makedirs(f"{output_dir}/checkpoints", exist_ok=True) OmegaConf.save(config, os.path.join(output_dir, 'config.yaml')) # Load scheduler, tokenizer and models. noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") if not image_finetune: unet = UNet3DConditionModel.from_pretrained_2d( pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs) ) else: unet = UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet") # Load pretrained unet weights if unet_checkpoint_path != "": zero_rank_print(f"from checkpoint: {unet_checkpoint_path}") unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu") if "global_step" in unet_checkpoint_path: zero_rank_print(f"global_step: {unet_checkpoint_path['global_step']}") state_dict = unet_checkpoint_path["state_dict"] if "state_dict" in unet_checkpoint_path else unet_checkpoint_path m, u = unet.load_state_dict(state_dict, strict=False) zero_rank_print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") assert len(u) == 0 # Freeze vae and text_encoder vae.requires_grad_(False) text_encoder.requires_grad_(False) # Set unet trainable parameters unet.requires_grad_(False) for name, param in unet.named_parameters(): for trainable_module_name in trainable_modules: if trainable_module_name in name: param.requires_grad = True break trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters())) optimizer = torch.optim.AdamW( trainable_params, lr=learning_rate, betas=(adam_beta1, adam_beta2), weight_decay=adam_weight_decay, eps=adam_epsilon, ) if is_main_process: zero_rank_print(f"trainable params number: {len(trainable_params)}") zero_rank_print(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M") # Enable xformers if enable_xformers_memory_efficient_attention: if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") # Enable gradient checkpointing if gradient_checkpointing: unet.enable_gradient_checkpointing() # Move models to GPU vae.to(local_rank) text_encoder.to(local_rank) # Get the training dataset train_dataset = WebVid10M(**train_data, is_image=image_finetune) distributed_sampler = DistributedSampler( train_dataset, num_replicas=num_processes, rank=global_rank, shuffle=True, seed=global_seed, ) # DataLoaders creation: train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=train_batch_size, shuffle=False, sampler=distributed_sampler, num_workers=num_workers, pin_memory=True, drop_last=True, ) # Get the training iteration if max_train_steps == -1: assert max_train_epoch != -1 max_train_steps = max_train_epoch * len(train_dataloader) if checkpointing_steps == -1: assert checkpointing_epochs != -1 checkpointing_steps = checkpointing_epochs * len(train_dataloader) if scale_lr: learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size * num_processes) # Scheduler lr_scheduler = get_scheduler( lr_scheduler, optimizer=optimizer, num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, num_training_steps=max_train_steps * gradient_accumulation_steps, ) # Validation pipeline if not image_finetune: validation_pipeline = AnimationPipeline( unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, ).to("cuda") else: validation_pipeline = StableDiffusionPipeline.from_pretrained( pretrained_model_path, unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, safety_checker=None, ) validation_pipeline.enable_vae_slicing() # DDP warpper unet.to(local_rank) unet = DDP(unet, device_ids=[local_rank], output_device=local_rank) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) # Afterwards we recalculate our number of training epochs num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) # Train! total_batch_size = train_batch_size * num_processes * gradient_accumulation_steps if is_main_process: logging.info("***** Running training *****") logging.info(f" Num examples = {len(train_dataset)}") logging.info(f" Num Epochs = {num_train_epochs}") logging.info(f" Instantaneous batch size per device = {train_batch_size}") logging.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logging.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") logging.info(f" Total optimization steps = {max_train_steps}") global_step = 0 first_epoch = 0 # Only show the progress bar once on each machine. progress_bar = tqdm(range(global_step, max_train_steps), disable=not is_main_process) progress_bar.set_description("Steps") # Support mixed-precision training scaler = torch.cuda.amp.GradScaler() if mixed_precision_training else None for epoch in range(first_epoch, num_train_epochs): train_dataloader.sampler.set_epoch(epoch) unet.train() for step, batch in enumerate(train_dataloader): if cfg_random_null_text: batch['text'] = [name if random.random() > cfg_random_null_text_ratio else "" for name in batch['text']] # Data batch sanity check if epoch == first_epoch and step == 0: pixel_values, texts = batch['pixel_values'].cpu(), batch['text'] if not image_finetune: pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w") for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)): pixel_value = pixel_value[None, ...] save_videos_grid(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.gif", rescale=True) else: for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)): pixel_value = pixel_value / 2. + 0.5 torchvision.utils.save_image(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.png") ### >>>> Training >>>> ### # Convert videos to latent space pixel_values = batch["pixel_values"].to(local_rank) video_length = pixel_values.shape[1] with torch.no_grad(): if not image_finetune: pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w") latents = vae.encode(pixel_values).latent_dist latents = latents.sample() latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length) else: latents = vae.encode(pixel_values).latent_dist latents = latents.sample() latents = latents * 0.18215 # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each video timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning with torch.no_grad(): prompt_ids = tokenizer( batch['text'], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ).input_ids.to(latents.device) encoder_hidden_states = text_encoder(prompt_ids)[0] # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": raise NotImplementedError else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") # Predict the noise residual and compute loss # Mixed-precision training with torch.cuda.amp.autocast(enabled=mixed_precision_training): model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") optimizer.zero_grad() # Backpropagate if mixed_precision_training: scaler.scale(loss).backward() """ >>> gradient clipping >>> """ scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm) """ <<< gradient clipping <<< """ scaler.step(optimizer) scaler.update() else: loss.backward() """ >>> gradient clipping >>> """ torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm) """ <<< gradient clipping <<< """ optimizer.step() lr_scheduler.step() progress_bar.update(1) global_step += 1 ### <<<< Training <<<< ### # Wandb logging if is_main_process and (not is_debug) and use_wandb: wandb.log({"train_loss": loss.item()}, step=global_step) # Save checkpoint if is_main_process and (global_step % checkpointing_steps == 0 or step == len(train_dataloader) - 1): save_path = os.path.join(output_dir, f"checkpoints") state_dict = { "epoch": epoch, "global_step": global_step, "state_dict": unet.state_dict(), } if step == len(train_dataloader) - 1: torch.save(state_dict, os.path.join(save_path, f"checkpoint-epoch-{epoch+1}.ckpt")) else: torch.save(state_dict, os.path.join(save_path, f"checkpoint.ckpt")) logging.info(f"Saved state to {save_path} (global_step: {global_step})") # Periodically validation if is_main_process and (global_step % validation_steps == 0 or global_step in validation_steps_tuple): samples = [] generator = torch.Generator(device=latents.device) generator.manual_seed(global_seed) height = train_data.sample_size[0] if not isinstance(train_data.sample_size, int) else train_data.sample_size width = train_data.sample_size[1] if not isinstance(train_data.sample_size, int) else train_data.sample_size prompts = validation_data.prompts[:2] if global_step < 1000 and (not image_finetune) else validation_data.prompts for idx, prompt in enumerate(prompts): if not image_finetune: sample = validation_pipeline( prompt, generator = generator, video_length = train_data.sample_n_frames, height = height, width = width, **validation_data, ).videos save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{idx}.gif") samples.append(sample) else: sample = validation_pipeline( prompt, generator = generator, height = height, width = width, num_inference_steps = validation_data.get("num_inference_steps", 25), guidance_scale = validation_data.get("guidance_scale", 8.), ).images[0] sample = torchvision.transforms.functional.to_tensor(sample) samples.append(sample) if not image_finetune: samples = torch.concat(samples) save_path = f"{output_dir}/samples/sample-{global_step}.gif" save_videos_grid(samples, save_path) else: samples = torch.stack(samples) save_path = f"{output_dir}/samples/sample-{global_step}.png" torchvision.utils.save_image(samples, save_path, nrow=4) logging.info(f"Saved samples to {save_path}") logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step >= max_train_steps: break dist.destroy_process_group() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) parser.add_argument("--launcher", type=str, choices=["pytorch", "slurm"], default="pytorch") parser.add_argument("--wandb", action="store_true") args = parser.parse_args() name = Path(args.config).stem config = OmegaConf.load(args.config) main(name=name, launcher=args.launcher, use_wandb=args.wandb, **config)