Repository: MC-E/ReVideo Branch: main Commit: 174645250d9d Files: 80 Total size: 445.7 KB Directory structure: gitextract_6p1lor9m/ ├── LICENSE ├── README.md ├── configs/ │ ├── examples/ │ │ ├── constant_motion/ │ │ │ ├── head6.sh │ │ │ ├── head7.sh │ │ │ ├── kong.sh │ │ │ ├── monkey.sh │ │ │ ├── woman2.sh │ │ │ └── woman5.sh │ │ ├── multi_region/ │ │ │ ├── lawn2.sh │ │ │ └── woman.sh │ │ └── single_region/ │ │ ├── desert.sh │ │ ├── dog.sh │ │ ├── football.sh │ │ ├── forest.sh │ │ ├── head5.sh │ │ ├── lawn.sh │ │ ├── lizard.sh │ │ ├── road.sh │ │ ├── sea.sh │ │ ├── sea2.sh │ │ ├── sky.sh │ │ └── woman4.sh │ └── inference/ │ └── config_test.yaml ├── ctrl_model/ │ ├── diffusion_ctrl.py │ └── svd_ctrl.py ├── main/ │ └── inference/ │ ├── sample_constant_motion.py │ ├── sample_multi_region.py │ └── sample_single_region.py ├── requirements.txt ├── sgm/ │ ├── __init__.py │ ├── inference/ │ │ ├── api.py │ │ └── helpers.py │ ├── lr_scheduler.py │ ├── models/ │ │ ├── __init__.py │ │ ├── autoencoder.py │ │ └── diffusion.py │ ├── modules/ │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── autoencoding/ │ │ │ ├── __init__.py │ │ │ ├── losses/ │ │ │ │ ├── __init__.py │ │ │ │ ├── discriminator_loss.py │ │ │ │ └── lpips.py │ │ │ ├── lpips/ │ │ │ │ ├── __init__.py │ │ │ │ ├── loss/ │ │ │ │ │ ├── .gitignore │ │ │ │ │ ├── LICENSE │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── lpips.py │ │ │ │ ├── model/ │ │ │ │ │ ├── LICENSE │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── model.py │ │ │ │ ├── util.py │ │ │ │ └── vqperceptual.py │ │ │ ├── regularizers/ │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ └── quantize.py │ │ │ └── temporal_ae.py │ │ ├── diffusionmodules/ │ │ │ ├── __init__.py │ │ │ ├── denoiser.py │ │ │ ├── denoiser_scaling.py │ │ │ ├── denoiser_weighting.py │ │ │ ├── discretizer.py │ │ │ ├── guiders.py │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ ├── sampling.py │ │ │ ├── sampling_utils.py │ │ │ ├── sigma_sampling.py │ │ │ ├── util.py │ │ │ ├── video_model.py │ │ │ └── wrappers.py │ │ ├── distributions/ │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── encoders/ │ │ │ ├── __init__.py │ │ │ └── modules.py │ │ └── video_attention.py │ └── util.py └── utils/ ├── save_video.py ├── tools.py └── visualizer.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ STABLE VIDEO DIFFUSION NON-COMMERCIAL COMMUNITY LICENSE AGREEMENT Dated: November 21, 2023 “AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time. "Agreement" means the terms and conditions for use, reproduction, distribution and modification of the Software Products set forth herein. "Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model. “Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software. "Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf. "Stability AI" or "we" means Stability AI Ltd. "Software" means, collectively, Stability AI’s proprietary software, models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement. “Software Products” means Software and Documentation. By using or distributing any portion or element of the Software Products, you agree to be bound by this Agreement. License Rights and Redistribution. Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Software Products to reproduce, distribute, and create Derivative Works of the Software Products for purposes other than commercial or production use. b. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "Stable Video Diffusion is licensed under the Stable Video Diffusion Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified. 2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS. 3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. 3. Intellectual Property. a. No trademark licenses are granted under this Agreement, and in connection with the Software Products, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products. Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works. If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products in violation of this Agreement. 4. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of any Software Products or Derivative Works. Sections 2-4 shall survive the termination of this Agreement. ================================================ FILE: README.md ================================================ # ReVideo: Remake a Video with Motion and Content Control [Chong Mou](https://scholar.google.com/citations?user=SYQoDk0AAAAJ&hl=zh-CN), [Mingdeng Cao](https://scholar.google.com/citations?user=EcS0L5sAAAAJ&hl=en), [Xintao Wang](https://xinntao.github.io/), [Zhaoyang Zhang](https://zzyfd.github.io/), [Ying Shan](https://scholar.google.com/citations?user=4oXBp9UAAAAJ), [Jian Zhang](https://jianzhang.tech/) [![Project page](https://img.shields.io/badge/Project-Page-brightgreen)](https://mc-e.github.io/project/ReVideo/) [![arXiv](https://img.shields.io/badge/ArXiv-2405.13865-brightgreen)](https://arxiv.org/abs/2405.13865) --- ## Introduction ReVideo aims to solve the problem of local video editing. The editing target includes visual content and motion trajectory modifications.

## 📰 **New Features/Updates** - [2024/09/25] ReVideo is accepted by NeurIPS 2024. - [2024/06/26] We release the code of ReVideo. - [2024/05/26] **Long video editing plan**: We are collaborating with [Open-Sora Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) team to replace SVD with Sora framework, making ReVideo suitable for long video editing. Here are some preliminary results. This initial combination is still limited in quality for long videos. In the future, we will continue to cooperate and launch high-quality long video editing models.
Generated by Open-Sora Editing Result
- [2024/05/23] Paper and project page of **ReVideo** are available. ## ✏️ Todo - [x] Code will be open-sourced in June ## 🔥🔥🔥 Main Features ### Change content & Customize motion trajectoy
Input Video Editing Result
### Change content & Keep motion trajectoy
Input Video Editing Result
### Keep content & Customize motion trajectoy
Input Video Editing Result
### Multi-area Editing
Input Video Editing Result
## 🔧 Dependencies and Installation - Python >= 3.8 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html)) - [PyTorch >= 2.0.1](https://pytorch.org/) ```bash pip install -r requirements.txt ``` ## ⏬ Download Models All models will be automatically downloaded. You can also choose to download manually from this [url](https://huggingface.co/Adapter/ReVideo). **Since our ReVideo is trained based on [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid), the usage guidelines for the model should follow the Stable Video Diffusion's [NC-COMMUNITY LICENSE](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid/blob/main/LICENSE)!** ## 💻 How to Test You can download the testset from [https://huggingface.co/Adapter/ReVideo](https://huggingface.co/Adapter/ReVideo). Inference requires at least `20GB` of GPU memory for editing a `768x1344` video. ```bash bash configs/examples/constant_motion/head6.sh ``` ### Description of input parameters ```bash --s_h # The abscissa of the top left corner of the editing region --e_h # The abscissa of the lower right corner of the editing region --s_w # The ordinate of the top left corner of the editing region --e_w # The ordinate of the lower right corner of the editing region --ps_h # The abscissa of the start point --pe_h # The abscissa of the end point --ps_w # The ordinate of the start point --pe_w # The ordinate of the end point --x_bias_all # Horizontal offset of reciprocating motion --y_bias_all # Vertical offset of reciprocating motion ``` ## Related Works

[1] https://pika.art/

[2] DragNUWA: Fine-grained Control in Video Generation by Integrating Text, Image, and Trajectory

[3] DragAnything: Motion Control for Anything using Entity Representation

[4] AnyV2V: A Plug-and-Play Framework For Any Video-to-Video Editing Tasks

# 🤗 Acknowledgements We appreciate the releasing code of [Stable Video Diffusion](https://github.com/Stability-AI/generative-models). ================================================ FILE: configs/examples/constant_motion/head6.sh ================================================ name="svd-example2[fps6_mb127-temp]" config="configs/inference/config_test.yaml" ckpt="ckpt/model.ckpt" image_input="testset/head6" path_ref="testset/reference/head6.png" res_dir="outputs" python3 main/inference/sample_constant_motion.py \ --seed 23 \ --ckpt $ckpt \ --config $config \ --savedir $res_dir/$name \ --savefps 10 \ --ddim_steps 25 \ --frames 14 \ --savefps 10 \ --input $image_input \ --path_ref $path_ref \ --fps 10 \ --motion 127 \ --cond_aug 0.02 \ --decoding_t 1 --resize \ --s_h 247 \ --e_h 661 \ --s_w 46 \ --e_w 620 \ --ps_h 389 350 464 385 \ --ps_w 264 363 363 440 ================================================ FILE: configs/examples/constant_motion/head7.sh ================================================ name="svd-example2[fps6_mb127-temp]" config="configs/inference/config_test.yaml" ckpt="ckpt/model.ckpt" image_input="testset/head7" path_ref="testset/reference/head7.png" res_dir="outputs" python3 main/inference/sample_constant_motion.py \ --seed 23 \ --ckpt $ckpt \ --config $config \ --savedir $res_dir/$name \ --savefps 10 \ --ddim_steps 25 \ --frames 14 \ --savefps 10 \ --input $image_input \ --path_ref $path_ref \ --fps 10 \ --motion 127 \ --cond_aug 0.02 \ --decoding_t 1 --resize \ --s_h 1672 \ --e_h 2433 \ --s_w 330 \ --e_w 704 \ --ps_h 1854 2082 1932 \ --ps_w 479 501 622 ================================================ FILE: configs/examples/constant_motion/kong.sh ================================================ name="svd-example2[fps6_mb127-temp]" config="configs/inference/config_test.yaml" ckpt="ckpt/model.ckpt" image_input="testset/kong" path_ref="testset/reference/kong.png" res_dir="outputs" python3 main/inference/sample_constant_motion.py \ --seed 23 \ --ckpt $ckpt \ --config $config \ --savedir $res_dir/$name \ --savefps 10 \ --ddim_steps 15 \ --frames 14 \ --savefps 10 \ --input $image_input \ --path_ref $path_ref \ --fps 10 \ --motion 127 \ --cond_aug 0.02 \ --decoding_t 1 --resize \ --s_h 290 \ --e_h 1128 \ --s_w 0 \ --e_w 456 \ --ps_h 720 911 835 877 \ --ps_w 223 258 150 123 ================================================ FILE: configs/examples/constant_motion/monkey.sh ================================================ name="svd-example2[fps6_mb127-temp]" config="configs/inference/config_test.yaml" ckpt="ckpt/model.ckpt" image_input="testset/monkey" path_ref="testset/reference/monkey.png" res_dir="outputs" python3 main/inference/sample_constant_motion.py \ --seed 23 \ --ckpt $ckpt \ --config $config \ --savedir $res_dir/$name \ --savefps 10 \ --ddim_steps 25 \ --frames 14 \ --savefps 10 \ --input $image_input \ --path_ref $path_ref \ --fps 10 \ --motion 127 \ --cond_aug 0.02 \ --decoding_t 1 --resize \ --s_h 461 \ --e_h 1039 \ --s_w 189 \ --e_w 770 \ --ps_h 615 736 694 \ --ps_w 457 461 586 ================================================ FILE: configs/examples/constant_motion/woman2.sh ================================================ name="svd-example2[fps6_mb127-temp]" config="configs/inference/config_test.yaml" ckpt="ckpt/model.ckpt" image_input="testset/woman2" path_ref="testset/reference/woman2.png" res_dir="outputs" python3 main/inference/sample_constant_motion.py \ --seed 23 \ --ckpt $ckpt \ --config $config \ --savedir $res_dir/$name \ --savefps 10 \ --ddim_steps 15 \ --frames 14 \ --savefps 10 \ --input $image_input \ --path_ref $path_ref \ --fps 10 \ --motion 127 \ --cond_aug 0.02 \ --decoding_t 1 --resize \ --s_h 797 \ --e_h 1098 \ --s_w 315 \ --e_w 647 \ --ps_h 888 1009 950 \ --ps_w 391 385 502 ================================================ FILE: configs/examples/constant_motion/woman5.sh ================================================ name="svd-example2[fps6_mb127-temp]" config="configs/inference/config_test.yaml" ckpt="ckpt/model.ckpt" image_input="testset/woman5" path_ref="testset/reference/woman5.png" res_dir="outputs" python3 main/inference/sample_constant_motion.py \ --seed 23 \ --ckpt $ckpt \ --config $config \ --savedir $res_dir/$name \ --savefps 10 \ --ddim_steps 15 \ --frames 14 \ --savefps 10 \ --input $image_input \ --path_ref $path_ref \ --fps 10 \ --motion 127 \ --cond_aug 0.02 \ --decoding_t 1 --resize \ --s_h 605 \ --e_h 1449 \ --s_w 599 \ --e_w 1929 \ --ps_h 825 1263 1671 \ --ps_w 1123 1287 1231 ================================================ FILE: configs/examples/multi_region/lawn2.sh ================================================ name="svd-example2[fps6_mb127-temp]" config="configs/inference/config_test.yaml" ckpt="ckpt/model.ckpt" image_input="testset/lawn2" path_ref="testset/reference/lawn2.png" res_dir="outputs" python3 main/inference/sample_multi_region.py \ --seed 23 \ --ckpt $ckpt \ --config $config \ --savedir $res_dir/$name \ --savefps 10 \ --ddim_steps 15 \ --frames 14 \ --savefps 10 \ --input $image_input \ --path_ref $path_ref \ --fps 6 \ --motion 127 \ --cond_aug 0.02 \ --decoding_t 1 --resize \ --s_h 1579 2577 509 1405 2369 \ --e_h 2155 2991 821 1723 2725 \ --s_w 311 119 1207 1257 1189 \ --e_w 673 393 1653 1663 1641 \ --ps_h 2887 2027 675 1577 2561 \ --pe_h 2707 1777 675 1577 2561 \ --ps_w 221 449 1347 1395 1367 \ --pe_w 221 449 1347 1395 1367 \ --x_bias_all 0 \ --x_bias_all 0 \ --x_bias_all 10 20 30 20 10 0 \ --x_bias_all 10 20 30 20 10 0 \ --x_bias_all 10 20 30 20 10 0 \ --y_bias_all 0 \ --y_bias_all 0 \ --y_bias_all 0 \ --y_bias_all 0 \ --y_bias_all 0 ================================================ FILE: configs/examples/multi_region/woman.sh ================================================ name="svd-example2[fps6_mb127-temp]" config="configs/inference/config_test.yaml" ckpt="ckpt/model.ckpt" image_input="testset/woman" path_ref="testset/reference/woman.png" res_dir="outputs" python3 main/inference/sample_multi_region.py \ --seed 23 \ --ckpt $ckpt \ --config $config \ --savedir $res_dir/$name \ --savefps 10 \ --ddim_steps 15 \ --frames 14 \ --savefps 10 \ --input $image_input \ --path_ref $path_ref \ --fps 6 \ --motion 127 \ --cond_aug 0.02 \ --decoding_t 1 --resize \ --s_h 437 1624 2140 85 328 1921 \ --e_h 752 1968 2399 402 728 2336 \ --s_w 324 339 434 243 565 223 \ --e_w 597 515 648 442 771 412 \ --ps_h 531 1718 2209 169 409 2007 \ --pe_h 681 1900 2331 318 552 2261 \ --ps_w 397 396 589 284 646 292 \ --pe_w 523 433 495 381 697 353 \ --x_bias_all 0 \ --x_bias_all 0 \ --x_bias_all 0 \ --x_bias_all 0 \ --x_bias_all 0 \ --x_bias_all 0 \ --y_bias_all 0 \ --y_bias_all 0 \ --y_bias_all 0 \ --y_bias_all 0 \ --y_bias_all 0 \ --y_bias_all 0 ================================================ FILE: configs/examples/single_region/desert.sh ================================================ name="svd-example2[fps6_mb127-temp]" config="configs/inference/config_test.yaml" ckpt="ckpt/model.ckpt" image_input="testset/desert" path_ref="testset/reference/desert.png" res_dir="outputs" python3 main/inference/sample_single_region.py \ --seed 23 \ --ckpt $ckpt \ --config $config \ --savedir $res_dir/$name \ --savefps 10 \ --ddim_steps 15 \ --frames 14 \ --savefps 10 \ --input $image_input \ --path_ref $path_ref \ --fps 10 \ --motion 127 \ --cond_aug 0.02 \ --decoding_t 1 --resize \ --s_h 1021 \ --e_h 2428 \ --s_w 1290 \ --e_w 1896 \ --ps_h 1805 \ --pe_h 1935 \ --ps_w 1585 \ --pe_w 1554 \ --x_bias 0 \ --y_bias 0 ================================================ FILE: configs/examples/single_region/dog.sh ================================================ name="svd-example2[fps6_mb127-temp]" config="configs/inference/config_test.yaml" ckpt="ckpt/model.ckpt" image_input="testset/dog" path_ref="testset/reference/dog.png" res_dir="outputs" python3 main/inference/sample_single_region.py \ --seed 23 \ --ckpt $ckpt \ --config $config \ --savedir $res_dir/$name \ --savefps 10 \ --ddim_steps 15 \ --frames 14 \ --savefps 10 \ --input $image_input \ --path_ref $path_ref \ --fps 10 \ --motion 127 \ --cond_aug 0.02 \ --decoding_t 1 --resize \ --s_h 83 \ --e_h 1477 \ --s_w 97 \ --e_w 982 \ --ps_h 803 608 1022 \ --pe_h 574 529 811 \ --ps_w 429 261 340 \ --pe_w 429 261 340 \ --x_bias 0 \ --y_bias 0 ================================================ FILE: configs/examples/single_region/football.sh ================================================ name="svd-example2[fps6_mb127-temp]" config="configs/inference/config_test.yaml" ckpt="ckpt/model.ckpt" image_input="testset/football" path_ref="testset/reference/football.png" res_dir="outputs" python3 main/inference/sample_single_region.py \ --seed 23 \ --ckpt $ckpt \ --config $config \ --savedir $res_dir/$name \ --savefps 10 \ --ddim_steps 15 \ --frames 14 \ --savefps 10 \ --input $image_input \ --path_ref $path_ref \ --fps 10 \ --motion 127 \ --cond_aug 0.02 \ --decoding_t 1 --resize \ --s_h 105 \ --e_h 1119 \ --s_w 910 \ --e_w 1301 \ --ps_h 629 \ --pe_h 789 \ --ps_w 1075 \ --pe_w 1075 \ --x_bias 0 \ --y_bias 0 ================================================ FILE: configs/examples/single_region/forest.sh ================================================ name="svd-example2[fps6_mb127-temp]" config="configs/inference/config_test.yaml" ckpt="ckpt/model.ckpt" image_input="testset/forest" path_ref="testset/reference/forest.png" res_dir="outputs" python3 main/inference/sample_single_region.py \ --seed 23 \ --ckpt $ckpt \ --config $config \ --savedir $res_dir/$name \ --savefps 10 \ --ddim_steps 15 \ --frames 14 \ --savefps 10 \ --input $image_input \ --path_ref $path_ref \ --fps 10 \ --motion 127 \ --cond_aug 0.02 \ --decoding_t 1 --resize \ --s_h 1045 \ --e_h 1497 \ --s_w 1511 \ --e_w 2073 \ --ps_h 1275 \ --pe_h 1275 \ --ps_w 1695 \ --pe_w 1695 \ --x_bias 10 20 10 0 \ --y_bias 0 0 0 0 ================================================ FILE: configs/examples/single_region/head5.sh ================================================ name="svd-example2[fps6_mb127-temp]" config="configs/inference/config_test.yaml" ckpt="ckpt/model.ckpt" image_input="testset/head5" path_ref="testset/reference/head5.png" res_dir="outputs" python3 main/inference/sample_single_region.py \ --seed 23 \ --ckpt $ckpt \ --config $config \ --savedir $res_dir/$name \ --savefps 10 \ --ddim_steps 15 \ --frames 14 \ --savefps 10 \ --input $image_input \ --path_ref $path_ref \ --fps 10 \ --motion 127 \ --cond_aug 0.02 \ --decoding_t 1 --resize \ --s_h 828 \ --e_h 1494 \ --s_w 0 \ --e_w 766 \ --ps_h 1157 \ --pe_h 1356 \ --ps_w 529 \ --pe_w 529 \ --x_bias 0 0 0 0 \ --y_bias 10 20 10 0 ================================================ FILE: configs/examples/single_region/lawn.sh ================================================ name="svd-example2[fps6_mb127-temp]" config="configs/inference/config_test.yaml" ckpt="ckpt/model.ckpt" image_input="testset/lawn" path_ref="testset/reference/lawn.png" res_dir="outputs" python3 main/inference/sample_single_region.py \ --seed 23 \ --ckpt $ckpt \ --config $config \ --savedir $res_dir/$name \ --savefps 10 \ --ddim_steps 15 \ --frames 14 \ --savefps 10 \ --input $image_input \ --path_ref $path_ref \ --fps 10 \ --motion 127 \ --cond_aug 0.02 \ --decoding_t 1 --resize \ --s_h 383 \ --e_h 1109 \ --s_w 485 \ --e_w 1627 \ --ps_h 717 741 \ --pe_h 717 741 \ --ps_w 731 945 \ --pe_w 781 995 \ --x_bias 15 30 45 30 15 0 \ --y_bias 0 ================================================ FILE: configs/examples/single_region/lizard.sh ================================================ name="svd-example2[fps6_mb127-temp]" config="configs/inference/config_test.yaml" ckpt="ckpt/model.ckpt" image_input="testset/lizard" path_ref="testset/reference/lizard.png" res_dir="outputs" python3 main/inference/sample_single_region.py \ --seed 23 \ --ckpt $ckpt \ --config $config \ --savedir $res_dir/$name \ --savefps 10 \ --ddim_steps 15 \ --frames 14 \ --savefps 10 \ --input $image_input \ --path_ref $path_ref \ --fps 10 \ --motion 127 \ --cond_aug 0.02 \ --decoding_t 1 --resize \ --s_h 152 \ --e_h 1062 \ --s_w 248 \ --e_w 1021 \ --ps_h 520 \ --pe_h 700 \ --ps_w 542 \ --pe_w 542 \ --x_bias 0 \ --y_bias 0 ================================================ FILE: configs/examples/single_region/road.sh ================================================ name="svd-example2[fps6_mb127-temp]" config="configs/inference/config_test.yaml" ckpt="ckpt/model.ckpt" image_input="testset/road" path_ref="testset/reference/road.png" res_dir="outputs" python3 main/inference/sample_single_region.py \ --seed 23 \ --ckpt $ckpt \ --config $config \ --savedir $res_dir/$name \ --savefps 10 \ --ddim_steps 15 \ --frames 14 \ --savefps 10 \ --input $image_input \ --path_ref $path_ref \ --fps 10 \ --motion 127 \ --cond_aug 0.02 \ --decoding_t 1 --resize \ --s_h 886 \ --e_h 1235 \ --s_w 704 \ --e_w 944 \ --ps_h 1115 \ --pe_h 1046 \ --ps_w 790 \ --pe_w 881 \ --x_bias 0 \ --y_bias 0 ================================================ FILE: configs/examples/single_region/sea.sh ================================================ name="svd-example2[fps6_mb127-temp]" config="configs/inference/config_test.yaml" ckpt="ckpt/model.ckpt" image_input="testset/sea2" path_ref="testset/reference/sea2.png" res_dir="outputs" python3 main/inference/sample_single_region.py \ --seed 23 \ --ckpt $ckpt \ --config $config \ --savedir $res_dir/$name \ --savefps 10 \ --ddim_steps 15 \ --frames 14 \ --savefps 10 \ --input $image_input \ --path_ref $path_ref \ --fps 10 \ --motion 127 \ --cond_aug 0.02 \ --decoding_t 1 --resize \ --s_h 547 \ --e_h 1400 \ --s_w 488 \ --e_w 882 \ --ps_h 939 \ --pe_h 1026 \ --ps_w 649 \ --pe_w 687 \ --x_bias 0 0 0 0 \ --y_bias 10 20 10 0 ================================================ FILE: configs/examples/single_region/sea2.sh ================================================ name="svd-example2[fps6_mb127-temp]" config="configs/inference/config_test.yaml" ckpt="ckpt/model.ckpt" image_input="testset/sea2" path_ref="testset/reference/sea2_2.png" res_dir="outputs" python3 main/inference/sample_single_region.py \ --seed 23 \ --ckpt $ckpt \ --config $config \ --savedir $res_dir/$name \ --savefps 10 \ --ddim_steps 15 \ --frames 14 \ --savefps 10 \ --input $image_input \ --path_ref $path_ref \ --fps 10 \ --motion 127 \ --cond_aug 0.02 \ --decoding_t 1 --resize \ --s_h 306 \ --e_h 1077 \ --s_w 617 \ --e_w 962 \ --ps_h 737 \ --pe_h 485 \ --ps_w 761 \ --pe_w 764 \ --x_bias 0 \ --y_bias 0 ================================================ FILE: configs/examples/single_region/sky.sh ================================================ name="svd-example2[fps6_mb127-temp]" config="configs/inference/config_test.yaml" ckpt="ckpt/model.ckpt" image_input="testset/sky" path_ref="testset/reference/sky.png" res_dir="outputs" python3 main/inference/sample_single_region.py \ --seed 23 \ --ckpt $ckpt \ --config $config \ --savedir $res_dir/$name \ --savefps 10 \ --ddim_steps 15 \ --frames 14 \ --savefps 10 \ --input $image_input \ --path_ref $path_ref \ --fps 10 \ --motion 127 \ --cond_aug 0.02 \ --decoding_t 1 --resize \ --s_h 580 \ --e_h 1037 \ --s_w 397 \ --e_w 607 \ --ps_h 780 \ --pe_h 709 \ --ps_w 516 \ --pe_w 532 \ --x_bias 0 \ --y_bias 0 ================================================ FILE: configs/examples/single_region/woman4.sh ================================================ name="svd-example2[fps6_mb127-temp]" config="configs/inference/config_test.yaml" ckpt="ckpt/model.ckpt" image_input="testset/woman4" path_ref="testset/reference/woman4.png" res_dir="outputs" python3 main/inference/sample_single_region.py \ --seed 23 \ --ckpt $ckpt \ --config $config \ --savedir $res_dir/$name \ --savefps 10 \ --ddim_steps 15 \ --frames 14 \ --savefps 10 \ --input $image_input \ --path_ref $path_ref \ --fps 10 \ --motion 127 \ --cond_aug 0.02 \ --decoding_t 1 --resize \ --s_h 94 \ --e_h 613 \ --s_w 236 \ --e_w 952 \ --ps_h 278 \ --pe_h 420 \ --ps_w 640 \ --pe_w 640 \ --x_bias 0 \ --y_bias 0 ================================================ FILE: configs/inference/config_test.yaml ================================================ num_frames: &num_frames 14 model: base_learning_rate: 3.0e-5 target: ctrl_model.diffusion_ctrl.DiffusionEngineCtrl params: scale_factor: 0.18215 input_key: video items: ['canny', 'depth', 'sketch', 'org'] probabilities: [0.3, 0.3, 0.3, 0.1] no_cond_log: true en_and_decode_n_samples_a_time: 1 use_ema: false disable_first_stage_autocast: true denoiser_config: target: sgm.modules.diffusionmodules.denoiser.Denoiser params: scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise network_config: target: ctrl_model.svd_ctrl.ControledVideoUnet params: num_frames: *num_frames adm_in_channels: 768 num_classes: sequential use_checkpoint: true in_channels: 8 out_channels: 4 model_channels: 320 attention_resolutions: [4, 2, 1] num_res_blocks: 2 channel_mult: [1, 2, 4, 4] num_head_channels: 64 use_linear_in_transformer: true transformer_depth: 1 context_dim: 1024 spatial_transformer_attn_type: softmax-xformers extra_ff_mix_layer: true use_spatial_context: true merge_strategy: learned_with_images video_kernel_size: [3, 1, 1] controlnet_config: target: ctrl_model.svd_ctrl.VideoCtrlNet params: num_frames: *num_frames adm_in_channels: 768 num_classes: sequential use_checkpoint: true in_channels: 8 model_channels: 320 attention_resolutions: [4, 2, 1] num_res_blocks: 2 channel_mult: [1, 2, 4, 4] num_head_channels: 64 use_linear_in_transformer: true transformer_depth: 1 context_dim: 1024 spatial_transformer_attn_type: softmax-xformers extra_ff_mix_layer: true use_spatial_context: true merge_strategy: learned_with_images video_kernel_size: [3, 1, 1] conditioning_channels: 2 conditioning_channels_mask: 3 conditioning_channels_region: 1 conditioning_embedding_out_channels: [16, 32, 96, 256] ctrlnet_scale: 1 if_init_from_unet: false ctrlnet_ckpt_path: ckpt/model.ckpt conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - is_trainable: false input_key: cond_frames_without_noise ucg_rate: 0.1 target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder params: n_cond_frames: 1 n_copies: 1 open_clip_embedding_config: target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder params: freeze: true # version: "ckpt/clip-h-14/models--laion--CLIP-ViT-H-14-laion2B-s32B-b79K/blobs/9a78ef8e8c73fd0df621682e7a8e8eb36c6916cb3c16b291a082ecd52ab79cc4" - input_key: fps_id is_trainable: false target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - input_key: motion_bucket_id is_trainable: false target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - input_key: cond_frames is_trainable: false ucg_rate: 0.1 target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder params: disable_encoder_autocast: true n_cond_frames: 1 n_copies: 1 is_ae: true encoder_config: target: sgm.models.autoencoder.AutoencoderKLModeOnly params: embed_dim: 4 monitor: val/rec_loss ddconfig: attn_type: vanilla-xformers double_z: true z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 lossconfig: target: torch.nn.Identity - input_key: cond_aug is_trainable: false target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 first_stage_config: target: sgm.models.autoencoder.AutoencodingEngine params: loss_config: target: torch.nn.Identity regularizer_config: target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer encoder_config: target: sgm.modules.diffusionmodules.model.Encoder params: attn_type: vanilla double_z: true z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 decoder_config: target: sgm.modules.autoencoding.temporal_ae.VideoDecoder params: attn_type: vanilla double_z: true z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 video_kernel_size: [3, 1, 1] sampler_config: target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler params: num_steps: 25 discretization_config: target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization params: sigma_max: 700.0 guider_config: target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider params: num_frames: *num_frames max_scale: 4 #2.5 min_scale: 1 #1.0 additional_cond_keys: ["ctrl_input", "mask", "region"] ================================================ FILE: ctrl_model/diffusion_ctrl.py ================================================ from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union from einops import rearrange, repeat import torch import torch.nn as nn import torch.nn.functional as F import torch as th from omegaconf import ListConfig, OmegaConf from safetensors.torch import load_file as load_safetensors from torch.optim.lr_scheduler import LambdaLR from sgm.models.diffusion import DiffusionEngine from sgm.util import (default, disabled_train, get_obj_from_str, instantiate_from_config, log_txt_as_img) from sgm.modules.diffusionmodules.wrappers import OpenAIWrapper import cv2 import numpy as np from utils.save_video import save_rgb_video, save_flow_video class DiffusionEngineCtrl(DiffusionEngine): def __init__( self, network_config, denoiser_config, first_stage_config, controlnet_config: Optional[Dict] = None, ctrlnet_key: Optional[str] = None, items = None, probabilities = None, conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None, sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None, optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None, scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None, loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None, network_wrapper: Union[None, str] = None, ckpt_path: Union[None, str] = None, use_ema: bool = False, ema_decay_rate: float = 0.9999, scale_factor: float = 1.0, disable_first_stage_autocast=False, input_key: str = "jpg", log_keys: Union[List, None] = None, no_cond_log: bool = False, compile_model: bool = False, en_and_decode_n_samples_a_time: Optional[int] = None, kernel_size = 199, sigma = 20, ): super().__init__( network_config, denoiser_config, first_stage_config, conditioner_config, sampler_config, optimizer_config, scheduler_config, loss_fn_config, network_wrapper, ckpt_path, use_ema, ema_decay_rate, scale_factor, disable_first_stage_autocast, input_key, log_keys, no_cond_log, compile_model, en_and_decode_n_samples_a_time, ) self.items = items self.probabilities = probabilities ctrlnet_model = instantiate_from_config(controlnet_config) if ctrlnet_model.ctrlnet_ckpt_path is not None: ctrlnet_model.init_from_ckpt() elif ctrlnet_model.if_init_from_unet == True: ctrlnet_model.init_from_unet(self.model.diffusion_model) else: print('random initial UNet') self.ctrlnet_key = ctrlnet_key self.model = CtrlNetWrapper( self.model.diffusion_model, compile_model=False, # the UNet may be compiled in the super().__init__() ctrlnet_model=ctrlnet_model, ) class CtrlNetWrapper(OpenAIWrapper): def __init__(self, diffusion_model, compile_model: bool = False, ctrlnet_model: nn.Module = None): super().__init__(diffusion_model, compile_model) self.ctrlnet_model = ctrlnet_model def forward( self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs ) -> torch.Tensor: x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) ctrlnet_cond = c.get("ctrl_input") mask = c.get("mask") region = c.get("region") assert ctrlnet_cond is not None, "Input SVD CtrlNet conditon is None!!!" down_block_res_samples, mid_block_res_sample = self.ctrlnet_model( x, mask=mask, region=region, conds=ctrlnet_cond, timesteps=t, context=c.get("crossattn", None), y=c.get("vector", None), **kwargs, ) return self.diffusion_model( x, timesteps=t, context=c.get("crossattn", None), y=c.get("vector", None), mid_block_additional_residual=mid_block_res_sample, down_block_additional_residuals=down_block_res_samples, **kwargs, ) ================================================ FILE: ctrl_model/svd_ctrl.py ================================================ from functools import partial from typing import List, Optional, Union, Tuple from einops import rearrange import torch import torch.nn as nn import torch.nn.functional as F import torch as th from sgm.modules.diffusionmodules.openaimodel import * from sgm.modules.video_attention import SpatialVideoTransformer from sgm.util import default from sgm.modules.diffusionmodules.util import AlphaBlender from sgm.modules.diffusionmodules.video_model import VideoResBlock, VideoUNet import numpy as np import cv2 class ControledVideoUnet(VideoUNet): """ Only modify the forward function by adding additional controls """ def forward( self, x: th.Tensor, timesteps: th.Tensor, context: Optional[th.Tensor] = None, y: Optional[th.Tensor] = None, time_context: Optional[th.Tensor] = None, num_video_frames: Optional[int] = None, image_only_indicator: Optional[th.Tensor] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None ): assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional -> no, relax this TODO" t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) emb = self.time_embed(t_emb) ## tbd: check the role of "image_only_indicator" num_video_frames = self.num_frames image_only_indicator = torch.zeros( x.shape[0]//num_video_frames, num_video_frames ).to(x.device) if image_only_indicator is None else image_only_indicator if self.num_classes is not None: assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) ## x shape: [bt,c,h,w] h = x hs = [] for module in self.input_blocks: h = module( h, emb, context=context, image_only_indicator=image_only_indicator, time_context=time_context, num_video_frames=num_video_frames, ) hs.append(h) h = self.middle_block( h, emb, context=context, image_only_indicator=image_only_indicator, time_context=time_context, num_video_frames=num_video_frames, ) # svd ctrl if mid_block_additional_residual is not None: h = h + mid_block_additional_residual for module in self.output_blocks: if down_block_additional_residuals is not None: h = th.cat([h, hs.pop() + down_block_additional_residuals.pop()], dim=1) else: h = th.cat([h, hs.pop()], dim=1) h = module( h, emb, context=context, image_only_indicator=image_only_indicator, time_context=time_context, num_video_frames=num_video_frames, ) return self.out(h) class ControlNetConditioningEmbedding(nn.Module): def __init__( self, conditioning_embedding_channels: int, conditioning_channels: int = 3, block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), ): super().__init__() self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) self.blocks = nn.ModuleList([]) for i in range(len(block_out_channels) - 1): channel_in = block_out_channels[i] channel_out = block_out_channels[i + 1] self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) self.conv_out = zero_module( nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) ) def forward(self, conditioning): embedding = self.conv_in(conditioning) embedding = F.silu(embedding) for block in self.blocks: embedding = block(embedding) embedding = F.silu(embedding) embedding = self.conv_out(embedding) return embedding class MaskEmbedding(nn.Module): def __init__( self, conditioning_embedding_channels: int, conditioning_channels: int = 3, block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), ): super().__init__() self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) self.blocks = nn.ModuleList([]) for i in range(len(block_out_channels) - 1): channel_in = block_out_channels[i] channel_out = block_out_channels[i + 1] self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) self.conv_out = zero_module( nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) ) def forward(self, conditioning): embedding = self.conv_in(conditioning) embedding = F.silu(embedding) for block in self.blocks: embedding = block(embedding) embedding = F.silu(embedding) embedding = self.conv_out(embedding) return embedding class WeightEmbedding(nn.Module): def __init__( self, conditioning_embedding_channels: int, conditioning_channels: int = 3, block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), ): super().__init__() self.conv_in = nn.Conv2d(conditioning_channels*2, block_out_channels[0], kernel_size=3, padding=1) self.blocks = nn.ModuleList([]) for i in range(len(block_out_channels) - 1): channel_in = block_out_channels[i] channel_out = block_out_channels[i + 1] self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) self.conv_out = zero_module( nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) ) self.sig = nn.Sigmoid() self.conv_final = nn.Conv2d(conditioning_embedding_channels, conditioning_embedding_channels, kernel_size=3, padding=1) def forward(self, conditioning, t, cond_embeddings, mask_embeddings): t_cond = t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(conditioning.shape) conditioning = torch.cat([conditioning, t_cond], dim=1) w = self.conv_in(conditioning) w = F.silu(w) for block in self.blocks: w = block(w) w = F.silu(w) w = self.conv_out(w) w = self.sig(w) embedding = cond_embeddings*w+mask_embeddings*(1-w) embedding = self.conv_final(embedding) return embedding class VideoCtrlNet(nn.Module): def __init__( self, in_channels: int, model_channels: int, num_frames: int, num_res_blocks: int, attention_resolutions: int, dropout: float = 0.0, channel_mult: List[int] = (1, 2, 4, 8), conv_resample: bool = True, dims: int = 2, num_classes: Optional[int] = None, use_checkpoint: bool = False, num_heads: int = -1, num_head_channels: int = -1, num_heads_upsample: int = -1, use_scale_shift_norm: bool = False, resblock_updown: bool = False, transformer_depth: Union[List[int], int] = 1, transformer_depth_middle: Optional[int] = None, context_dim: Optional[int] = None, time_downup: bool = False, time_context_dim: Optional[int] = None, extra_ff_mix_layer: bool = False, use_spatial_context: bool = False, merge_strategy: str = "fixed", merge_factor: float = 0.5, spatial_transformer_attn_type: str = "softmax", video_kernel_size: Union[int, List[int]] = 3, use_linear_in_transformer: bool = False, adm_in_channels: Optional[int] = None, disable_temporal_crossattention: bool = False, max_ddpm_temb_period: int = 10000, # ctrlnet conditioning_channels: int = 3, conditioning_channels_mask: int = 1, conditioning_channels_region: int = 1, conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), ctrlnet_scale: float = 1.0, ctrlnet_ckpt_path: Optional[str] = None, if_init_from_unet = True, ): super(VideoCtrlNet, self).__init__() assert context_dim is not None if num_heads_upsample == -1: num_heads_upsample = num_heads if num_heads == -1: assert num_head_channels != -1 if num_head_channels == -1: assert num_heads != -1 self.in_channels = in_channels self.model_channels = model_channels self.num_frames = num_frames if isinstance(transformer_depth, int): transformer_depth = len(channel_mult) * [transformer_depth] transformer_depth_middle = default( transformer_depth_middle, transformer_depth[-1] ) self.dims = dims self.num_res_blocks = num_res_blocks self.attention_resolutions = attention_resolutions self.dropout = dropout self.channel_mult = channel_mult self.conv_resample = conv_resample self.num_classes = num_classes self.use_checkpoint = use_checkpoint self.if_init_from_unet = if_init_from_unet self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample self.ctrlnet_scale = ctrlnet_scale self.ctrlnet_ckpt_path = ctrlnet_ckpt_path time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ) if self.num_classes is not None: if isinstance(self.num_classes, int): self.label_emb = nn.Embedding(num_classes, time_embed_dim) elif self.num_classes == "continuous": print("setting up linear c_adm embedding layer") self.label_emb = nn.Linear(1, time_embed_dim) elif self.num_classes == "timestep": self.label_emb = nn.Sequential( Timestep(model_channels), nn.Sequential( linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ), ) elif self.num_classes == "sequential": assert adm_in_channels is not None self.label_emb = nn.Sequential( nn.Sequential( linear(adm_in_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ) ) else: raise ValueError() # control net conditioning embedding self.controlnet_cond_embedding = ControlNetConditioningEmbedding( conditioning_embedding_channels=model_channels, block_out_channels=conditioning_embedding_out_channels, conditioning_channels=conditioning_channels, ) self.mask_embedding = MaskEmbedding( conditioning_embedding_channels=model_channels, block_out_channels=conditioning_embedding_out_channels, conditioning_channels=conditioning_channels_mask, ) self.weight_embedding = WeightEmbedding( conditioning_embedding_channels=model_channels, block_out_channels=conditioning_embedding_out_channels, conditioning_channels=conditioning_channels_region, ) self.controlnet_down_blocks = nn.ModuleList([ self.make_zero_conv(model_channels), ]) self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( conv_nd(dims, in_channels, model_channels, 3, padding=1) ) ] ) self._feature_size = model_channels input_block_chans = [model_channels] ch = model_channels ds = 1 def get_attention_layer( ch, num_heads, dim_head, depth=1, context_dim=None, use_checkpoint=False, disabled_sa=False, ): return SpatialVideoTransformer( ch, num_heads, dim_head, depth=depth, context_dim=context_dim, time_context_dim=time_context_dim, dropout=dropout, ff_in=extra_ff_mix_layer, use_spatial_context=use_spatial_context, merge_strategy=merge_strategy, merge_factor=merge_factor, checkpoint=use_checkpoint, use_linear=use_linear_in_transformer, attn_mode=spatial_transformer_attn_type, disable_self_attn=disabled_sa, disable_temporal_crossattention=disable_temporal_crossattention, max_time_embed_period=max_ddpm_temb_period, ) def get_resblock( merge_factor, merge_strategy, video_kernel_size, ch, time_embed_dim, dropout, out_ch, dims, use_checkpoint, use_scale_shift_norm, down=False, up=False, ): return VideoResBlock( merge_factor=merge_factor, merge_strategy=merge_strategy, video_kernel_size=video_kernel_size, channels=ch, emb_channels=time_embed_dim, dropout=dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=down, up=up, ) for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): layers = [ get_resblock( merge_factor=merge_factor, merge_strategy=merge_strategy, video_kernel_size=video_kernel_size, ch=ch, time_embed_dim=time_embed_dim, dropout=dropout, out_ch=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = mult * model_channels if ds in attention_resolutions: if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels layers.append( get_attention_layer( ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, use_checkpoint=use_checkpoint, disabled_sa=False, ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) # ctrlnet blocks self.controlnet_down_blocks.append(self.make_zero_conv(ch)) self._feature_size += ch input_block_chans.append(ch) if level != len(channel_mult) - 1: ds *= 2 out_ch = ch self.input_blocks.append( TimestepEmbedSequential( get_resblock( merge_factor=merge_factor, merge_strategy=merge_strategy, video_kernel_size=video_kernel_size, ch=ch, time_embed_dim=time_embed_dim, dropout=dropout, out_ch=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True, ) if resblock_updown else Downsample( ch, conv_resample, dims=dims, out_channels=out_ch, third_down=time_downup, ) ) ) ch = out_ch input_block_chans.append(ch) self.controlnet_down_blocks.append(self.make_zero_conv(ch)) self._feature_size += ch # self.step_cur = 0 if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels # ctrlnet mid block self.controlnet_mid_block = self.make_zero_conv(ch) self.middle_block = TimestepEmbedSequential( get_resblock( merge_factor=merge_factor, merge_strategy=merge_strategy, video_kernel_size=video_kernel_size, ch=ch, time_embed_dim=time_embed_dim, out_ch=None, dropout=dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), get_attention_layer( ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, use_checkpoint=use_checkpoint, ), get_resblock( merge_factor=merge_factor, merge_strategy=merge_strategy, video_kernel_size=video_kernel_size, ch=ch, out_ch=None, time_embed_dim=time_embed_dim, dropout=dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), ) self._feature_size += ch def make_zero_conv(self, channels): return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))) def init_from_ckpt(self, ckpt_path=None): path = self.ctrlnet_ckpt_path if ckpt_path is not None: path = ckpt_path if path.endswith("ckpt"): sd = torch.load(path, map_location="cpu")["state_dict"] sd = filter(lambda x: x[0].startswith("model.ctrlnet_model"), sd.items()) sd = {k.replace("model.ctrlnet_model.", ""): v for k, v in sd} elif path.endswith("bin"): sd = torch.load(path, map_location="cpu") elif path.endswith("safetensors"): sd = load_safetensors(path) else: raise NotImplementedError missing, unexpected = self.load_state_dict(sd, strict=False) print( f"@CtrlNet: Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" ) if len(missing) > 0: print(f"Missing Keys: {missing}") if len(unexpected) > 0: print(f"Unexpected Keys: {unexpected}") def init_from_unet(self, unet): self.input_blocks.load_state_dict(unet.input_blocks.state_dict()) self.middle_block.load_state_dict(unet.middle_block.state_dict()) print("init from unet successfully!") def forward( self, x: th.Tensor, conds: th.Tensor, mask: th.Tensor, region: th.Tensor, timesteps: th.Tensor, context: Optional[th.Tensor] = None, y: Optional[th.Tensor] = None, time_context: Optional[th.Tensor] = None, num_video_frames: Optional[int] = None, image_only_indicator: Optional[th.Tensor] = None, ): assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional -> no, relax this TODO" t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) emb = self.time_embed(t_emb) ## tbd: check the role of "image_only_indicator" num_video_frames = self.num_frames image_only_indicator = torch.zeros( x.shape[0]//num_video_frames, num_video_frames ).to(x.device) if image_only_indicator is None else image_only_indicator if self.num_classes is not None: assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) # ctrlnet cond_embeddings = self.controlnet_cond_embedding(conds) mask = mask.permute(0,2,1,3,4) mask = mask.reshape(mask.shape[0]*mask.shape[1], mask.shape[2], mask.shape[3], mask.shape[4]) mask_embeddings = self.mask_embedding(mask) cond_embeddings = self.weight_embedding(region, timesteps, cond_embeddings, mask_embeddings) ##*0.1 ## x shape: [bt,c,h,w] h = x down_block_res_samples = [] for module, zero_module in zip(self.input_blocks, self.controlnet_down_blocks): if cond_embeddings is not None: # for the conv_in block h = module( h, emb, context=context, image_only_indicator=image_only_indicator, time_context=time_context, num_video_frames=num_video_frames, ) h += cond_embeddings cond_embeddings = None else: h = module( h, emb, context=context, image_only_indicator=image_only_indicator, time_context=time_context, num_video_frames=num_video_frames, ) down_block_res_samples.append( zero_module( h, emb, context=context, image_only_indicator=image_only_indicator, time_context=time_context, num_video_frames=num_video_frames, ) * self.ctrlnet_scale ) h = self.middle_block( h, emb, context=context, image_only_indicator=image_only_indicator, time_context=time_context, num_video_frames=num_video_frames, ) mid_block_res_sample = self.controlnet_mid_block( h, emb, context=context, image_only_indicator=image_only_indicator, time_context=time_context, num_video_frames=num_video_frames, ) * self.ctrlnet_scale return down_block_res_samples, mid_block_res_sample ================================================ FILE: main/inference/sample_constant_motion.py ================================================ import datetime, time import os, sys, argparse import math from glob import glob from pathlib import Path from typing import Optional import cv2 import numpy as np import torch from einops import rearrange, repeat from fire import Fire from omegaconf import OmegaConf from PIL import Image from torchvision.transforms import ToTensor sys.path.insert(1, os.path.join(sys.path[0], '..', '..')) from sgm.util import default, instantiate_from_config from utils.save_video import save_flow_video, save_rgb_video import torch.nn as nn from utils.visualizer import Visualizer import time from utils.tools import resize_pil_image, quick_freeze, get_gaussian_kernel, get_batch, get_unique_embedder_keys_from_conditioner, load_model if not os.path.exists('ckpt'): os.makedirs('ckpt') if not os.path.exists('ckpt/model.ckpt'): torch.hub.download_url_to_file( 'https://huggingface.co/Adapter/ReVideo/resolve/main/model.ckpt', 'ckpt/model.ckpt') def sample( input_path: str = "outputs/inputs/test_image.png", # Can either be image file or folder with image files path_ref: str = None, ckpt: str = "checkpoints/svd.safetensors", config: str = None, num_frames: Optional[int] = None, num_steps: Optional[int] = None, version: str = "svd", fps_id: int = 6, motion_bucket_id: int = 127, cond_aug: float = 0.02, seed: int = 23, decoding_t: int = 1, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. device: str = "cuda", output_folder: Optional[str] = None, save_fps: int = 10, resize: Optional[bool] = False, # points = None s_w = None, e_w = None, s_h = None, e_h = None, ps_h = None, ps_w = None, ): """ Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`. """ # flow cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker2").to(device) guassian_filter = quick_freeze(get_gaussian_kernel(kernel_size=51, sigma=10, channels=2)).to(device) visualizer = Visualizer(tracks_leave_trace=-1, show_first_frame=0, mode='cool', linewidth=2) visualizer_layer = Visualizer(tracks_leave_trace=-1, show_first_frame=0, mode='cool', linewidth=4) torch.manual_seed(seed) all_img_paths = os.listdir(input_path) all_img_paths.sort() for i in range(len(all_img_paths)): all_img_paths[i] = os.path.join(input_path, all_img_paths[i]) print(f'loaded {len(all_img_paths)} images.') os.makedirs(output_folder, exist_ok=True) images = [] for no, input_img_path in enumerate(all_img_paths): filepath, fullflname = os.path.split(input_img_path) filename, ext = os.path.splitext(fullflname) print(f'-sample {no+1}: {filename} ...') with Image.open(input_img_path) as image: if image.mode == "RGBA": image = image.convert("RGB") w_org, h_org = image.size image = resize_pil_image(image, max_resolution=1024*1024) # image = resize_pil_image(image, max_resolution=896*896) scale_h = image.size[0]/w_org scale_w = image.size[1]/h_org w, h = image.size if h % 64 != 0 or w % 64 != 0: width, height = map(lambda x: x - x % 64, (w, h)) image = image.resize((width, height)) print( f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!" ) image = ToTensor()(image) image = image * 2.0 - 1.0 image = image.unsqueeze(0).to(device) H, W = image.shape[2:] assert image.shape[1] == 3 F = 8 C = 4 shape = (num_frames, C, H // F, W // F) images.append(image) images = images[:num_frames]#[::-1]#.flip(dims=(1,)) images = torch.stack(images, dim=2) save_rgb_video((images.flip(1)+1)/2.,'outputs/org.mp4') img_ref = Image.open(path_ref) if img_ref.mode == "RGBA": img_ref = img_ref.convert("RGB") img_ref = img_ref.resize((images.shape[-1], images.shape[-2])) img_ref = ToTensor()(img_ref) img_ref = img_ref * 2.0 - 1.0 img_ref = img_ref.unsqueeze(0).to(device) with torch.no_grad(): vid_cotracker = ((images+1)/2.).permute(0,2,1,3,4)*255. grid = torch.zeros(1,len(ps_h),3) for i in range(len(ps_h)): grid[0,i,1] = ps_h[i]*scale_h grid[0,i,2] = ps_w[i]*scale_w grid = grid.to(device) tracks, _ = cotracker(vid_cotracker, queries=grid) # B T N 2, B T N 1 layer = torch.ones_like(images.permute(0,2,1,3,4)) res_video = visualizer_layer.visualize(layer.cpu()*255., tracks=tracks, save_video=False) res_video = ( (rearrange(res_video[0], "t c h w -> t h w c")) .numpy() .astype(np.uint8) ) frame = cv2.cvtColor(res_video[-1], cv2.COLOR_RGB2BGR) filename = input_path.split('/')[-1] cv2.imwrite(os.path.join('vis_im', filename, 'layer.png'), frame) b,c,n,w,h = images.shape s_w = int(s_w*scale_w) e_w = int(e_w*scale_w) s_h = int(s_h*scale_h) e_h = int(e_h*scale_h) model_config = config model, filter = load_model( model_config, ckpt, device, num_frames, num_steps, ) select_point = tracks.flip(-1) maps = [] for i in range(num_frames-1): map = torch.zeros((1,2,img_ref.shape[-2],img_ref.shape[-1])).to(device) rows = select_point[:, i+1, :, 0].to(device).int()#, dtype=torch.int64) cols = select_point[:, i+1, :, 1].to(device).int()#, dtype=torch.int64) rows = torch.clip(rows, 0, w-1) cols = torch.clip(cols, 0, h-1) map[0,:,rows[0], cols[0]] = (select_point[0, i+1] - select_point[0, i]).permute(1,0)#flow[kk, jj+1, :, rows[kk], cols[kk]].clone()#.cpu() maps.append(map) maps = [torch.zeros_like(maps[0])]+maps maps = torch.stack(maps, dim=2)#.reshape(b*n,2,w,h) guassian_filter = quick_freeze(get_gaussian_kernel(kernel_size=51, sigma=10, channels=2)).to(device) with torch.no_grad(): maps = maps.permute(0,2,1,3,4).reshape(b*n,2,w,h) maps = guassian_filter(maps).reshape(b,n,2,w,h) images[:,:,:,s_w:e_w,s_h:e_h] = -1 save_flow_video(maps.permute(0,2,1,3,4), 'outputs/flow.mp4') save_rgb_video((images.flip(1)+1)/2.,'outputs/content.mp4') flow = maps.reshape(b*n,2,w,h) region = torch.zeros_like(images)[:,:1] region[:,:,:,s_w:e_w,s_h:e_h] = 1 region = region.permute(0,2,1,3,4).reshape(b*n,1,w,h) value_dict = {} print(images.shape) value_dict["video"] = images.to(dtype=torch.float16) value_dict["region"] = region.to(dtype=torch.float16) value_dict["motion_bucket_id"] = motion_bucket_id value_dict["fps_id"] = fps_id value_dict["cond_aug"] = cond_aug value_dict["cond_frames_without_noise"] = img_ref.to(dtype=torch.float16) #images[:,:,0] value_dict["cond_frames"] = (img_ref + cond_aug * torch.randn_like(images[:,:,0])).to(dtype=torch.float16) print(cond_aug) with torch.no_grad(): with torch.autocast(device): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, [1, num_frames], T=num_frames, device=device, ) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=[ "cond_frames", "cond_frames_without_noise", ], ) for k in ["crossattn", "concat"]: uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) randn = torch.randn(shape, device=device, dtype = torch.float16) additional_model_inputs = {} additional_model_inputs["image_only_indicator"] = torch.zeros( 2, num_frames ).to(device, dtype=torch.float16) #additional_model_inputs["image_only_indicator"][:,0] = 1 additional_model_inputs["num_video_frames"] = batch["num_video_frames"] def denoiser(input, sigma, c): return model.denoiser( model.model, input, sigma, c, **additional_model_inputs ) c['mask'] = images.clone() uc['mask'] = images.clone() c['region'] = region.clone() uc['region'] = region.clone() c['ctrl_input'] = flow.clone() uc['ctrl_input'] = flow.clone() samples_z = model.sampler(denoiser, randn, cond=c, uc=uc) model.en_and_decode_n_samples_a_time = decoding_t samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) filename = input_path.split('/')[-1] res_video = visualizer.visualize(samples.unsqueeze(0).cpu()*255., tracks=tracks, save_video=False) visualizer.save_video(res_video, filename='result_cotracker_%s'%filename, writer=None, step=0) os.makedirs(os.path.join('vis_im', filename, 'im_w_track'), exist_ok=True) os.makedirs(os.path.join('vis_im', filename, 'im_wo_track'), exist_ok=True) vid = ( (rearrange(samples, "t c h w -> t h w c") * 255) .cpu() .numpy() .astype(np.uint8) ) vid_wo_track = ( (rearrange(res_video[0], "t c h w -> t h w c")) .cpu() .numpy() .astype(np.uint8) ) for idx, frame in enumerate(vid): frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) name = '%04d.png'%idx cv2.imwrite(os.path.join('vis_im', filename, 'im_wo_track', name), frame) for idx, frame in enumerate(vid_wo_track): frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) name = '%04d.png'%idx cv2.imwrite(os.path.join('vis_im', filename, 'im_w_track', name), frame) print(f'Done! results saved in {output_folder}.') def get_parser(): parser = argparse.ArgumentParser() parser.add_argument("--seed", type=int, default=23, help="seed for seed_everything") parser.add_argument("--ckpt", type=str, default=None, help="checkpoint path") parser.add_argument("--config", type=str, help="config (yaml) path") parser.add_argument("--input", type=str, default=None, help="image path or folder") parser.add_argument("--path_ref", type=str, default=None, help="reference image path") parser.add_argument("--savedir", type=str, default=None, help="results saving path") parser.add_argument("--savefps", type=int, default=10, help="video fps to generate") parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt",) parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM",) parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)",) parser.add_argument("--frames", type=int, default=-1, help="frames num to inference") parser.add_argument("--fps", type=int, default=6, help="control the fps") parser.add_argument("--motion", type=int, default=127, help="control the motion magnitude") parser.add_argument("--cond_aug", type=float, default=0.02, help="adding noise to input image") parser.add_argument("--decoding_t", type=int, default=1, help="frames num to decoding per time") parser.add_argument("--resize", action='store_true', default=False, help="resize all input to default resolution") parser.add_argument("--s_w", type=int, default=None) parser.add_argument("--e_w", type=int, default=None) parser.add_argument("--s_h", type=int, default=None) parser.add_argument("--e_h", type=int, default=None) parser.add_argument("--ps_h", metavar="N", type=int, nargs="+", default=None) parser.add_argument("--ps_w", metavar="N", type=int, nargs="+", default=None) return parser if __name__ == "__main__": now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") print("@SVD Inference: %s"%now) #Fire(sample) parser = get_parser() args = parser.parse_args() sample(input_path=args.input, path_ref=args.path_ref, ckpt=args.ckpt, config=args.config, num_frames=args.frames, num_steps=args.ddim_steps, \ fps_id=args.fps, motion_bucket_id=args.motion, cond_aug=args.cond_aug, seed=args.seed, \ decoding_t=args.decoding_t, output_folder=args.savedir, save_fps=args.savefps, resize=args.resize, s_w=args.s_w, e_w=args.e_w, s_h=args.s_h, e_h=args.e_h, ps_w=args.ps_w, ps_h=args.ps_h) ================================================ FILE: main/inference/sample_multi_region.py ================================================ import datetime, time import os, sys, argparse import math from glob import glob from pathlib import Path from typing import Optional import cv2 import numpy as np import torch from einops import rearrange, repeat from fire import Fire from PIL import Image from torchvision.transforms import ToTensor sys.path.insert(1, os.path.join(sys.path[0], '..', '..')) from utils.save_video import save_flow_video, save_rgb_video import torch.nn as nn from utils.visualizer import Visualizer from utils.tools import resize_pil_image, quick_freeze, get_gaussian_kernel, get_batch, get_unique_embedder_keys_from_conditioner, load_model if not os.path.exists('ckpt'): os.makedirs('ckpt') if not os.path.exists('ckpt/model.ckpt'): torch.hub.download_url_to_file( 'https://huggingface.co/Adapter/ReVideo/resolve/main/model.ckpt', 'ckpt/model.ckpt') def sample( input_path: str = "outputs/inputs/test_image.png", # Can either be image file or folder with image files path_ref: str = None, ckpt: str = "checkpoints/svd.safetensors", config: str = None, num_frames: Optional[int] = None, num_steps: Optional[int] = None, version: str = "svd", fps_id: int = 6, motion_bucket_id: int = 127, cond_aug: float = 0.02, seed: int = 23, decoding_t: int = 1, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. device: str = "cuda", output_folder: Optional[str] = None, save_fps: int = 10, resize: Optional[bool] = False, # points = None s_w = None, e_w = None, s_h = None, e_h = None, ps_w = None, pe_w = None, ps_h = None, pe_h = None, x_bias_all = None, y_bias_all = None, ): """ Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`. """ # flow guassian_filter = quick_freeze(get_gaussian_kernel(kernel_size=51, sigma=10, channels=2)).to(device) visualizer = Visualizer(tracks_leave_trace=-1, show_first_frame=0, mode='cool', linewidth=2) visualizer_layer = Visualizer(tracks_leave_trace=-1, show_first_frame=0, mode='cool', linewidth=4) torch.manual_seed(seed) all_img_paths = os.listdir(input_path) all_img_paths.sort() for i in range(len(all_img_paths)): all_img_paths[i] = os.path.join(input_path, all_img_paths[i]) print(f'loaded {len(all_img_paths)} images.') os.makedirs(output_folder, exist_ok=True) images = [] for no, input_img_path in enumerate(all_img_paths): filepath, fullflname = os.path.split(input_img_path) filename, ext = os.path.splitext(fullflname) print(f'-sample {no+1}: {filename} ...') with Image.open(input_img_path) as image: if image.mode == "RGBA": image = image.convert("RGB") w_org, h_org = image.size image = resize_pil_image(image, max_resolution=1024*1024) scale_h = image.size[0]/w_org scale_w = image.size[1]/h_org w, h = image.size if h % 64 != 0 or w % 64 != 0: width, height = map(lambda x: x - x % 64, (w, h)) image = image.resize((width, height)) print( f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!" ) image = ToTensor()(image) image = image * 2.0 - 1.0 image = image.unsqueeze(0).to(device) H, W = image.shape[2:] assert image.shape[1] == 3 F = 8 C = 4 shape = (num_frames, C, H // F, W // F) images.append(image) images = images[:num_frames] images = torch.stack(images, dim=2) save_rgb_video((images.flip(1)+1)/2.,'outputs/org.mp4') img_ref = Image.open(path_ref) if img_ref.mode == "RGBA": img_ref = img_ref.convert("RGB") img_ref = img_ref.resize((images.shape[-1], images.shape[-2])) img_ref = ToTensor()(img_ref) img_ref = img_ref * 2.0 - 1.0 img_ref = img_ref.unsqueeze(0).to(device) b,c,n,w,h = images.shape map = torch.zeros((1,2,num_frames-1,img_ref.shape[-2],img_ref.shape[-1])).to(device) region = torch.zeros_like(images)[:,:1] tracks = [] assert len(x_bias_all)==len(y_bias_all) and len(x_bias_all)==len(s_w), 'Wrong bias!' for k in range(len(s_w)): s_w[k] = int(s_w[k]*scale_w) e_w[k] = int(e_w[k]*scale_w) s_h[k] = int(s_h[k]*scale_h) e_h[k] = int(e_h[k]*scale_h) images[:,:,:,s_w[k]:e_w[k],s_h[k]:e_h[k]] = -1 region[:,:,:,s_w[k]:e_w[k],s_h[k]:e_h[k]] = 1 p_start = [ps_h[k]*scale_h, ps_w[k]*scale_w] p_end = [pe_h[k]*scale_h, pe_w[k]*scale_w] x_dist = p_end[0]-p_start[0] y_dist = p_end[1]-p_start[1] inter_x = x_dist/(num_frames-2) inter_y = y_dist/(num_frames-2) x_bias = x_bias_all[k] y_bias = y_bias_all[k] for i in range(num_frames-1): x_cur = int(p_start[0]+i*inter_x + x_bias[i%len(x_bias)]) y_cur = int(p_start[1]+i*inter_y + y_bias[i%len(y_bias)]) if i == 0: x_per = p_start[0] y_per = p_start[1] else: x_per = int(p_start[0]+(i-1)*inter_x + x_bias[(i-1)%len(x_bias)]) y_per = int(p_start[1]+(i-1)*inter_y+ y_bias[(i-1)%len(y_bias)]) map[:,1,i,y_cur,x_cur] = x_cur - x_per map[:,0,i,y_cur,x_cur] = y_cur - y_per track = torch.zeros(b,num_frames,1,2) for i in range(num_frames): if i == 0: x_cur = p_start[0] y_cur = p_start[1] else: x_cur = int(p_start[0]+i*inter_x + x_bias[i%len(x_bias)]) y_cur = int(p_start[1]+i*inter_y + y_bias[i%len(y_bias)]) track[0,i,0,0]=x_cur track[0,i,0,1]=y_cur tracks.append(track) tracks = torch.cat(tracks, dim=2) layer = torch.ones_like(images.permute(0,2,1,3,4)) res_video = visualizer_layer.visualize(layer.cpu()*255., tracks=tracks, save_video=False) res_video = ( (rearrange(res_video[0], "t c h w -> t h w c")) .numpy() .astype(np.uint8) ) frame = cv2.cvtColor(res_video[-1], cv2.COLOR_RGB2BGR) filename = input_path.split('/')[-1] cv2.imwrite(os.path.join('vis_im', filename, 'layer.png'), frame) pad = torch.zeros_like(map[:,:,:1,:,:]) map = torch.cat([pad, map], dim=2) with torch.no_grad(): map = map.permute(0,2,1,3,4).reshape(b*n,2,w,h) map = guassian_filter(map).reshape(b,n,2,w,h) save_flow_video(map.permute(0,2,1,3,4), 'outputs/flow.mp4') save_rgb_video((images.flip(1)+1)/2.,'outputs/content.mp4') flow = map.reshape(b*n,2,w,h) model_config = config model, filter = load_model( model_config, ckpt, device, num_frames, num_steps, ) region = region.permute(0,2,1,3,4).reshape(b*n,1,w,h) value_dict = {} value_dict["video"] = images.to(dtype=torch.float16) value_dict["region"] = region.to(dtype=torch.float16) value_dict["motion_bucket_id"] = motion_bucket_id value_dict["fps_id"] = fps_id value_dict["cond_aug"] = cond_aug value_dict["cond_frames_without_noise"] = img_ref.to(dtype=torch.float16) #images[:,:,0] value_dict["cond_frames"] = img_ref.to(dtype=torch.float16) #images[:,:,0] # + cond_aug * torch.randn_like(images[:,:,0]) with torch.no_grad(): with torch.autocast(device): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, [1, num_frames], T=num_frames, device=device, ) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=[ "cond_frames", "cond_frames_without_noise", ], ) for k in ["crossattn", "concat"]: uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) randn = torch.randn(shape, device=device, dtype=torch.float16) additional_model_inputs = {} additional_model_inputs["image_only_indicator"] = torch.zeros( 2, num_frames ).to(device, dtype=torch.float16) additional_model_inputs["num_video_frames"] = batch["num_video_frames"] def denoiser(input, sigma, c): return model.denoiser( model.model, input, sigma, c, **additional_model_inputs ) c['mask'] = images.clone() uc['mask'] = images.clone() c['region'] = region.clone() uc['region'] = region.clone() c['ctrl_input'] = flow.clone() uc['ctrl_input'] = flow.clone() samples_z = model.sampler(denoiser, randn, cond=c, uc=uc) model.en_and_decode_n_samples_a_time = decoding_t samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) filename = input_path.split('/')[-1] res_video = visualizer.visualize(samples.unsqueeze(0).cpu()*255., tracks=tracks, save_video=False) visualizer.save_video(res_video, filename='result_cotracker_%s'%filename, writer=None, step=0) os.makedirs(os.path.join('vis_im', filename, 'im_w_track'), exist_ok=True) os.makedirs(os.path.join('vis_im', filename, 'im_wo_track'), exist_ok=True) vid = ( (rearrange(samples, "t c h w -> t h w c") * 255) .cpu() .numpy() .astype(np.uint8) ) vid_wo_track = ( (rearrange(res_video[0], "t c h w -> t h w c")) .cpu() .numpy() .astype(np.uint8) ) for idx, frame in enumerate(vid): frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) name = '%04d.png'%idx cv2.imwrite(os.path.join('vis_im', filename, 'im_wo_track', name), frame) for idx, frame in enumerate(vid_wo_track): frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) name = '%04d.png'%idx cv2.imwrite(os.path.join('vis_im', filename, 'im_w_track', name), frame) print(f'Done! results saved in {output_folder}.') def get_parser(): parser = argparse.ArgumentParser() parser.add_argument("--seed", type=int, default=23, help="seed for seed_everything") parser.add_argument("--ckpt", type=str, default=None, help="checkpoint path") parser.add_argument("--config", type=str, help="config (yaml) path") parser.add_argument("--input", type=str, default=None, help="image path or folder") parser.add_argument("--path_ref", type=str, default=None, help="reference image path") parser.add_argument("--savedir", type=str, default=None, help="results saving path") parser.add_argument("--savefps", type=int, default=10, help="video fps to generate") parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt",) parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM",) parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)",) parser.add_argument("--frames", type=int, default=-1, help="frames num to inference") parser.add_argument("--fps", type=int, default=6, help="control the fps") parser.add_argument("--motion", type=int, default=127, help="control the motion magnitude") parser.add_argument("--cond_aug", type=float, default=0.02, help="adding noise to input image") parser.add_argument("--decoding_t", type=int, default=1, help="frames num to decoding per time") parser.add_argument("--resize", action='store_true', default=False, help="resize all input to default resolution") parser.add_argument("--s_w", metavar="N", type=int, nargs="+", default=None) parser.add_argument("--e_w", metavar="N", type=int, nargs="+", default=None) parser.add_argument("--s_h", metavar="N", type=int, nargs="+", default=None) parser.add_argument("--e_h", metavar="N", type=int, nargs="+", default=None) parser.add_argument("--ps_w", metavar="N", type=int, nargs="+", default=None) parser.add_argument("--pe_w", metavar="N", type=int, nargs="+", default=None) parser.add_argument("--ps_h", metavar="N", type=int, nargs="+", default=None) parser.add_argument("--pe_h", metavar="N", type=int, nargs="+", default=None) parser.add_argument("--x_bias_all", nargs="+", action="append", type=int, help="Horizontal swing") parser.add_argument("--y_bias_all", nargs="+", action="append", type=int, help="Vertical swing") return parser if __name__ == "__main__": now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") print("@SVD Inference: %s"%now) #Fire(sample) parser = get_parser() args = parser.parse_args() sample(input_path=args.input, path_ref=args.path_ref, ckpt=args.ckpt, config=args.config, num_frames=args.frames, num_steps=args.ddim_steps, \ fps_id=args.fps, motion_bucket_id=args.motion, cond_aug=args.cond_aug, seed=args.seed, \ decoding_t=args.decoding_t, output_folder=args.savedir, save_fps=args.savefps, resize=args.resize, \ s_w=args.s_w, e_w=args.e_w, s_h=args.s_h, e_h=args.e_h, ps_w=args.ps_w, pe_w=args.pe_w, ps_h=args.ps_h, \ pe_h=args.pe_h, x_bias_all=args.x_bias_all, y_bias_all=args.y_bias_all) ================================================ FILE: main/inference/sample_single_region.py ================================================ import datetime, time import os, sys, argparse import math from glob import glob from pathlib import Path from typing import Optional import cv2 import numpy as np import torch from einops import rearrange, repeat from fire import Fire from omegaconf import OmegaConf from PIL import Image from torchvision.transforms import ToTensor sys.path.insert(1, os.path.join(sys.path[0], '..', '..')) from sgm.util import default, instantiate_from_config from utils.save_video import save_flow_video, save_rgb_video import torch.nn as nn from utils.visualizer import Visualizer from utils.tools import resize_pil_image, quick_freeze, get_gaussian_kernel, get_batch, get_unique_embedder_keys_from_conditioner, load_model if not os.path.exists('ckpt'): os.makedirs('ckpt') if not os.path.exists('ckpt/model.ckpt'): torch.hub.download_url_to_file( 'https://huggingface.co/Adapter/ReVideo/resolve/main/model.ckpt', 'ckpt/model.ckpt') def sample( input_path: str = "outputs/inputs/test_image.png", # Can either be image file or folder with image files path_ref: str = None, ckpt: str = "checkpoints/svd.safetensors", config: str = None, num_frames: Optional[int] = None, num_steps: Optional[int] = None, version: str = "svd", fps_id: int = 6, motion_bucket_id: int = 127, cond_aug: float = 0.02, seed: int = 23, decoding_t: int = 1, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. device: str = "cuda", output_folder: Optional[str] = None, save_fps: int = 10, resize: Optional[bool] = False, # points = None s_w = None, e_w = None, s_h = None, e_h = None, ps_w = None, pe_w = None, ps_h = None, pe_h = None, x_bias = None, y_bias = None, ): """ Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`. """ guassian_filter = quick_freeze(get_gaussian_kernel(kernel_size=51, sigma=10, channels=2)).to(device) visualizer = Visualizer(tracks_leave_trace=-1, show_first_frame=0, mode='cool', linewidth=4) visualizer_layer = Visualizer(tracks_leave_trace=-1, show_first_frame=0, mode='cool', linewidth=4) torch.manual_seed(seed) all_img_paths = os.listdir(input_path) all_img_paths.sort() for i in range(len(all_img_paths)): all_img_paths[i] = os.path.join(input_path, all_img_paths[i]) print(f'loaded {len(all_img_paths)} images.') os.makedirs(output_folder, exist_ok=True) images = [] for no, input_img_path in enumerate(all_img_paths): filepath, fullflname = os.path.split(input_img_path) filename, ext = os.path.splitext(fullflname) print(f'-sample {no+1}: {filename} ...') with Image.open(input_img_path) as image: if image.mode == "RGBA": image = image.convert("RGB") w_org, h_org = image.size image = resize_pil_image(image, max_resolution=1024*1024) scale_h = image.size[0]/w_org scale_w = image.size[1]/h_org w, h = image.size if h % 64 != 0 or w % 64 != 0: width, height = map(lambda x: x - x % 64, (w, h)) image = image.resize((width, height)) print( f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!" ) image = ToTensor()(image) image = image * 2.0 - 1.0 image = image.unsqueeze(0).to(device) H, W = image.shape[2:] assert image.shape[1] == 3 F = 8 C = 4 shape = (num_frames, C, H // F, W // F) images.append(image) images = images[:num_frames]#[::-1]#.flip(dims=(1,)) images = torch.stack(images, dim=2) save_rgb_video((images.flip(1)+1)/2.,'outputs/org.mp4') img_ref = Image.open(path_ref) if img_ref.mode == "RGBA": img_ref = img_ref.convert("RGB") img_ref = img_ref.resize((images.shape[-1], images.shape[-2])) img_ref = ToTensor()(img_ref) img_ref = img_ref * 2.0 - 1.0 img_ref = img_ref.unsqueeze(0).to(device) b,c,n,w,h = images.shape s_w = int(s_w*scale_w) e_w = int(e_w*scale_w) s_h = int(s_h*scale_h) e_h = int(e_h*scale_h) p_start = [[ps_h_i*scale_h, ps_w_i*scale_w] for ps_h_i, ps_w_i in zip(ps_h, ps_w)] p_end = [[pe_h_i*scale_h, pe_w_i*scale_w] for pe_h_i, pe_w_i in zip(pe_h, pe_w)] tracks = torch.zeros(b,num_frames,len(p_start),2) for k in range(len(p_start)): x_dist = p_end[k][0]-p_start[k][0] y_dist = p_end[k][1]-p_start[k][1] inter_x = x_dist/(num_frames-2) inter_y = y_dist/(num_frames-2) for i in range(num_frames): if i == 0: x_cur = p_start[k][0] y_cur = p_start[k][1] else: x_cur = int(p_start[k][0]+i*inter_x + x_bias[i%len(x_bias)]) y_cur = int(p_start[k][1]+i*inter_y + y_bias[i%len(y_bias)]) tracks[0,i,k,0]=x_cur tracks[0,i,k,1]=y_cur layer = torch.ones_like(images.permute(0,2,1,3,4)) res_video = visualizer_layer.visualize(layer.cpu()*255., tracks=tracks, save_video=False) res_video = ( (rearrange(res_video[0], "t c h w -> t h w c")) .numpy() .astype(np.uint8) ) frame = cv2.cvtColor(res_video[-1], cv2.COLOR_RGB2BGR) filename = input_path.split('/')[-1] cv2.imwrite(os.path.join('vis_im', filename, 'layer.png'), frame) model_config = config model, filter = load_model( model_config, ckpt, device, num_frames, num_steps, ) map = torch.zeros((1,2,num_frames-1,img_ref.shape[-2],img_ref.shape[-1])).to(device) for k in range(len(p_start)): x_dist = p_end[k][0]-p_start[k][0] y_dist = p_end[k][1]-p_start[k][1] inter_x = x_dist/(num_frames-2) inter_y = y_dist/(num_frames-2) for i in range(num_frames-1): x_cur = int(p_start[k][0]+i*inter_x + x_bias[i%len(x_bias)]) y_cur = int(p_start[k][1]+i*inter_y + y_bias[i%len(y_bias)]) if i == 0: x_per = p_start[k][0] y_per = p_start[k][1] else: x_per = int(p_start[k][0]+(i-1)*inter_x + x_bias[(i-1)%len(x_bias)]) y_per = int(p_start[k][1]+(i-1)*inter_y+ y_bias[(i-1)%len(y_bias)]) map[:,1,i,y_cur,x_cur] = x_cur - x_per map[:,0,i,y_cur,x_cur] = y_cur - y_per pad = torch.zeros_like(map[:,:,:1,:,:]) map = torch.cat([pad, map], dim=2) guassian_filter = quick_freeze(get_gaussian_kernel(kernel_size=51, sigma=10, channels=2)).to(device) with torch.no_grad(): map = map.permute(0,2,1,3,4).reshape(b*n,2,w,h) map = guassian_filter(map).reshape(b,n,2,w,h) images[:,:,:,s_w:e_w,s_h:e_h] = -1 save_flow_video(map.permute(0,2,1,3,4), 'outputs/flow.mp4') save_rgb_video((images.flip(1)+1)/2.,'outputs/content.mp4') flow = map.reshape(b*n,2,w,h) region = torch.zeros_like(images)[:,:1] region[:,:,:,s_w:e_w,s_h:e_h] = 1 region = region.permute(0,2,1,3,4).reshape(b*n,1,w,h) value_dict = {} value_dict["video"] = images.to(dtype=torch.float16) value_dict["region"] = region.to(dtype=torch.float16) value_dict["motion_bucket_id"] = motion_bucket_id value_dict["fps_id"] = fps_id value_dict["cond_aug"] = cond_aug value_dict["cond_frames_without_noise"] = img_ref.to(dtype=torch.float16) #images[:,:,0] value_dict["cond_frames"] = (img_ref + cond_aug * torch.randn_like(images[:,:,0])).to(dtype=torch.float16) #images[:,:,0] # + cond_aug * torch.randn_like(images[:,:,0]) with torch.no_grad(): with torch.autocast(device): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, [1, num_frames], T=num_frames, device=device, ) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=[ "cond_frames", "cond_frames_without_noise", ], ) for k in ["crossattn", "concat"]: uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) randn = torch.randn(shape, device=device, dtype=torch.float16) additional_model_inputs = {} additional_model_inputs["image_only_indicator"] = torch.zeros( 2, num_frames ).to(device, dtype=torch.float16) #additional_model_inputs["image_only_indicator"][:,0] = 1 additional_model_inputs["num_video_frames"] = batch["num_video_frames"] def denoiser(input, sigma, c): return model.denoiser( model.model, input, sigma, c, **additional_model_inputs ) c['mask'] = images.clone() uc['mask'] = images.clone() c['region'] = region.clone() uc['region'] = region.clone() c['ctrl_input'] = flow.clone() uc['ctrl_input'] = flow.clone() samples_z = model.sampler(denoiser, randn, cond=c, uc=uc) # model.first_stage_model.to(dtype=torch.float32) model.en_and_decode_n_samples_a_time = decoding_t samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) filename = input_path.split('/')[-1] res_video = visualizer.visualize(samples.unsqueeze(0).cpu()*255., tracks=tracks, save_video=False) visualizer.save_video(res_video, filename='result_cotracker_%s'%filename, writer=None, step=0) os.makedirs(os.path.join('vis_im', filename, 'im_w_track'), exist_ok=True) os.makedirs(os.path.join('vis_im', filename, 'im_wo_track'), exist_ok=True) vid = ( (rearrange(samples, "t c h w -> t h w c") * 255) .cpu() .numpy() .astype(np.uint8) ) vid_wo_track = ( (rearrange(res_video[0], "t c h w -> t h w c")) .cpu() .numpy() .astype(np.uint8) ) for idx, frame in enumerate(vid): frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) name = '%04d.png'%idx cv2.imwrite(os.path.join('vis_im', filename, 'im_wo_track', name), frame) for idx, frame in enumerate(vid_wo_track): frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) name = '%04d.png'%idx cv2.imwrite(os.path.join('vis_im', filename, 'im_w_track', name), frame) print(f'Done! results saved in {output_folder}.') def get_parser(): parser = argparse.ArgumentParser() parser.add_argument("--seed", type=int, default=23, help="seed for seed_everything") parser.add_argument("--ckpt", type=str, default=None, help="checkpoint path") parser.add_argument("--config", type=str, help="config (yaml) path") parser.add_argument("--input", type=str, default=None, help="image path or folder") parser.add_argument("--path_ref", type=str, default=None, help="reference image path") parser.add_argument("--savedir", type=str, default=None, help="results saving path") parser.add_argument("--savefps", type=int, default=10, help="video fps to generate") parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt",) parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM",) parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)",) parser.add_argument("--frames", type=int, default=-1, help="frames num to inference") parser.add_argument("--fps", type=int, default=6, help="control the fps") parser.add_argument("--motion", type=int, default=127, help="control the motion magnitude") parser.add_argument("--cond_aug", type=float, default=0.02, help="adding noise to input image") parser.add_argument("--decoding_t", type=int, default=1, help="frames num to decoding per time") parser.add_argument("--resize", action='store_true', default=False, help="resize all input to default resolution") parser.add_argument("--s_w", type=int, default=None) parser.add_argument("--e_w", type=int, default=None) parser.add_argument("--s_h", type=int, default=None) parser.add_argument("--e_h", type=int, default=None) parser.add_argument("--ps_w", metavar="N", type=int, nargs="+", default=None) parser.add_argument("--pe_w", metavar="N", type=int, nargs="+", default=None) parser.add_argument("--ps_h", metavar="N", type=int, nargs="+", default=None) parser.add_argument("--pe_h", metavar="N", type=int, nargs="+", default=None) parser.add_argument("--x_bias", metavar="N", type=int, nargs="+", default=None) parser.add_argument("--y_bias", metavar="N", type=int, nargs="+", default=None) return parser if __name__ == "__main__": now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") print("@SVD Inference: %s"%now) #Fire(sample) parser = get_parser() args = parser.parse_args() sample(input_path=args.input, path_ref=args.path_ref, ckpt=args.ckpt, config=args.config, num_frames=args.frames, num_steps=args.ddim_steps, \ fps_id=args.fps, motion_bucket_id=args.motion, cond_aug=args.cond_aug, seed=args.seed, \ decoding_t=args.decoding_t, output_folder=args.savedir, save_fps=args.savefps, resize=args.resize, \ s_w=args.s_w, e_w=args.e_w, s_h=args.s_h, e_h=args.e_h, ps_w=args.ps_w, pe_w=args.pe_w, ps_h=args.ps_h, pe_h=args.pe_h, x_bias=args.x_bias, y_bias=args.y_bias) ================================================ FILE: requirements.txt ================================================ black==23.7.0 chardet==5.1.0 clip einops>=0.6.1 fairscale==0.4.13 fire==0.5.0 fsspec>=2023.6.0 invisible-watermark>=0.2.0 kornia==0.6.9 matplotlib==3.7.2 natsort>=8.4.0 ninja>=1.11.1 numpy>=1.24.4 omegaconf>=2.3.0 open-clip-torch>=2.20.0 opencv-python==4.6.0.66 pandas>=2.0.3 pillow>=9.5.0 pudb>=2022.1.3 pytorch-lightning==1.9.0 pyyaml>=6.0.1 scipy>=1.10.1 streamlit>=0.73.1 tensorboardx==2.6 timm==0.4.12 tokenizers==0.12.1 torch==2.0.1 torchaudio>=2.0.2 torchdata==0.6.1 torchmetrics>=1.0.1 torchvision>=0.15.2 tqdm>=4.65.0 transformers==4.19.1 triton==2.0.0 urllib3<1.27,>=1.25.4 wandb>=0.15.6 webdataset>=0.2.33 wheel>=0.41.0 xformers==0.0.21 decord deepspeed yacs controlnet_aux PyAV safetensors imageio[ffmpeg] gradio==3.50.2 safetensors ================================================ FILE: sgm/__init__.py ================================================ from .models import AutoencodingEngine, DiffusionEngine from .util import get_configs_path, instantiate_from_config __version__ = "0.1.0" ================================================ FILE: sgm/inference/api.py ================================================ import pathlib from dataclasses import asdict, dataclass from enum import Enum from typing import Optional from omegaconf import OmegaConf from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img, do_sample) from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler, DPMPP2SAncestralSampler, EulerAncestralSampler, EulerEDMSampler, HeunEDMSampler, LinearMultistepSampler) from sgm.util import load_model_from_config class ModelArchitecture(str, Enum): SD_2_1 = "stable-diffusion-v2-1" SD_2_1_768 = "stable-diffusion-v2-1-768" SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base" SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner" SDXL_V1_BASE = "stable-diffusion-xl-v1-base" SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner" class Sampler(str, Enum): EULER_EDM = "EulerEDMSampler" HEUN_EDM = "HeunEDMSampler" EULER_ANCESTRAL = "EulerAncestralSampler" DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler" DPMPP2M = "DPMPP2MSampler" LINEAR_MULTISTEP = "LinearMultistepSampler" class Discretization(str, Enum): LEGACY_DDPM = "LegacyDDPMDiscretization" EDM = "EDMDiscretization" class Guider(str, Enum): VANILLA = "VanillaCFG" IDENTITY = "IdentityGuider" class Thresholder(str, Enum): NONE = "None" @dataclass class SamplingParams: width: int = 1024 height: int = 1024 steps: int = 50 sampler: Sampler = Sampler.DPMPP2M discretization: Discretization = Discretization.LEGACY_DDPM guider: Guider = Guider.VANILLA thresholder: Thresholder = Thresholder.NONE scale: float = 6.0 aesthetic_score: float = 5.0 negative_aesthetic_score: float = 5.0 img2img_strength: float = 1.0 orig_width: int = 1024 orig_height: int = 1024 crop_coords_top: int = 0 crop_coords_left: int = 0 sigma_min: float = 0.0292 sigma_max: float = 14.6146 rho: float = 3.0 s_churn: float = 0.0 s_tmin: float = 0.0 s_tmax: float = 999.0 s_noise: float = 1.0 eta: float = 1.0 order: int = 4 @dataclass class SamplingSpec: width: int height: int channels: int factor: int is_legacy: bool config: str ckpt: str is_guided: bool model_specs = { ModelArchitecture.SD_2_1: SamplingSpec( height=512, width=512, channels=4, factor=8, is_legacy=True, config="sd_2_1.yaml", ckpt="v2-1_512-ema-pruned.safetensors", is_guided=True, ), ModelArchitecture.SD_2_1_768: SamplingSpec( height=768, width=768, channels=4, factor=8, is_legacy=True, config="sd_2_1_768.yaml", ckpt="v2-1_768-ema-pruned.safetensors", is_guided=True, ), ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec( height=1024, width=1024, channels=4, factor=8, is_legacy=False, config="sd_xl_base.yaml", ckpt="sd_xl_base_0.9.safetensors", is_guided=True, ), ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec( height=1024, width=1024, channels=4, factor=8, is_legacy=True, config="sd_xl_refiner.yaml", ckpt="sd_xl_refiner_0.9.safetensors", is_guided=True, ), ModelArchitecture.SDXL_V1_BASE: SamplingSpec( height=1024, width=1024, channels=4, factor=8, is_legacy=False, config="sd_xl_base.yaml", ckpt="sd_xl_base_1.0.safetensors", is_guided=True, ), ModelArchitecture.SDXL_V1_REFINER: SamplingSpec( height=1024, width=1024, channels=4, factor=8, is_legacy=True, config="sd_xl_refiner.yaml", ckpt="sd_xl_refiner_1.0.safetensors", is_guided=True, ), } class SamplingPipeline: def __init__( self, model_id: ModelArchitecture, model_path="checkpoints", config_path="configs/inference", device="cuda", use_fp16=True, ) -> None: if model_id not in model_specs: raise ValueError(f"Model {model_id} not supported") self.model_id = model_id self.specs = model_specs[self.model_id] self.config = str(pathlib.Path(config_path, self.specs.config)) self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt)) self.device = device self.model = self._load_model(device=device, use_fp16=use_fp16) def _load_model(self, device="cuda", use_fp16=True): config = OmegaConf.load(self.config) model = load_model_from_config(config, self.ckpt) if model is None: raise ValueError(f"Model {self.model_id} could not be loaded") model.to(device) if use_fp16: model.conditioner.half() model.model.half() return model def text_to_image( self, params: SamplingParams, prompt: str, negative_prompt: str = "", samples: int = 1, return_latents: bool = False, ): sampler = get_sampler_config(params) value_dict = asdict(params) value_dict["prompt"] = prompt value_dict["negative_prompt"] = negative_prompt value_dict["target_width"] = params.width value_dict["target_height"] = params.height return do_sample( self.model, sampler, value_dict, samples, params.height, params.width, self.specs.channels, self.specs.factor, force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, filter=None, ) def image_to_image( self, params: SamplingParams, image, prompt: str, negative_prompt: str = "", samples: int = 1, return_latents: bool = False, ): sampler = get_sampler_config(params) if params.img2img_strength < 1.0: sampler.discretization = Img2ImgDiscretizationWrapper( sampler.discretization, strength=params.img2img_strength, ) height, width = image.shape[2], image.shape[3] value_dict = asdict(params) value_dict["prompt"] = prompt value_dict["negative_prompt"] = negative_prompt value_dict["target_width"] = width value_dict["target_height"] = height return do_img2img( image, self.model, sampler, value_dict, samples, force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, filter=None, ) def refiner( self, params: SamplingParams, image, prompt: str, negative_prompt: Optional[str] = None, samples: int = 1, return_latents: bool = False, ): sampler = get_sampler_config(params) value_dict = { "orig_width": image.shape[3] * 8, "orig_height": image.shape[2] * 8, "target_width": image.shape[3] * 8, "target_height": image.shape[2] * 8, "prompt": prompt, "negative_prompt": negative_prompt, "crop_coords_top": 0, "crop_coords_left": 0, "aesthetic_score": 6.0, "negative_aesthetic_score": 2.5, } return do_img2img( image, self.model, sampler, value_dict, samples, skip_encode=True, return_latents=return_latents, filter=None, ) def get_guider_config(params: SamplingParams): if params.guider == Guider.IDENTITY: guider_config = { "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" } elif params.guider == Guider.VANILLA: scale = params.scale thresholder = params.thresholder if thresholder == Thresholder.NONE: dyn_thresh_config = { "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" } else: raise NotImplementedError guider_config = { "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config}, } else: raise NotImplementedError return guider_config def get_discretization_config(params: SamplingParams): if params.discretization == Discretization.LEGACY_DDPM: discretization_config = { "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", } elif params.discretization == Discretization.EDM: discretization_config = { "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", "params": { "sigma_min": params.sigma_min, "sigma_max": params.sigma_max, "rho": params.rho, }, } else: raise ValueError(f"unknown discretization {params.discretization}") return discretization_config def get_sampler_config(params: SamplingParams): discretization_config = get_discretization_config(params) guider_config = get_guider_config(params) sampler = None if params.sampler == Sampler.EULER_EDM: return EulerEDMSampler( num_steps=params.steps, discretization_config=discretization_config, guider_config=guider_config, s_churn=params.s_churn, s_tmin=params.s_tmin, s_tmax=params.s_tmax, s_noise=params.s_noise, verbose=True, ) if params.sampler == Sampler.HEUN_EDM: return HeunEDMSampler( num_steps=params.steps, discretization_config=discretization_config, guider_config=guider_config, s_churn=params.s_churn, s_tmin=params.s_tmin, s_tmax=params.s_tmax, s_noise=params.s_noise, verbose=True, ) if params.sampler == Sampler.EULER_ANCESTRAL: return EulerAncestralSampler( num_steps=params.steps, discretization_config=discretization_config, guider_config=guider_config, eta=params.eta, s_noise=params.s_noise, verbose=True, ) if params.sampler == Sampler.DPMPP2S_ANCESTRAL: return DPMPP2SAncestralSampler( num_steps=params.steps, discretization_config=discretization_config, guider_config=guider_config, eta=params.eta, s_noise=params.s_noise, verbose=True, ) if params.sampler == Sampler.DPMPP2M: return DPMPP2MSampler( num_steps=params.steps, discretization_config=discretization_config, guider_config=guider_config, verbose=True, ) if params.sampler == Sampler.LINEAR_MULTISTEP: return LinearMultistepSampler( num_steps=params.steps, discretization_config=discretization_config, guider_config=guider_config, order=params.order, verbose=True, ) raise ValueError(f"unknown sampler {params.sampler}!") ================================================ FILE: sgm/inference/helpers.py ================================================ import math import os from typing import List, Optional, Union import numpy as np import torch from einops import rearrange from imwatermark import WatermarkEncoder from omegaconf import ListConfig from PIL import Image from torch import autocast from sgm.util import append_dims class WatermarkEmbedder: def __init__(self, watermark): self.watermark = watermark self.num_bits = len(WATERMARK_BITS) self.encoder = WatermarkEncoder() self.encoder.set_watermark("bits", self.watermark) def __call__(self, image: torch.Tensor) -> torch.Tensor: """ Adds a predefined watermark to the input image Args: image: ([N,] B, RGB, H, W) in range [0, 1] Returns: same as input but watermarked """ squeeze = len(image.shape) == 4 if squeeze: image = image[None, ...] n = image.shape[0] image_np = rearrange( (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" ).numpy()[:, :, :, ::-1] # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] # watermarking libary expects input as cv2 BGR format for k in range(image_np.shape[0]): image_np[k] = self.encoder.encode(image_np[k], "dwtDct") image = torch.from_numpy( rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n) ).to(image.device) image = torch.clamp(image / 255, min=0.0, max=1.0) if squeeze: image = image[0] return image # A fixed 48-bit message that was choosen at random # WATERMARK_MESSAGE = 0xB3EC907BB19E WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] embed_watermark = WatermarkEmbedder(WATERMARK_BITS) def get_unique_embedder_keys_from_conditioner(conditioner): return list({x.input_key for x in conditioner.embedders}) def perform_save_locally(save_path, samples): os.makedirs(os.path.join(save_path), exist_ok=True) base_count = len(os.listdir(os.path.join(save_path))) samples = embed_watermark(samples) for sample in samples: sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") Image.fromarray(sample.astype(np.uint8)).save( os.path.join(save_path, f"{base_count:09}.png") ) base_count += 1 class Img2ImgDiscretizationWrapper: """ wraps a discretizer, and prunes the sigmas params: strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) """ def __init__(self, discretization, strength: float = 1.0): self.discretization = discretization self.strength = strength assert 0.0 <= self.strength <= 1.0 def __call__(self, *args, **kwargs): # sigmas start large first, and decrease then sigmas = self.discretization(*args, **kwargs) print(f"sigmas after discretization, before pruning img2img: ", sigmas) sigmas = torch.flip(sigmas, (0,)) sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] print("prune index:", max(int(self.strength * len(sigmas)), 1)) sigmas = torch.flip(sigmas, (0,)) print(f"sigmas after pruning: ", sigmas) return sigmas def do_sample( model, sampler, value_dict, num_samples, H, W, C, F, force_uc_zero_embeddings: Optional[List] = None, batch2model_input: Optional[List] = None, return_latents=False, filter=None, device="cuda", ): if force_uc_zero_embeddings is None: force_uc_zero_embeddings = [] if batch2model_input is None: batch2model_input = [] with torch.no_grad(): with autocast(device) as precision_scope: with model.ema_scope(): num_samples = [num_samples] batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples, ) for key in batch: if isinstance(batch[key], torch.Tensor): print(key, batch[key].shape) elif isinstance(batch[key], list): print(key, [len(l) for l in batch[key]]) else: print(key, batch[key]) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings, ) for k in c: if not k == "crossattn": c[k], uc[k] = map( lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc) ) additional_model_inputs = {} for k in batch2model_input: additional_model_inputs[k] = batch[k] shape = (math.prod(num_samples), C, H // F, W // F) randn = torch.randn(shape).to(device) def denoiser(input, sigma, c): return model.denoiser( model.model, input, sigma, c, **additional_model_inputs ) samples_z = sampler(denoiser, randn, cond=c, uc=uc) samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) if filter is not None: samples = filter(samples) if return_latents: return samples, samples_z return samples def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): # Hardcoded demo setups; might undergo some changes in the future batch = {} batch_uc = {} for key in keys: if key == "txt": batch["txt"] = ( np.repeat([value_dict["prompt"]], repeats=math.prod(N)) .reshape(N) .tolist() ) batch_uc["txt"] = ( np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) .reshape(N) .tolist() ) elif key == "original_size_as_tuple": batch["original_size_as_tuple"] = ( torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) .to(device) .repeat(*N, 1) ) elif key == "crop_coords_top_left": batch["crop_coords_top_left"] = ( torch.tensor( [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] ) .to(device) .repeat(*N, 1) ) elif key == "aesthetic_score": batch["aesthetic_score"] = ( torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) ) batch_uc["aesthetic_score"] = ( torch.tensor([value_dict["negative_aesthetic_score"]]) .to(device) .repeat(*N, 1) ) elif key == "target_size_as_tuple": batch["target_size_as_tuple"] = ( torch.tensor([value_dict["target_height"], value_dict["target_width"]]) .to(device) .repeat(*N, 1) ) else: batch[key] = value_dict[key] for key in batch.keys(): if key not in batch_uc and isinstance(batch[key], torch.Tensor): batch_uc[key] = torch.clone(batch[key]) return batch, batch_uc def get_input_image_tensor(image: Image.Image, device="cuda"): w, h = image.size print(f"loaded input image of size ({w}, {h})") width, height = map( lambda x: x - x % 64, (w, h) ) # resize to integer multiple of 64 image = image.resize((width, height)) image_array = np.array(image.convert("RGB")) image_array = image_array[None].transpose(0, 3, 1, 2) image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0 return image_tensor.to(device) def do_img2img( img, model, sampler, value_dict, num_samples, force_uc_zero_embeddings=[], additional_kwargs={}, offset_noise_level: float = 0.0, return_latents=False, skip_encode=False, filter=None, device="cuda", ): with torch.no_grad(): with autocast(device) as precision_scope: with model.ema_scope(): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, [num_samples], ) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings, ) for k in c: c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc)) for k in additional_kwargs: c[k] = uc[k] = additional_kwargs[k] if skip_encode: z = img else: z = model.encode_first_stage(img) noise = torch.randn_like(z) sigmas = sampler.discretization(sampler.num_steps) sigma = sigmas[0].to(z.device) if offset_noise_level > 0.0: noise = noise + offset_noise_level * append_dims( torch.randn(z.shape[0], device=z.device), z.ndim ) noised_z = z + noise * append_dims(sigma, z.ndim) noised_z = noised_z / torch.sqrt( 1.0 + sigmas[0] ** 2.0 ) # Note: hardcoded to DDPM-like scaling. need to generalize later. def denoiser(x, sigma, c): return model.denoiser(model.model, x, sigma, c) samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) if filter is not None: samples = filter(samples) if return_latents: return samples, samples_z return samples ================================================ FILE: sgm/lr_scheduler.py ================================================ import numpy as np class LambdaWarmUpCosineScheduler: """ note: use with a base_lr of 1.0 """ def __init__( self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0, ): self.lr_warm_up_steps = warm_up_steps self.lr_start = lr_start self.lr_min = lr_min self.lr_max = lr_max self.lr_max_decay_steps = max_decay_steps self.last_lr = 0.0 self.verbosity_interval = verbosity_interval def schedule(self, n, **kwargs): if self.verbosity_interval > 0: if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") if n < self.lr_warm_up_steps: lr = ( self.lr_max - self.lr_start ) / self.lr_warm_up_steps * n + self.lr_start self.last_lr = lr return lr else: t = (n - self.lr_warm_up_steps) / ( self.lr_max_decay_steps - self.lr_warm_up_steps ) t = min(t, 1.0) lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 1 + np.cos(t * np.pi) ) self.last_lr = lr return lr def __call__(self, n, **kwargs): return self.schedule(n, **kwargs) class LambdaWarmUpCosineScheduler2: """ supports repeated iterations, configurable via lists note: use with a base_lr of 1.0. """ def __init__( self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0 ): assert ( len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) ) self.lr_warm_up_steps = warm_up_steps self.f_start = f_start self.f_min = f_min self.f_max = f_max self.cycle_lengths = cycle_lengths self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) self.last_f = 0.0 self.verbosity_interval = verbosity_interval def find_in_interval(self, n): interval = 0 for cl in self.cum_cycles[1:]: if n <= cl: return interval interval += 1 def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] if self.verbosity_interval > 0: if n % self.verbosity_interval == 0: print( f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}" ) if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ cycle ] * n + self.f_start[cycle] self.last_f = f return f else: t = (n - self.lr_warm_up_steps[cycle]) / ( self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] ) t = min(t, 1.0) f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 1 + np.cos(t * np.pi) ) self.last_f = f return f def __call__(self, n, **kwargs): return self.schedule(n, **kwargs) class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] if self.verbosity_interval > 0: if n % self.verbosity_interval == 0: print( f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}" ) if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ cycle ] * n + self.f_start[cycle] self.last_f = f return f else: f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * ( self.cycle_lengths[cycle] - n ) / (self.cycle_lengths[cycle]) self.last_f = f return f ================================================ FILE: sgm/models/__init__.py ================================================ from .autoencoder import AutoencodingEngine from .diffusion import DiffusionEngine ================================================ FILE: sgm/models/autoencoder.py ================================================ import logging import math import re from abc import abstractmethod from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple, Union import pytorch_lightning as pl import torch import torch.nn as nn from einops import rearrange from packaging import version from ..modules.autoencoding.regularizers import AbstractRegularizer from ..modules.ema import LitEma from ..util import (default, get_nested_attribute, get_obj_from_str, instantiate_from_config) logpy = logging.getLogger(__name__) class AbstractAutoencoder(pl.LightningModule): """ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators, unCLIP models, etc. Hence, it is fairly general, and specific features (e.g. discriminator training, encoding, decoding) must be implemented in subclasses. """ def __init__( self, ema_decay: Union[None, float] = None, monitor: Union[None, str] = None, input_key: str = "jpg", ): super().__init__() self.input_key = input_key self.use_ema = ema_decay is not None if monitor is not None: self.monitor = monitor if self.use_ema: self.model_ema = LitEma(self, decay=ema_decay) logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") if version.parse(torch.__version__) >= version.parse("2.0.0"): self.automatic_optimization = False def apply_ckpt(self, ckpt: Union[None, str, dict]): if ckpt is None: return if isinstance(ckpt, str): ckpt = { "target": "sgm.modules.checkpoint.CheckpointEngine", "params": {"ckpt_path": ckpt}, } engine = instantiate_from_config(ckpt) engine(self) @abstractmethod def get_input(self, batch) -> Any: raise NotImplementedError() def on_train_batch_end(self, *args, **kwargs): # for EMA computation if self.use_ema: self.model_ema(self) @contextmanager def ema_scope(self, context=None): if self.use_ema: self.model_ema.store(self.parameters()) self.model_ema.copy_to(self) if context is not None: logpy.info(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.parameters()) if context is not None: logpy.info(f"{context}: Restored training weights") @abstractmethod def encode(self, *args, **kwargs) -> torch.Tensor: raise NotImplementedError("encode()-method of abstract base class called") @abstractmethod def decode(self, *args, **kwargs) -> torch.Tensor: raise NotImplementedError("decode()-method of abstract base class called") def instantiate_optimizer_from_config(self, params, lr, cfg): logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config") return get_obj_from_str(cfg["target"])( params, lr=lr, **cfg.get("params", dict()) ) def configure_optimizers(self) -> Any: raise NotImplementedError() class AutoencodingEngine(AbstractAutoencoder): """ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL (we also restore them explicitly as special cases for legacy reasons). Regularizations such as KL or VQ are moved to the regularizer class. """ def __init__( self, *args, encoder_config: Dict, decoder_config: Dict, loss_config: Dict, regularizer_config: Dict, optimizer_config: Union[Dict, None] = None, lr_g_factor: float = 1.0, trainable_ae_params: Optional[List[List[str]]] = None, ae_optimizer_args: Optional[List[dict]] = None, trainable_disc_params: Optional[List[List[str]]] = None, disc_optimizer_args: Optional[List[dict]] = None, disc_start_iter: int = 0, diff_boost_factor: float = 3.0, ckpt_engine: Union[None, str, dict] = None, ckpt_path: Optional[str] = None, additional_decode_keys: Optional[List[str]] = None, **kwargs, ): super().__init__(*args, **kwargs) self.automatic_optimization = False # pytorch lightning self.encoder: torch.nn.Module = instantiate_from_config(encoder_config) self.decoder: torch.nn.Module = instantiate_from_config(decoder_config) self.loss: torch.nn.Module = instantiate_from_config(loss_config) self.regularization: AbstractRegularizer = instantiate_from_config( regularizer_config ) self.optimizer_config = default( optimizer_config, {"target": "torch.optim.Adam"} ) self.diff_boost_factor = diff_boost_factor self.disc_start_iter = disc_start_iter self.lr_g_factor = lr_g_factor self.trainable_ae_params = trainable_ae_params if self.trainable_ae_params is not None: self.ae_optimizer_args = default( ae_optimizer_args, [{} for _ in range(len(self.trainable_ae_params))], ) assert len(self.ae_optimizer_args) == len(self.trainable_ae_params) else: self.ae_optimizer_args = [{}] # makes type consitent self.trainable_disc_params = trainable_disc_params if self.trainable_disc_params is not None: self.disc_optimizer_args = default( disc_optimizer_args, [{} for _ in range(len(self.trainable_disc_params))], ) assert len(self.disc_optimizer_args) == len(self.trainable_disc_params) else: self.disc_optimizer_args = [{}] # makes type consitent if ckpt_path is not None: assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path" logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead") self.apply_ckpt(default(ckpt_path, ckpt_engine)) self.additional_decode_keys = set(default(additional_decode_keys, [])) def get_input(self, batch: Dict) -> torch.Tensor: # assuming unified data format, dataloader returns a dict. # image tensors should be scaled to -1 ... 1 and in channels-first # format (e.g., bchw instead if bhwc) return batch[self.input_key] def get_autoencoder_params(self) -> list: params = [] if hasattr(self.loss, "get_trainable_autoencoder_parameters"): params += list(self.loss.get_trainable_autoencoder_parameters()) if hasattr(self.regularization, "get_trainable_parameters"): params += list(self.regularization.get_trainable_parameters()) params = params + list(self.encoder.parameters()) params = params + list(self.decoder.parameters()) return params def get_discriminator_params(self) -> list: if hasattr(self.loss, "get_trainable_parameters"): params = list(self.loss.get_trainable_parameters()) # e.g., discriminator else: params = [] return params def get_last_layer(self): return self.decoder.get_last_layer() def encode( self, x: torch.Tensor, return_reg_log: bool = False, unregularized: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: z = self.encoder(x) if unregularized: return z, dict() z, reg_log = self.regularization(z) if return_reg_log: return z, reg_log return z def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: x = self.decoder(z, **kwargs) return x def forward( self, x: torch.Tensor, **additional_decode_kwargs ) -> Tuple[torch.Tensor, torch.Tensor, dict]: z, reg_log = self.encode(x, return_reg_log=True) dec = self.decode(z, **additional_decode_kwargs) return z, dec, reg_log def inner_training_step( self, batch: dict, batch_idx: int, optimizer_idx: int = 0 ) -> torch.Tensor: x = self.get_input(batch) additional_decode_kwargs = { key: batch[key] for key in self.additional_decode_keys.intersection(batch) } z, xrec, regularization_log = self(x, **additional_decode_kwargs) if hasattr(self.loss, "forward_keys"): extra_info = { "z": z, "optimizer_idx": optimizer_idx, "global_step": self.global_step, "last_layer": self.get_last_layer(), "split": "train", "regularization_log": regularization_log, "autoencoder": self, } extra_info = {k: extra_info[k] for k in self.loss.forward_keys} else: extra_info = dict() if optimizer_idx == 0: # autoencode out_loss = self.loss(x, xrec, **extra_info) if isinstance(out_loss, tuple): aeloss, log_dict_ae = out_loss else: # simple loss function aeloss = out_loss log_dict_ae = {"train/loss/rec": aeloss.detach()} self.log_dict( log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True, sync_dist=False, ) self.log( "loss", aeloss.mean().detach(), prog_bar=True, logger=False, on_epoch=False, on_step=True, ) return aeloss elif optimizer_idx == 1: # discriminator discloss, log_dict_disc = self.loss(x, xrec, **extra_info) # -> discriminator always needs to return a tuple self.log_dict( log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True ) return discloss else: raise NotImplementedError(f"Unknown optimizer {optimizer_idx}") def training_step(self, batch: dict, batch_idx: int): opts = self.optimizers() if not isinstance(opts, list): # Non-adversarial case opts = [opts] optimizer_idx = batch_idx % len(opts) if self.global_step < self.disc_start_iter: optimizer_idx = 0 opt = opts[optimizer_idx] opt.zero_grad() with opt.toggle_model(): loss = self.inner_training_step( batch, batch_idx, optimizer_idx=optimizer_idx ) self.manual_backward(loss) opt.step() def validation_step(self, batch: dict, batch_idx: int) -> Dict: log_dict = self._validation_step(batch, batch_idx) with self.ema_scope(): log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") log_dict.update(log_dict_ema) return log_dict def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict: x = self.get_input(batch) z, xrec, regularization_log = self(x) if hasattr(self.loss, "forward_keys"): extra_info = { "z": z, "optimizer_idx": 0, "global_step": self.global_step, "last_layer": self.get_last_layer(), "split": "val" + postfix, "regularization_log": regularization_log, "autoencoder": self, } extra_info = {k: extra_info[k] for k in self.loss.forward_keys} else: extra_info = dict() out_loss = self.loss(x, xrec, **extra_info) if isinstance(out_loss, tuple): aeloss, log_dict_ae = out_loss else: # simple loss function aeloss = out_loss log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()} full_log_dict = log_dict_ae if "optimizer_idx" in extra_info: extra_info["optimizer_idx"] = 1 discloss, log_dict_disc = self.loss(x, xrec, **extra_info) full_log_dict.update(log_dict_disc) self.log( f"val{postfix}/loss/rec", log_dict_ae[f"val{postfix}/loss/rec"], sync_dist=True, ) self.log_dict(full_log_dict, sync_dist=True) return full_log_dict def get_param_groups( self, parameter_names: List[List[str]], optimizer_args: List[dict] ) -> Tuple[List[Dict[str, Any]], int]: groups = [] num_params = 0 for names, args in zip(parameter_names, optimizer_args): params = [] for pattern_ in names: pattern_params = [] pattern = re.compile(pattern_) for p_name, param in self.named_parameters(): if re.match(pattern, p_name): pattern_params.append(param) num_params += param.numel() if len(pattern_params) == 0: logpy.warn(f"Did not find parameters for pattern {pattern_}") params.extend(pattern_params) groups.append({"params": params, **args}) return groups, num_params def configure_optimizers(self) -> List[torch.optim.Optimizer]: if self.trainable_ae_params is None: ae_params = self.get_autoencoder_params() else: ae_params, num_ae_params = self.get_param_groups( self.trainable_ae_params, self.ae_optimizer_args ) logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}") if self.trainable_disc_params is None: disc_params = self.get_discriminator_params() else: disc_params, num_disc_params = self.get_param_groups( self.trainable_disc_params, self.disc_optimizer_args ) logpy.info( f"Number of trainable discriminator parameters: {num_disc_params:,}" ) opt_ae = self.instantiate_optimizer_from_config( ae_params, default(self.lr_g_factor, 1.0) * self.learning_rate, self.optimizer_config, ) opts = [opt_ae] if len(disc_params) > 0: opt_disc = self.instantiate_optimizer_from_config( disc_params, self.learning_rate, self.optimizer_config ) opts.append(opt_disc) return opts @torch.no_grad() def log_images( self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs ) -> dict: log = dict() additional_decode_kwargs = {} x = self.get_input(batch) additional_decode_kwargs.update( {key: batch[key] for key in self.additional_decode_keys.intersection(batch)} ) _, xrec, _ = self(x, **additional_decode_kwargs) log["inputs"] = x log["reconstructions"] = xrec diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x) diff.clamp_(0, 1.0) log["diff"] = 2.0 * diff - 1.0 # diff_boost shows location of small errors, by boosting their # brightness. log["diff_boost"] = ( 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1 ) if hasattr(self.loss, "log_images"): log.update(self.loss.log_images(x, xrec)) with self.ema_scope(): _, xrec_ema, _ = self(x, **additional_decode_kwargs) log["reconstructions_ema"] = xrec_ema diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x) diff_ema.clamp_(0, 1.0) log["diff_ema"] = 2.0 * diff_ema - 1.0 log["diff_boost_ema"] = ( 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1 ) if additional_log_kwargs: additional_decode_kwargs.update(additional_log_kwargs) _, xrec_add, _ = self(x, **additional_decode_kwargs) log_str = "reconstructions-" + "-".join( [f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs] ) log[log_str] = xrec_add return log class AutoencodingEngineLegacy(AutoencodingEngine): def __init__(self, embed_dim: int, **kwargs): self.max_batch_size = kwargs.pop("max_batch_size", None) ddconfig = kwargs.pop("ddconfig") ckpt_path = kwargs.pop("ckpt_path", None) ckpt_engine = kwargs.pop("ckpt_engine", None) super().__init__( encoder_config={ "target": "sgm.modules.diffusionmodules.model.Encoder", "params": ddconfig, }, decoder_config={ "target": "sgm.modules.diffusionmodules.model.Decoder", "params": ddconfig, }, **kwargs, ) self.quant_conv = torch.nn.Conv2d( (1 + ddconfig["double_z"]) * ddconfig["z_channels"], (1 + ddconfig["double_z"]) * embed_dim, 1, ) self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.embed_dim = embed_dim self.apply_ckpt(default(ckpt_path, ckpt_engine)) def get_autoencoder_params(self) -> list: params = super().get_autoencoder_params() return params def encode( self, x: torch.Tensor, return_reg_log: bool = False ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: if self.max_batch_size is None: z = self.encoder(x) z = self.quant_conv(z) else: N = x.shape[0] bs = self.max_batch_size n_batches = int(math.ceil(N / bs)) z = list() for i_batch in range(n_batches): z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs]) z_batch = self.quant_conv(z_batch) z.append(z_batch) z = torch.cat(z, 0) z, reg_log = self.regularization(z) if return_reg_log: return z, reg_log return z def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor: if self.max_batch_size is None: dec = self.post_quant_conv(z) dec = self.decoder(dec, **decoder_kwargs) else: N = z.shape[0] bs = self.max_batch_size n_batches = int(math.ceil(N / bs)) dec = list() for i_batch in range(n_batches): dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs]) dec_batch = self.decoder(dec_batch, **decoder_kwargs) dec.append(dec_batch) dec = torch.cat(dec, 0) return dec class AutoencoderKL(AutoencodingEngineLegacy): def __init__(self, **kwargs): if "lossconfig" in kwargs: kwargs["loss_config"] = kwargs.pop("lossconfig") super().__init__( regularizer_config={ "target": ( "sgm.modules.autoencoding.regularizers" ".DiagonalGaussianRegularizer" ) }, **kwargs, ) class AutoencoderLegacyVQ(AutoencodingEngineLegacy): def __init__( self, embed_dim: int, n_embed: int, sane_index_shape: bool = False, **kwargs, ): if "lossconfig" in kwargs: logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.") kwargs["loss_config"] = kwargs.pop("lossconfig") super().__init__( regularizer_config={ "target": ( "sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer" ), "params": { "n_e": n_embed, "e_dim": embed_dim, "sane_index_shape": sane_index_shape, }, }, **kwargs, ) class IdentityFirstStage(AbstractAutoencoder): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def get_input(self, x: Any) -> Any: return x def encode(self, x: Any, *args, **kwargs) -> Any: return x def decode(self, x: Any, *args, **kwargs) -> Any: return x class AEIntegerWrapper(nn.Module): def __init__( self, model: nn.Module, shape: Union[None, Tuple[int, int], List[int]] = (16, 16), regularization_key: str = "regularization", encoder_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__() self.model = model assert hasattr(model, "encode") and hasattr( model, "decode" ), "Need AE interface" self.regularization = get_nested_attribute(model, regularization_key) self.shape = shape self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True}) def encode(self, x) -> torch.Tensor: assert ( not self.training ), f"{self.__class__.__name__} only supports inference currently" _, log = self.model.encode(x, **self.encoder_kwargs) assert isinstance(log, dict) inds = log["min_encoding_indices"] return rearrange(inds, "b ... -> b (...)") def decode( self, inds: torch.Tensor, shape: Union[None, tuple, list] = None ) -> torch.Tensor: # expect inds shape (b, s) with s = h*w shape = default(shape, self.shape) # Optional[(h, w)] if shape is not None: assert len(shape) == 2, f"Unhandeled shape {shape}" inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1]) h = self.regularization.get_codebook_entry(inds) # (b, h, w, c) h = rearrange(h, "b h w c -> b c h w") return self.model.decode(h) class AutoencoderKLModeOnly(AutoencodingEngineLegacy): def __init__(self, **kwargs): if "lossconfig" in kwargs: kwargs["loss_config"] = kwargs.pop("lossconfig") super().__init__( regularizer_config={ "target": ( "sgm.modules.autoencoding.regularizers" ".DiagonalGaussianRegularizer" ), "params": {"sample": False}, }, **kwargs, ) ================================================ FILE: sgm/models/diffusion.py ================================================ import math from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple, Union from einops import rearrange, repeat import pytorch_lightning as pl import torch from omegaconf import ListConfig, OmegaConf from safetensors.torch import load_file as load_safetensors from torch.optim.lr_scheduler import LambdaLR from ..modules import UNCONDITIONAL_CONFIG from ..modules.autoencoding.temporal_ae import VideoDecoder from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER from ..modules.ema import LitEma from ..util import (default, disabled_train, get_obj_from_str, instantiate_from_config, log_txt_as_img) class DiffusionEngine(pl.LightningModule): def __init__( self, network_config, denoiser_config, first_stage_config, conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None, sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None, optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None, scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None, loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None, network_wrapper: Union[None, str] = None, ckpt_path: Union[None, str] = None, use_ema: bool = False, ema_decay_rate: float = 0.9999, scale_factor: float = 1.0, disable_first_stage_autocast=False, input_key: str = "jpg", log_keys: Union[List, None] = None, no_cond_log: bool = False, compile_model: bool = False, en_and_decode_n_samples_a_time: Optional[int] = None, ): super().__init__() self.log_keys = log_keys self.input_key = input_key self.optimizer_config = default( optimizer_config, {"target": "torch.optim.AdamW"} ) model = instantiate_from_config(network_config) self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))( model, compile_model=compile_model ) self.denoiser = instantiate_from_config(denoiser_config) self.sampler = ( instantiate_from_config(sampler_config) if sampler_config is not None else None ) self.conditioner = instantiate_from_config( default(conditioner_config, UNCONDITIONAL_CONFIG) ) self.scheduler_config = scheduler_config self._init_first_stage(first_stage_config) ## update with num_frames self.num_frames = network_config.params.num_frames self.use_ema = use_ema if self.use_ema: self.model_ema = LitEma(self.model, decay=ema_decay_rate) print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") self.scale_factor = scale_factor self.disable_first_stage_autocast = disable_first_stage_autocast self.no_cond_log = no_cond_log if ckpt_path is not None: self.init_from_ckpt(ckpt_path) self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time def init_from_ckpt( self, path: str, ) -> None: if path.endswith("ckpt"): sd = torch.load(path, map_location="cpu")["state_dict"] elif path.endswith("safetensors"): sd = load_safetensors(path) else: raise NotImplementedError missing, unexpected = self.load_state_dict(sd, strict=False) print( f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" ) if len(missing) > 0: print(f"Missing Keys: {missing}") if len(unexpected) > 0: print(f"Unexpected Keys: {unexpected}") # exit(0) def _init_first_stage(self, config): model = instantiate_from_config(config).eval() model.train = disabled_train for param in model.parameters(): param.requires_grad = False self.first_stage_model = model @torch.no_grad() def decode_first_stage(self, z): z = 1.0 / self.scale_factor * z n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) n_rounds = math.ceil(z.shape[0] / n_samples) all_out = [] with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): for n in range(n_rounds): if isinstance(self.first_stage_model.decoder, VideoDecoder): kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} else: kwargs = {} out = self.first_stage_model.decode( z[n * n_samples : (n + 1) * n_samples], **kwargs ) all_out.append(out) out = torch.cat(all_out, dim=0) return out # @torch.no_grad() def decode_first_stage_train(self, z): z = 1.0 / self.scale_factor * z n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) n_rounds = math.ceil(z.shape[0] / n_samples) all_out = [] with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): for n in range(n_rounds): if isinstance(self.first_stage_model.decoder, VideoDecoder): kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} else: kwargs = {} out = self.first_stage_model.decode( z[n * n_samples : (n + 1) * n_samples], **kwargs ) all_out.append(out) out = torch.cat(all_out, dim=0) return out @torch.no_grad() def encode_first_stage(self, x): n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0]) n_rounds = math.ceil(x.shape[0] / n_samples) all_out = [] with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): for n in range(n_rounds): out = self.first_stage_model.encode( x[n * n_samples : (n + 1) * n_samples] ) all_out.append(out) z = torch.cat(all_out, dim=0) z = self.scale_factor * z return z def forward(self, x, batch): loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch) loss_mean = loss.mean() loss_dict = {"loss": loss_mean} return loss_mean, loss_dict def get_input(self, batch): # assuming unified data format, dataloader returns a dict. # image tensors should be scaled to -1 ... 1 and in bchw format x = batch[self.input_key] x = rearrange(x, "b c t h w -> (b t) c h w") return x def shared_step(self, batch: Dict) -> Any: x = self.get_input(batch) x = self.encode_first_stage(x) batch["global_step"] = self.global_step loss, loss_dict = self(x, batch) return loss, loss_dict def training_step(self, batch, batch_idx): loss, loss_dict = self.shared_step(batch) self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False) self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False) if self.scheduler_config is not None: lr = self.optimizers().param_groups[0]["lr"] self.log("lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) return loss def on_train_start(self, *args, **kwargs): if self.sampler is None or self.loss_fn is None: raise ValueError("Sampler and loss function need to be set for training.") def on_train_batch_end(self, *args, **kwargs): if self.use_ema: self.model_ema(self.model) @contextmanager def ema_scope(self, context=None): if self.use_ema: self.model_ema.store(self.model.parameters()) self.model_ema.copy_to(self.model) if context is not None: print(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.model.parameters()) if context is not None: print(f"{context}: Restored training weights") def instantiate_optimizer_from_config(self, params, lr, cfg): return get_obj_from_str(cfg["target"])( params, lr=lr, **cfg.get("params", dict()) ) def configure_optimizers(self): lr = self.learning_rate params = list(self.model.parameters()) for embedder in self.conditioner.embedders: if embedder.is_trainable: params = params + list(embedder.parameters()) print(f"@Training [{len(params)}] paramters.") opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config) if self.scheduler_config is not None: scheduler = instantiate_from_config(self.scheduler_config) print("Setting up LambdaLR scheduler...") scheduler = [ { "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), "interval": "step", "frequency": 1, } ] return [opt], scheduler return opt @torch.no_grad() def sample( self, cond: Dict, uc: Union[Dict, None] = None, batch_size: int = 16, shape: Union[None, Tuple, List] = None, **kwargs, ): randn = torch.randn(batch_size, *shape).to(self.device) denoiser = lambda input, sigma, c: self.denoiser( self.model, input, sigma, c, **kwargs ) samples = self.sampler(denoiser, randn, cond, uc=uc) return samples @torch.no_grad() def log_conditionings(self, batch: Dict, n: int) -> Dict: """ Defines heuristics to log different conditionings. These can be lists of strings (text-to-image), tensors, ints, ... """ image_h, image_w = batch[self.input_key].shape[2:] log = dict() for embedder in self.conditioner.embedders: if ( (self.log_keys is None) or (embedder.input_key in self.log_keys) ) and not self.no_cond_log: x = batch[embedder.input_key][:n] if isinstance(x, torch.Tensor): if x.dim() == 1: # class-conditional, convert integer to string x = [str(x[i].item()) for i in range(x.shape[0])] xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4) elif x.dim() == 2: # size and crop cond and the like x = [ "x".join([str(xx) for xx in x[i].tolist()]) for i in range(x.shape[0]) ] xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) else: raise NotImplementedError() elif isinstance(x, (List, ListConfig)): if isinstance(x[0], str): # strings xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) else: raise NotImplementedError() else: raise NotImplementedError() log[embedder.input_key] = xc return log @torch.no_grad() def log_images( self, batch: Dict, N: int = 999, sample: bool = True, ucg_keys: List[str] = None, **kwargs, ) -> Dict: conditioner_input_keys = [e.input_key for e in self.conditioner.embedders] if ucg_keys: assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( "Each defined ucg key for sampling must be in the provided conditioner input keys," f"but we have {ucg_keys} vs. {conditioner_input_keys}" ) else: ucg_keys = conditioner_input_keys log = dict() x = self.get_input(batch) c, uc = self.conditioner.get_unconditional_conditioning( batch, force_uc_zero_embeddings=ucg_keys if len(self.conditioner.embedders) > 0 else [], ) ## repeat conditions for each frames for k in ["crossattn", "concat", "vector"]: uc[k] = repeat(uc[k], "b ... -> b t ...", t=self.num_frames) uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=self.num_frames) c[k] = repeat(c[k], "b ... -> b t ...", t=self.num_frames) c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=self.num_frames) sampling_kwargs = {} N = min(x.shape[0], N) x = x.to(self.device)[:N] b = x.shape[0] // self.num_frames log["inputs"] = rearrange(x, '(b t) c h w -> b c t h w', b=b, t=self.num_frames) z = self.encode_first_stage(x) if self.disable_first_stage_autocast: self.disable_first_stage_autocast = False x_recon = self.decode_first_stage(z) self.disable_first_stage_autocast = True else: x_recon = self.decode_first_stage(z) log["reconstructions"] = rearrange(x_recon, '(b t) c h w -> b c t h w', b=b, t=self.num_frames) for k in c: if isinstance(c[k], torch.Tensor): c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc)) if sample: with self.ema_scope("Plotting"): samples = self.sample( c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs ) if self.disable_first_stage_autocast: self.disable_first_stage_autocast = False samples = self.decode_first_stage(samples) self.disable_first_stage_autocast = True else: samples = self.decode_first_stage(samples) log["samples"] = rearrange(samples, '(b t) c h w -> b c t h w', b=b, t=self.num_frames) return log ================================================ FILE: sgm/modules/__init__.py ================================================ from .encoders.modules import GeneralConditioner UNCONDITIONAL_CONFIG = { "target": "sgm.modules.GeneralConditioner", "params": {"emb_models": []}, } ================================================ FILE: sgm/modules/attention.py ================================================ import logging import math from inspect import isfunction from typing import Any, Optional import torch import torch.nn.functional as F from einops import rearrange, repeat from packaging import version from torch import nn from torch.utils.checkpoint import checkpoint logpy = logging.getLogger(__name__) if version.parse(torch.__version__) >= version.parse("2.0.0"): SDP_IS_AVAILABLE = True from torch.backends.cuda import SDPBackend, sdp_kernel BACKEND_MAP = { SDPBackend.MATH: { "enable_math": True, "enable_flash": False, "enable_mem_efficient": False, }, SDPBackend.FLASH_ATTENTION: { "enable_math": False, "enable_flash": True, "enable_mem_efficient": False, }, SDPBackend.EFFICIENT_ATTENTION: { "enable_math": False, "enable_flash": False, "enable_mem_efficient": True, }, None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True}, } else: from contextlib import nullcontext SDP_IS_AVAILABLE = False sdp_kernel = nullcontext BACKEND_MAP = {} logpy.warn( f"No SDP backend available, likely because you are running in pytorch " f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. " f"You might want to consider upgrading." ) try: import xformers import xformers.ops XFORMERS_IS_AVAILABLE = True except: XFORMERS_IS_AVAILABLE = False logpy.warn("no module 'xformers'. Processing without...") # from .diffusionmodules.util import mixed_checkpoint as checkpoint def exists(val): return val is not None def uniq(arr): return {el: True for el in arr}.keys() def default(val, d): if exists(val): return val return d() if isfunction(d) else d def max_neg_value(t): return -torch.finfo(t.dtype).max def init_(tensor): dim = tensor.shape[-1] std = 1 / math.sqrt(dim) tensor.uniform_(-std, std) return tensor # feedforward class GEGLU(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) return x * F.gelu(gate) class FeedForward(nn.Module): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) project_in = ( nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) ) self.net = nn.Sequential( project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) ) def forward(self, x): return self.net(x) def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module def Normalize(in_channels): return torch.nn.GroupNorm( num_groups=32, num_channels=in_channels, eps=1e-6, affine=True ) class LinearAttention(nn.Module): def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) q, k, v = rearrange( qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 ) k = k.softmax(dim=-1) context = torch.einsum("bhdn,bhen->bhde", k, v) out = torch.einsum("bhde,bhdn->bhen", context, q) out = rearrange( out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w ) return self.to_out(out) class SelfAttention(nn.Module): ATTENTION_MODES = ("xformers", "torch", "math") def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, qk_scale: Optional[float] = None, attn_drop: float = 0.0, proj_drop: float = 0.0, attn_mode: str = "xformers", ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) assert attn_mode in self.ATTENTION_MODES self.attn_mode = attn_mode def forward(self, x: torch.Tensor) -> torch.Tensor: B, L, C = x.shape qkv = self.qkv(x) if self.attn_mode == "torch": qkv = rearrange( qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads ).float() q, k, v = qkv[0], qkv[1], qkv[2] # B H L D x = torch.nn.functional.scaled_dot_product_attention(q, k, v) x = rearrange(x, "B H L D -> B L (H D)") elif self.attn_mode == "xformers": qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads) q, k, v = qkv[0], qkv[1], qkv[2] # B L H D x = xformers.ops.memory_efficient_attention(q, k, v) x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads) elif self.attn_mode == "math": qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) q, k, v = qkv[0], qkv[1], qkv[2] # B H L D attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, L, C) else: raise NotImplemented x = self.proj(x) x = self.proj_drop(x) return x class SpatialSelfAttention(nn.Module): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.k = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.v = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.proj_out = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) def forward(self, x): h_ = x h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention b, c, h, w = q.shape q = rearrange(q, "b c h w -> b (h w) c") k = rearrange(k, "b c h w -> b c (h w)") w_ = torch.einsum("bij,bjk->bik", q, k) w_ = w_ * (int(c) ** (-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values v = rearrange(v, "b c h w -> b c (h w)") w_ = rearrange(w_, "b i j -> b j i") h_ = torch.einsum("bij,bjk->bik", v, w_) h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) h_ = self.proj_out(h_) return x + h_ class CrossAttention(nn.Module): def __init__( self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, backend=None, ): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) self.scale = dim_head**-0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) self.backend = backend def forward( self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0, ): h = self.heads if additional_tokens is not None: # get the number of masked tokens at the beginning of the output sequence n_tokens_to_mask = additional_tokens.shape[1] # add additional token x = torch.cat([additional_tokens, x], dim=1) q = self.to_q(x) context = default(context, x) k = self.to_k(context) v = self.to_v(context) if n_times_crossframe_attn_in_self: # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439 assert x.shape[0] % n_times_crossframe_attn_in_self == 0 n_cp = x.shape[0] // n_times_crossframe_attn_in_self k = repeat( k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp ) v = repeat( v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp ) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) ## old """ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale del q, k if exists(mask): mask = rearrange(mask, 'b ... -> b (...)') max_neg_value = -torch.finfo(sim.dtype).max mask = repeat(mask, 'b j -> (b h) () j', h=h) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of sim = sim.softmax(dim=-1) out = einsum('b i j, b j d -> b i d', sim, v) """ ## new with sdp_kernel(**BACKEND_MAP[self.backend]): # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape) out = F.scaled_dot_product_attention( q, k, v, attn_mask=mask ) # scale is dim_head ** -0.5 per default del q, k, v out = rearrange(out, "b h n d -> b n (h d)", h=h) if additional_tokens is not None: # remove additional token out = out[:, n_tokens_to_mask:] return self.to_out(out) class MemoryEfficientCrossAttention(nn.Module): # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 def __init__( self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs ): super().__init__() logpy.debug( f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, " f"context_dim is {context_dim} and using {heads} heads with a " f"dimension of {dim_head}." ) inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) self.heads = heads self.dim_head = dim_head self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) self.attention_op: Optional[Any] = None def forward( self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0, align_w_first_frame = False, ): # print(x.shape, context is None) if additional_tokens is not None: # get the number of masked tokens at the beginning of the output sequence n_tokens_to_mask = additional_tokens.shape[1] # add additional token x = torch.cat([additional_tokens, x], dim=1) q = self.to_q(x) # print(context is None) context = default(context, x) k = self.to_k(context) v = self.to_v(context) # print(n_times_crossframe_attn_in_self) # exit(0) # n_times_crossframe_attn_in_self = 1 # print(x.shape, q.shape, k.shape, v.shape, n_times_crossframe_attn_in_self) if n_times_crossframe_attn_in_self: # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439 assert x.shape[0] % n_times_crossframe_attn_in_self == 0 # n_cp = x.shape[0]//n_times_crossframe_attn_in_self k = repeat( k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_times_crossframe_attn_in_self, ) v = repeat( v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_times_crossframe_attn_in_self, ) b, _, _ = q.shape q, k, v = map( lambda t: t.unsqueeze(3) .reshape(b, t.shape[1], self.heads, self.dim_head) .permute(0, 2, 1, 3) .reshape(b * self.heads, t.shape[1], self.dim_head) .contiguous(), (q, k, v), ) # actually compute the attention, what we cannot get enough of if version.parse(xformers.__version__) >= version.parse("0.0.21"): # NOTE: workaround for # https://github.com/facebookresearch/xformers/issues/845 max_bs = 32768 N = q.shape[0] n_batches = math.ceil(N / max_bs) out = list() for i_batch in range(n_batches): batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs) out.append( xformers.ops.memory_efficient_attention( q[batch], k[batch], v[batch], attn_bias=None, op=self.attention_op, ) ) out = torch.cat(out, 0) else: out = xformers.ops.memory_efficient_attention( q, k, v, attn_bias=None, op=self.attention_op ) # TODO: Use this directly in the attention operation, as a bias if exists(mask): raise NotImplementedError out = ( out.unsqueeze(0) .reshape(b, self.heads, out.shape[1], self.dim_head) .permute(0, 2, 1, 3) .reshape(b, out.shape[1], self.heads * self.dim_head) ) if additional_tokens is not None: # remove additional token out = out[:, n_tokens_to_mask:] return self.to_out(out) class BasicTransformerBlock(nn.Module): ATTENTION_MODES = { "softmax": CrossAttention, # vanilla attention "softmax-xformers": MemoryEfficientCrossAttention, # ampere } def __init__( self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False, attn_mode="softmax", sdp_backend=None, ): super().__init__() assert attn_mode in self.ATTENTION_MODES if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE: logpy.warn( f"Attention mode '{attn_mode}' is not available. Falling " f"back to native attention. This is not a problem in " f"Pytorch >= 2.0. FYI, you are running with PyTorch " f"version {torch.__version__}." ) attn_mode = "softmax" elif attn_mode == "softmax" and not SDP_IS_AVAILABLE: logpy.warn( "We do not support vanilla attention anymore, as it is too " "expensive. Sorry." ) if not XFORMERS_IS_AVAILABLE: assert ( False ), "Please install xformers via e.g. 'pip install xformers==0.0.16'" else: logpy.info("Falling back to xformers efficient attention.") attn_mode = "softmax-xformers" attn_cls = self.ATTENTION_MODES[attn_mode] if version.parse(torch.__version__) >= version.parse("2.0.0"): assert sdp_backend is None or isinstance(sdp_backend, SDPBackend) else: assert sdp_backend is None self.disable_self_attn = disable_self_attn self.attn1 = attn_cls( query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, context_dim=context_dim if self.disable_self_attn else None, backend=sdp_backend, ) # is a self-attention if not self.disable_self_attn self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.attn2 = attn_cls( query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout, backend=sdp_backend, ) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) self.checkpoint = checkpoint if self.checkpoint: logpy.debug(f"{self.__class__.__name__} is using checkpointing") def forward( self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 ): kwargs = {"x": x} if context is not None: kwargs.update({"context": context}) if additional_tokens is not None: kwargs.update({"additional_tokens": additional_tokens}) if n_times_crossframe_attn_in_self: kwargs.update( {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self} ) # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint) if self.checkpoint: # inputs = {"x": x, "context": context} return checkpoint(self._forward, x, context) # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint) else: return self._forward(**kwargs) def _forward( self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 ): x = ( self.attn1( self.norm1(x), context=context if self.disable_self_attn else None, additional_tokens=additional_tokens, n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0, ) + x ) x = ( self.attn2( self.norm2(x), context=context, additional_tokens=additional_tokens ) + x ) x = self.ff(self.norm3(x)) + x return x class BasicTransformerSingleLayerBlock(nn.Module): ATTENTION_MODES = { "softmax": CrossAttention, # vanilla attention "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128]) } def __init__( self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True, attn_mode="softmax", ): super().__init__() assert attn_mode in self.ATTENTION_MODES attn_cls = self.ATTENTION_MODES[attn_mode] self.attn1 = attn_cls( query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, context_dim=context_dim, ) self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.checkpoint = checkpoint def forward(self, x, context=None): # inputs = {"x": x, "context": context} # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint) return checkpoint(self._forward, x, context) def _forward(self, x, context=None): x = self.attn1(self.norm1(x), context=context) + x x = self.ff(self.norm2(x)) + x return x class SpatialTransformer(nn.Module): """ Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply standard transformer action. Finally, reshape to image NEW: use_linear for more efficiency instead of the 1x1 convs """ def __init__( self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None, disable_self_attn=False, use_linear=False, attn_type="softmax", use_checkpoint=True, # sdp_backend=SDPBackend.FLASH_ATTENTION sdp_backend=None, ): super().__init__() logpy.debug( f"constructing {self.__class__.__name__} of depth {depth} w/ " f"{in_channels} channels and {n_heads} heads." ) if exists(context_dim) and not isinstance(context_dim, list): context_dim = [context_dim] if exists(context_dim) and isinstance(context_dim, list): if depth != len(context_dim): logpy.warn( f"{self.__class__.__name__}: Found context dims " f"{context_dim} of depth {len(context_dim)}, which does not " f"match the specified 'depth' of {depth}. Setting context_dim " f"to {depth * [context_dim[0]]} now." ) # depth does not match context dims. assert all( map(lambda x: x == context_dim[0], context_dim) ), "need homogenous context_dim to match depth automatically" context_dim = depth * [context_dim[0]] elif context_dim is None: context_dim = [None] * depth self.in_channels = in_channels inner_dim = n_heads * d_head self.norm = Normalize(in_channels) if not use_linear: self.proj_in = nn.Conv2d( in_channels, inner_dim, kernel_size=1, stride=1, padding=0 ) else: self.proj_in = nn.Linear(in_channels, inner_dim) self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], disable_self_attn=disable_self_attn, attn_mode=attn_type, checkpoint=use_checkpoint, sdp_backend=sdp_backend, ) for d in range(depth) ] ) if not use_linear: self.proj_out = zero_module( nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) ) else: # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) self.use_linear = use_linear def forward(self, x, context=None): # note: if no context is given, cross-attention defaults to self-attention if not isinstance(context, list): context = [context] b, c, h, w = x.shape x_in = x x = self.norm(x) if not self.use_linear: x = self.proj_in(x) x = rearrange(x, "b c h w -> b (h w) c").contiguous() if self.use_linear: x = self.proj_in(x) for i, block in enumerate(self.transformer_blocks): if i > 0 and len(context) == 1: i = 0 # use same context for each block x = block(x, context=context[i]) if self.use_linear: x = self.proj_out(x) x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() if not self.use_linear: x = self.proj_out(x) return x + x_in class SimpleTransformer(nn.Module): def __init__( self, dim: int, depth: int, heads: int, dim_head: int, context_dim: Optional[int] = None, dropout: float = 0.0, checkpoint: bool = True, ): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( BasicTransformerBlock( dim, heads, dim_head, dropout=dropout, context_dim=context_dim, attn_mode="softmax-xformers", checkpoint=checkpoint, ) ) def forward( self, x: torch.Tensor, context: Optional[torch.Tensor] = None, ) -> torch.Tensor: for layer in self.layers: x = layer(x, context) return x ================================================ FILE: sgm/modules/autoencoding/__init__.py ================================================ ================================================ FILE: sgm/modules/autoencoding/losses/__init__.py ================================================ __all__ = [ "GeneralLPIPSWithDiscriminator", "LatentLPIPS", ] from .discriminator_loss import GeneralLPIPSWithDiscriminator from .lpips import LatentLPIPS ================================================ FILE: sgm/modules/autoencoding/losses/discriminator_loss.py ================================================ from typing import Dict, Iterator, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torchvision from einops import rearrange from matplotlib import colormaps from matplotlib import pyplot as plt from ....util import default, instantiate_from_config from ..lpips.loss.lpips import LPIPS from ..lpips.model.model import weights_init from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss class GeneralLPIPSWithDiscriminator(nn.Module): def __init__( self, disc_start: int, logvar_init: float = 0.0, disc_num_layers: int = 3, disc_in_channels: int = 3, disc_factor: float = 1.0, disc_weight: float = 1.0, perceptual_weight: float = 1.0, disc_loss: str = "hinge", scale_input_to_tgt_size: bool = False, dims: int = 2, learn_logvar: bool = False, regularization_weights: Union[None, Dict[str, float]] = None, additional_log_keys: Optional[List[str]] = None, discriminator_config: Optional[Dict] = None, ): super().__init__() self.dims = dims if self.dims > 2: print( f"running with dims={dims}. This means that for perceptual loss " f"calculation, the LPIPS loss will be applied to each frame " f"independently." ) self.scale_input_to_tgt_size = scale_input_to_tgt_size assert disc_loss in ["hinge", "vanilla"] self.perceptual_loss = LPIPS().eval() self.perceptual_weight = perceptual_weight # output log variance self.logvar = nn.Parameter( torch.full((), logvar_init), requires_grad=learn_logvar ) self.learn_logvar = learn_logvar discriminator_config = default( discriminator_config, { "target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator", "params": { "input_nc": disc_in_channels, "n_layers": disc_num_layers, "use_actnorm": False, }, }, ) self.discriminator = instantiate_from_config(discriminator_config).apply( weights_init ) self.discriminator_iter_start = disc_start self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss self.disc_factor = disc_factor self.discriminator_weight = disc_weight self.regularization_weights = default(regularization_weights, {}) self.forward_keys = [ "optimizer_idx", "global_step", "last_layer", "split", "regularization_log", ] self.additional_log_keys = set(default(additional_log_keys, [])) self.additional_log_keys.update(set(self.regularization_weights.keys())) def get_trainable_parameters(self) -> Iterator[nn.Parameter]: return self.discriminator.parameters() def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]: if self.learn_logvar: yield self.logvar yield from () @torch.no_grad() def log_images( self, inputs: torch.Tensor, reconstructions: torch.Tensor ) -> Dict[str, torch.Tensor]: # calc logits of real/fake logits_real = self.discriminator(inputs.contiguous().detach()) if len(logits_real.shape) < 4: # Non patch-discriminator return dict() logits_fake = self.discriminator(reconstructions.contiguous().detach()) # -> (b, 1, h, w) # parameters for colormapping high = max(logits_fake.abs().max(), logits_real.abs().max()).item() cmap = colormaps["PiYG"] # diverging colormap def to_colormap(logits: torch.Tensor) -> torch.Tensor: """(b, 1, ...) -> (b, 3, ...)""" logits = (logits + high) / (2 * high) logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel # -> (b, 1, ..., 3) logits = torch.from_numpy(logits_np).to(logits.device) return rearrange(logits, "b 1 ... c -> b c ...") logits_real = torch.nn.functional.interpolate( logits_real, size=inputs.shape[-2:], mode="nearest", antialias=False, ) logits_fake = torch.nn.functional.interpolate( logits_fake, size=reconstructions.shape[-2:], mode="nearest", antialias=False, ) # alpha value of logits for overlay alpha_real = torch.abs(logits_real) / high alpha_fake = torch.abs(logits_fake) / high # -> (b, 1, h, w) in range [0, 0.5] # alpha value of lines don't really matter, since the values are the same # for both images and logits anyway grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4) grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4) grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1) # -> (1, h, w) # blend logits and images together # prepare logits for plotting logits_real = to_colormap(logits_real) logits_fake = to_colormap(logits_fake) # resize logits # -> (b, 3, h, w) # make some grids # add all logits to one plot logits_real = torchvision.utils.make_grid(logits_real, nrow=4) logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4) # I just love how torchvision calls the number of columns `nrow` grid_logits = torch.cat((logits_real, logits_fake), dim=1) # -> (3, h, w) grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4) grid_images_fake = torchvision.utils.make_grid( 0.5 * reconstructions + 0.5, nrow=4 ) grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1) # -> (3, h, w) in range [0, 1] grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images # Create labeled colorbar dpi = 100 height = 128 / dpi width = grid_logits.shape[2] / dpi fig, ax = plt.subplots(figsize=(width, height), dpi=dpi) img = ax.imshow(np.array([[-high, high]]), cmap=cmap) plt.colorbar( img, cax=ax, orientation="horizontal", fraction=0.9, aspect=width / height, pad=0.0, ) img.set_visible(False) fig.tight_layout() fig.canvas.draw() # manually convert figure to numpy cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,)) cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0 cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device) # Add colorbar to plot annotated_grid = torch.cat((grid_logits, cbar), dim=1) blended_grid = torch.cat((grid_blend, cbar), dim=1) return { "vis_logits": 2 * annotated_grid[None, ...] - 1, "vis_logits_blended": 2 * blended_grid[None, ...] - 1, } def calculate_adaptive_weight( self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor ) -> torch.Tensor: nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() d_weight = d_weight * self.discriminator_weight return d_weight def forward( self, inputs: torch.Tensor, reconstructions: torch.Tensor, *, # added because I changed the order here regularization_log: Dict[str, torch.Tensor], optimizer_idx: int, global_step: int, last_layer: torch.Tensor, split: str = "train", weights: Union[None, float, torch.Tensor] = None, ) -> Tuple[torch.Tensor, dict]: if self.scale_input_to_tgt_size: inputs = torch.nn.functional.interpolate( inputs, reconstructions.shape[2:], mode="bicubic", antialias=True ) if self.dims > 2: inputs, reconstructions = map( lambda x: rearrange(x, "b c t h w -> (b t) c h w"), (inputs, reconstructions), ) rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) if self.perceptual_weight > 0: p_loss = self.perceptual_loss( inputs.contiguous(), reconstructions.contiguous() ) rec_loss = rec_loss + self.perceptual_weight * p_loss nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights) # now the GAN part if optimizer_idx == 0: # generator update if global_step >= self.discriminator_iter_start or not self.training: logits_fake = self.discriminator(reconstructions.contiguous()) g_loss = -torch.mean(logits_fake) if self.training: d_weight = self.calculate_adaptive_weight( nll_loss, g_loss, last_layer=last_layer ) else: d_weight = torch.tensor(1.0) else: d_weight = torch.tensor(0.0) g_loss = torch.tensor(0.0, requires_grad=True) loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss log = dict() for k in regularization_log: if k in self.regularization_weights: loss = loss + self.regularization_weights[k] * regularization_log[k] if k in self.additional_log_keys: log[f"{split}/{k}"] = regularization_log[k].detach().float().mean() log.update( { f"{split}/loss/total": loss.clone().detach().mean(), f"{split}/loss/nll": nll_loss.detach().mean(), f"{split}/loss/rec": rec_loss.detach().mean(), f"{split}/loss/g": g_loss.detach().mean(), f"{split}/scalars/logvar": self.logvar.detach(), f"{split}/scalars/d_weight": d_weight.detach(), } ) return loss, log elif optimizer_idx == 1: # second pass for discriminator update logits_real = self.discriminator(inputs.contiguous().detach()) logits_fake = self.discriminator(reconstructions.contiguous().detach()) if global_step >= self.discriminator_iter_start or not self.training: d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake) else: d_loss = torch.tensor(0.0, requires_grad=True) log = { f"{split}/loss/disc": d_loss.clone().detach().mean(), f"{split}/logits/real": logits_real.detach().mean(), f"{split}/logits/fake": logits_fake.detach().mean(), } return d_loss, log else: raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}") def get_nll_loss( self, rec_loss: torch.Tensor, weights: Optional[Union[float, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar weighted_nll_loss = nll_loss if weights is not None: weighted_nll_loss = weights * nll_loss weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] return nll_loss, weighted_nll_loss ================================================ FILE: sgm/modules/autoencoding/losses/lpips.py ================================================ import torch import torch.nn as nn from ....util import default, instantiate_from_config from ..lpips.loss.lpips import LPIPS class LatentLPIPS(nn.Module): def __init__( self, decoder_config, perceptual_weight=1.0, latent_weight=1.0, scale_input_to_tgt_size=False, scale_tgt_to_input_size=False, perceptual_weight_on_inputs=0.0, ): super().__init__() self.scale_input_to_tgt_size = scale_input_to_tgt_size self.scale_tgt_to_input_size = scale_tgt_to_input_size self.init_decoder(decoder_config) self.perceptual_loss = LPIPS().eval() self.perceptual_weight = perceptual_weight self.latent_weight = latent_weight self.perceptual_weight_on_inputs = perceptual_weight_on_inputs def init_decoder(self, config): self.decoder = instantiate_from_config(config) if hasattr(self.decoder, "encoder"): del self.decoder.encoder def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"): log = dict() loss = (latent_inputs - latent_predictions) ** 2 log[f"{split}/latent_l2_loss"] = loss.mean().detach() image_reconstructions = None if self.perceptual_weight > 0.0: image_reconstructions = self.decoder.decode(latent_predictions) image_targets = self.decoder.decode(latent_inputs) perceptual_loss = self.perceptual_loss( image_targets.contiguous(), image_reconstructions.contiguous() ) loss = ( self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean() ) log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach() if self.perceptual_weight_on_inputs > 0.0: image_reconstructions = default( image_reconstructions, self.decoder.decode(latent_predictions) ) if self.scale_input_to_tgt_size: image_inputs = torch.nn.functional.interpolate( image_inputs, image_reconstructions.shape[2:], mode="bicubic", antialias=True, ) elif self.scale_tgt_to_input_size: image_reconstructions = torch.nn.functional.interpolate( image_reconstructions, image_inputs.shape[2:], mode="bicubic", antialias=True, ) perceptual_loss2 = self.perceptual_loss( image_inputs.contiguous(), image_reconstructions.contiguous() ) loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean() log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach() return loss, log ================================================ FILE: sgm/modules/autoencoding/lpips/__init__.py ================================================ ================================================ FILE: sgm/modules/autoencoding/lpips/loss/.gitignore ================================================ vgg.pth ================================================ FILE: sgm/modules/autoencoding/lpips/loss/LICENSE ================================================ Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: sgm/modules/autoencoding/lpips/loss/__init__.py ================================================ ================================================ FILE: sgm/modules/autoencoding/lpips/loss/lpips.py ================================================ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" from collections import namedtuple import torch import torch.nn as nn from torchvision import models from ..util import get_ckpt_path class LPIPS(nn.Module): # Learned perceptual metric def __init__(self, use_dropout=True): super().__init__() self.scaling_layer = ScalingLayer() self.chns = [64, 128, 256, 512, 512] # vg16 features self.net = vgg16(pretrained=True, requires_grad=False) self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) self.load_from_pretrained() for param in self.parameters(): param.requires_grad = False def load_from_pretrained(self, name="vgg_lpips"): ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss") self.load_state_dict( torch.load(ckpt, map_location=torch.device("cpu")), strict=False ) print("loaded pretrained LPIPS loss from {}".format(ckpt)) @classmethod def from_pretrained(cls, name="vgg_lpips"): if name != "vgg_lpips": raise NotImplementedError model = cls() ckpt = get_ckpt_path(name) model.load_state_dict( torch.load(ckpt, map_location=torch.device("cpu")), strict=False ) return model def forward(self, input, target): print(input.shape, target.shape) in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) print(in0_input.shape, in1_input.shape) exit(0) outs0, outs1 = self.net(in0_input), self.net(in1_input) feats0, feats1, diffs = {}, {}, {} lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] for kk in range(len(self.chns)): feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( outs1[kk] ) diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 res = [ spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns)) ] val = res[0] for l in range(1, len(self.chns)): val += res[l] return val class ScalingLayer(nn.Module): def __init__(self): super(ScalingLayer, self).__init__() self.register_buffer( "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] ) self.register_buffer( "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] ) def forward(self, inp): return (inp - self.shift) / self.scale class NetLinLayer(nn.Module): """A single linear layer which does a 1x1 conv""" def __init__(self, chn_in, chn_out=1, use_dropout=False): super(NetLinLayer, self).__init__() layers = ( [ nn.Dropout(), ] if (use_dropout) else [] ) layers += [ nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] self.model = nn.Sequential(*layers) class vgg16(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): super(vgg16, self).__init__() vgg_pretrained_features = models.vgg16(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.N_slices = 5 for x in range(4): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(4, 9): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(9, 16): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(16, 23): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(23, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h = self.slice1(X) h_relu1_2 = h h = self.slice2(h) h_relu2_2 = h h = self.slice3(h) h_relu3_3 = h h = self.slice4(h) h_relu4_3 = h h = self.slice5(h) h_relu5_3 = h vgg_outputs = namedtuple( "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] ) out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) return out def normalize_tensor(x, eps=1e-10): norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) return x / (norm_factor + eps) def spatial_average(x, keepdim=True): return x.mean([2, 3], keepdim=keepdim) ================================================ FILE: sgm/modules/autoencoding/lpips/model/LICENSE ================================================ Copyright (c) 2017, Jun-Yan Zhu and Taesung Park All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. --------------------------- LICENSE FOR pix2pix -------------------------------- BSD License For pix2pix software Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. ----------------------------- LICENSE FOR DCGAN -------------------------------- BSD License For dcgan.torch software Copyright (c) 2015, Facebook, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: sgm/modules/autoencoding/lpips/model/__init__.py ================================================ ================================================ FILE: sgm/modules/autoencoding/lpips/model/model.py ================================================ import functools import torch.nn as nn from ..util import ActNorm def weights_init(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find("BatchNorm") != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0) class NLayerDiscriminator(nn.Module): """Defines a PatchGAN discriminator as in Pix2Pix --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py """ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): """Construct a PatchGAN discriminator Parameters: input_nc (int) -- the number of channels in input images ndf (int) -- the number of filters in the last conv layer n_layers (int) -- the number of conv layers in the discriminator norm_layer -- normalization layer """ super(NLayerDiscriminator, self).__init__() if not use_actnorm: norm_layer = nn.BatchNorm2d else: norm_layer = ActNorm if ( type(norm_layer) == functools.partial ): # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func != nn.BatchNorm2d else: use_bias = norm_layer != nn.BatchNorm2d kw = 4 padw = 1 sequence = [ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True), ] nf_mult = 1 nf_mult_prev = 1 for n in range(1, n_layers): # gradually increase the number of filters nf_mult_prev = nf_mult nf_mult = min(2**n, 8) sequence += [ nn.Conv2d( ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias, ), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True), ] nf_mult_prev = nf_mult nf_mult = min(2**n_layers, 8) sequence += [ nn.Conv2d( ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias, ), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True), ] sequence += [ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) ] # output 1 channel prediction map self.main = nn.Sequential(*sequence) def forward(self, input): """Standard forward.""" return self.main(input) ================================================ FILE: sgm/modules/autoencoding/lpips/util.py ================================================ import hashlib import os import requests import torch import torch.nn as nn from tqdm import tqdm URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} CKPT_MAP = {"vgg_lpips": "vgg.pth"} MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} def download(url, local_path, chunk_size=1024): os.makedirs(os.path.split(local_path)[0], exist_ok=True) with requests.get(url, stream=True) as r: total_size = int(r.headers.get("content-length", 0)) with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: with open(local_path, "wb") as f: for data in r.iter_content(chunk_size=chunk_size): if data: f.write(data) pbar.update(chunk_size) def md5_hash(path): with open(path, "rb") as f: content = f.read() return hashlib.md5(content).hexdigest() def get_ckpt_path(name, root, check=False): assert name in URL_MAP path = os.path.join(root, CKPT_MAP[name]) if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) download(URL_MAP[name], path) md5 = md5_hash(path) assert md5 == MD5_MAP[name], md5 return path class ActNorm(nn.Module): def __init__( self, num_features, logdet=False, affine=True, allow_reverse_init=False ): assert affine super().__init__() self.logdet = logdet self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) self.allow_reverse_init = allow_reverse_init self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) def initialize(self, input): with torch.no_grad(): flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) mean = ( flatten.mean(1) .unsqueeze(1) .unsqueeze(2) .unsqueeze(3) .permute(1, 0, 2, 3) ) std = ( flatten.std(1) .unsqueeze(1) .unsqueeze(2) .unsqueeze(3) .permute(1, 0, 2, 3) ) self.loc.data.copy_(-mean) self.scale.data.copy_(1 / (std + 1e-6)) def forward(self, input, reverse=False): if reverse: return self.reverse(input) if len(input.shape) == 2: input = input[:, :, None, None] squeeze = True else: squeeze = False _, _, height, width = input.shape if self.training and self.initialized.item() == 0: self.initialize(input) self.initialized.fill_(1) h = self.scale * (input + self.loc) if squeeze: h = h.squeeze(-1).squeeze(-1) if self.logdet: log_abs = torch.log(torch.abs(self.scale)) logdet = height * width * torch.sum(log_abs) logdet = logdet * torch.ones(input.shape[0]).to(input) return h, logdet return h def reverse(self, output): if self.training and self.initialized.item() == 0: if not self.allow_reverse_init: raise RuntimeError( "Initializing ActNorm in reverse direction is " "disabled by default. Use allow_reverse_init=True to enable." ) else: self.initialize(output) self.initialized.fill_(1) if len(output.shape) == 2: output = output[:, :, None, None] squeeze = True else: squeeze = False h = output / self.scale - self.loc if squeeze: h = h.squeeze(-1).squeeze(-1) return h ================================================ FILE: sgm/modules/autoencoding/lpips/vqperceptual.py ================================================ import torch import torch.nn.functional as F def hinge_d_loss(logits_real, logits_fake): loss_real = torch.mean(F.relu(1.0 - logits_real)) loss_fake = torch.mean(F.relu(1.0 + logits_fake)) d_loss = 0.5 * (loss_real + loss_fake) return d_loss def vanilla_d_loss(logits_real, logits_fake): d_loss = 0.5 * ( torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake)) ) return d_loss ================================================ FILE: sgm/modules/autoencoding/regularizers/__init__.py ================================================ from abc import abstractmethod from typing import Any, Tuple import torch import torch.nn as nn import torch.nn.functional as F from ....modules.distributions.distributions import \ DiagonalGaussianDistribution from .base import AbstractRegularizer class DiagonalGaussianRegularizer(AbstractRegularizer): def __init__(self, sample: bool = True): super().__init__() self.sample = sample def get_trainable_parameters(self) -> Any: yield from () def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: log = dict() posterior = DiagonalGaussianDistribution(z) if self.sample: z = posterior.sample() else: z = posterior.mode() kl_loss = posterior.kl() kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] log["kl_loss"] = kl_loss return z, log ================================================ FILE: sgm/modules/autoencoding/regularizers/base.py ================================================ from abc import abstractmethod from typing import Any, Tuple import torch import torch.nn.functional as F from torch import nn class AbstractRegularizer(nn.Module): def __init__(self): super().__init__() def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: raise NotImplementedError() @abstractmethod def get_trainable_parameters(self) -> Any: raise NotImplementedError() class IdentityRegularizer(AbstractRegularizer): def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: return z, dict() def get_trainable_parameters(self) -> Any: yield from () def measure_perplexity( predicted_indices: torch.Tensor, num_centroids: int ) -> Tuple[torch.Tensor, torch.Tensor]: # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally encodings = ( F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) ) avg_probs = encodings.mean(0) perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() cluster_use = torch.sum(avg_probs > 0) return perplexity, cluster_use ================================================ FILE: sgm/modules/autoencoding/regularizers/quantize.py ================================================ import logging from abc import abstractmethod from typing import Dict, Iterator, Literal, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torch import einsum from .base import AbstractRegularizer, measure_perplexity logpy = logging.getLogger(__name__) class AbstractQuantizer(AbstractRegularizer): def __init__(self): super().__init__() # Define these in your init # shape (N,) self.used: Optional[torch.Tensor] self.re_embed: int self.unknown_index: Union[Literal["random"], int] def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor: assert self.used is not None, "You need to define used indices for remap" ishape = inds.shape assert len(ishape) > 1 inds = inds.reshape(ishape[0], -1) used = self.used.to(inds) match = (inds[:, :, None] == used[None, None, ...]).long() new = match.argmax(-1) unknown = match.sum(2) < 1 if self.unknown_index == "random": new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to( device=new.device ) else: new[unknown] = self.unknown_index return new.reshape(ishape) def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor: assert self.used is not None, "You need to define used indices for remap" ishape = inds.shape assert len(ishape) > 1 inds = inds.reshape(ishape[0], -1) used = self.used.to(inds) if self.re_embed > self.used.shape[0]: # extra token inds[inds >= self.used.shape[0]] = 0 # simply set to zero back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) return back.reshape(ishape) @abstractmethod def get_codebook_entry( self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None ) -> torch.Tensor: raise NotImplementedError() def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]: yield from self.parameters() class GumbelQuantizer(AbstractQuantizer): """ credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) Gumbel Softmax trick quantizer Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 https://arxiv.org/abs/1611.01144 """ def __init__( self, num_hiddens: int, embedding_dim: int, n_embed: int, straight_through: bool = True, kl_weight: float = 5e-4, temp_init: float = 1.0, remap: Optional[str] = None, unknown_index: str = "random", loss_key: str = "loss/vq", ) -> None: super().__init__() self.loss_key = loss_key self.embedding_dim = embedding_dim self.n_embed = n_embed self.straight_through = straight_through self.temperature = temp_init self.kl_weight = kl_weight self.proj = nn.Conv2d(num_hiddens, n_embed, 1) self.embed = nn.Embedding(n_embed, embedding_dim) self.remap = remap if self.remap is not None: self.register_buffer("used", torch.tensor(np.load(self.remap))) self.re_embed = self.used.shape[0] else: self.used = None self.re_embed = n_embed if unknown_index == "extra": self.unknown_index = self.re_embed self.re_embed = self.re_embed + 1 else: assert unknown_index == "random" or isinstance( unknown_index, int ), "unknown index needs to be 'random', 'extra' or any integer" self.unknown_index = unknown_index # "random" or "extra" or integer if self.remap is not None: logpy.info( f"Remapping {self.n_embed} indices to {self.re_embed} indices. " f"Using {self.unknown_index} for unknown indices." ) def forward( self, z: torch.Tensor, temp: Optional[float] = None, return_logits: bool = False ) -> Tuple[torch.Tensor, Dict]: # force hard = True when we are in eval mode, as we must quantize. # actually, always true seems to work hard = self.straight_through if self.training else True temp = self.temperature if temp is None else temp out_dict = {} logits = self.proj(z) if self.remap is not None: # continue only with used logits full_zeros = torch.zeros_like(logits) logits = logits[:, self.used, ...] soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) if self.remap is not None: # go back to all entries but unused set to zero full_zeros[:, self.used, ...] = soft_one_hot soft_one_hot = full_zeros z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) # + kl divergence to the prior loss qy = F.softmax(logits, dim=1) diff = ( self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() ) out_dict[self.loss_key] = diff ind = soft_one_hot.argmax(dim=1) out_dict["indices"] = ind if self.remap is not None: ind = self.remap_to_used(ind) if return_logits: out_dict["logits"] = logits return z_q, out_dict def get_codebook_entry(self, indices, shape): # TODO: shape not yet optional b, h, w, c = shape assert b * h * w == indices.shape[0] indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w) if self.remap is not None: indices = self.unmap_to_all(indices) one_hot = ( F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() ) z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight) return z_q class VectorQuantizer(AbstractQuantizer): """ ____________________________________________ Discretization bottleneck part of the VQ-VAE. Inputs: - n_e : number of embeddings - e_dim : dimension of embedding - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 _____________________________________________ """ def __init__( self, n_e: int, e_dim: int, beta: float = 0.25, remap: Optional[str] = None, unknown_index: str = "random", sane_index_shape: bool = False, log_perplexity: bool = False, embedding_weight_norm: bool = False, loss_key: str = "loss/vq", ): super().__init__() self.n_e = n_e self.e_dim = e_dim self.beta = beta self.loss_key = loss_key if not embedding_weight_norm: self.embedding = nn.Embedding(self.n_e, self.e_dim) self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) else: self.embedding = torch.nn.utils.weight_norm( nn.Embedding(self.n_e, self.e_dim), dim=1 ) self.remap = remap if self.remap is not None: self.register_buffer("used", torch.tensor(np.load(self.remap))) self.re_embed = self.used.shape[0] else: self.used = None self.re_embed = n_e if unknown_index == "extra": self.unknown_index = self.re_embed self.re_embed = self.re_embed + 1 else: assert unknown_index == "random" or isinstance( unknown_index, int ), "unknown index needs to be 'random', 'extra' or any integer" self.unknown_index = unknown_index # "random" or "extra" or integer if self.remap is not None: logpy.info( f"Remapping {self.n_e} indices to {self.re_embed} indices. " f"Using {self.unknown_index} for unknown indices." ) self.sane_index_shape = sane_index_shape self.log_perplexity = log_perplexity def forward( self, z: torch.Tensor, ) -> Tuple[torch.Tensor, Dict]: do_reshape = z.ndim == 4 if do_reshape: # # reshape z -> (batch, height, width, channel) and flatten z = rearrange(z, "b c h w -> b h w c").contiguous() else: assert z.ndim < 4, "No reshaping strategy for inputs > 4 dimensions defined" z = z.contiguous() z_flattened = z.view(-1, self.e_dim) # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z d = ( torch.sum(z_flattened**2, dim=1, keepdim=True) + torch.sum(self.embedding.weight**2, dim=1) - 2 * torch.einsum( "bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n") ) ) min_encoding_indices = torch.argmin(d, dim=1) z_q = self.embedding(min_encoding_indices).view(z.shape) loss_dict = {} if self.log_perplexity: perplexity, cluster_usage = measure_perplexity( min_encoding_indices.detach(), self.n_e ) loss_dict.update({"perplexity": perplexity, "cluster_usage": cluster_usage}) # compute loss for embedding loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean( (z_q - z.detach()) ** 2 ) loss_dict[self.loss_key] = loss # preserve gradients z_q = z + (z_q - z).detach() # reshape back to match original input shape if do_reshape: z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() if self.remap is not None: min_encoding_indices = min_encoding_indices.reshape( z.shape[0], -1 ) # add batch axis min_encoding_indices = self.remap_to_used(min_encoding_indices) min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten if self.sane_index_shape: if do_reshape: min_encoding_indices = min_encoding_indices.reshape( z_q.shape[0], z_q.shape[2], z_q.shape[3] ) else: min_encoding_indices = rearrange( min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0] ) loss_dict["min_encoding_indices"] = min_encoding_indices return z_q, loss_dict def get_codebook_entry( self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None ) -> torch.Tensor: # shape specifying (batch, height, width, channel) if self.remap is not None: assert shape is not None, "Need to give shape for remap" indices = indices.reshape(shape[0], -1) # add batch axis indices = self.unmap_to_all(indices) indices = indices.reshape(-1) # flatten again # get quantized latent vectors z_q = self.embedding(indices) if shape is not None: z_q = z_q.view(shape) # reshape back to match original input shape z_q = z_q.permute(0, 3, 1, 2).contiguous() return z_q class EmbeddingEMA(nn.Module): def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): super().__init__() self.decay = decay self.eps = eps weight = torch.randn(num_tokens, codebook_dim) self.weight = nn.Parameter(weight, requires_grad=False) self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) self.update = True def forward(self, embed_id): return F.embedding(embed_id, self.weight) def cluster_size_ema_update(self, new_cluster_size): self.cluster_size.data.mul_(self.decay).add_( new_cluster_size, alpha=1 - self.decay ) def embed_avg_ema_update(self, new_embed_avg): self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) def weight_update(self, num_tokens): n = self.cluster_size.sum() smoothed_cluster_size = ( (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n ) # normalize embedding average with smoothed cluster size embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) self.weight.data.copy_(embed_normalized) class EMAVectorQuantizer(AbstractQuantizer): def __init__( self, n_embed: int, embedding_dim: int, beta: float, decay: float = 0.99, eps: float = 1e-5, remap: Optional[str] = None, unknown_index: str = "random", loss_key: str = "loss/vq", ): super().__init__() self.codebook_dim = embedding_dim self.num_tokens = n_embed self.beta = beta self.loss_key = loss_key self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) self.remap = remap if self.remap is not None: self.register_buffer("used", torch.tensor(np.load(self.remap))) self.re_embed = self.used.shape[0] else: self.used = None self.re_embed = n_embed if unknown_index == "extra": self.unknown_index = self.re_embed self.re_embed = self.re_embed + 1 else: assert unknown_index == "random" or isinstance( unknown_index, int ), "unknown index needs to be 'random', 'extra' or any integer" self.unknown_index = unknown_index # "random" or "extra" or integer if self.remap is not None: logpy.info( f"Remapping {self.n_embed} indices to {self.re_embed} indices. " f"Using {self.unknown_index} for unknown indices." ) def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]: # reshape z -> (batch, height, width, channel) and flatten # z, 'b c h w -> b h w c' z = rearrange(z, "b c h w -> b h w c") z_flattened = z.reshape(-1, self.codebook_dim) # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z d = ( z_flattened.pow(2).sum(dim=1, keepdim=True) + self.embedding.weight.pow(2).sum(dim=1) - 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight) ) # 'n d -> d n' encoding_indices = torch.argmin(d, dim=1) z_q = self.embedding(encoding_indices).view(z.shape) encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) avg_probs = torch.mean(encodings, dim=0) perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) if self.training and self.embedding.update: # EMA cluster size encodings_sum = encodings.sum(0) self.embedding.cluster_size_ema_update(encodings_sum) # EMA embedding average embed_sum = encodings.transpose(0, 1) @ z_flattened self.embedding.embed_avg_ema_update(embed_sum) # normalize embed_avg and update weight self.embedding.weight_update(self.num_tokens) # compute loss for embedding loss = self.beta * F.mse_loss(z_q.detach(), z) # preserve gradients z_q = z + (z_q - z).detach() # reshape back to match original input shape # z_q, 'b h w c -> b c h w' z_q = rearrange(z_q, "b h w c -> b c h w") out_dict = { self.loss_key: loss, "encodings": encodings, "encoding_indices": encoding_indices, "perplexity": perplexity, } return z_q, out_dict class VectorQuantizerWithInputProjection(VectorQuantizer): def __init__( self, input_dim: int, n_codes: int, codebook_dim: int, beta: float = 1.0, output_dim: Optional[int] = None, **kwargs, ): super().__init__(n_codes, codebook_dim, beta, **kwargs) self.proj_in = nn.Linear(input_dim, codebook_dim) self.output_dim = output_dim if output_dim is not None: self.proj_out = nn.Linear(codebook_dim, output_dim) else: self.proj_out = nn.Identity() def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]: rearr = False in_shape = z.shape if z.ndim > 3: rearr = self.output_dim is not None z = rearrange(z, "b c ... -> b (...) c") z = self.proj_in(z) z_q, loss_dict = super().forward(z) z_q = self.proj_out(z_q) if rearr: if len(in_shape) == 4: z_q = rearrange(z_q, "b (h w) c -> b c h w ", w=in_shape[-1]) elif len(in_shape) == 5: z_q = rearrange( z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2] ) else: raise NotImplementedError( f"rearranging not available for {len(in_shape)}-dimensional input." ) return z_q, loss_dict ================================================ FILE: sgm/modules/autoencoding/temporal_ae.py ================================================ from typing import Callable, Iterable, Union import torch from einops import rearrange, repeat from sgm.modules.diffusionmodules.model import ( XFORMERS_IS_AVAILABLE, AttnBlock, Decoder, MemoryEfficientAttnBlock, ResnetBlock, ) from sgm.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding from sgm.modules.video_attention import VideoTransformerBlock from sgm.util import partialclass class VideoResBlock(ResnetBlock): def __init__( self, out_channels, *args, dropout=0.0, video_kernel_size=3, alpha=0.0, merge_strategy="learned", **kwargs, ): super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs) if video_kernel_size is None: video_kernel_size = [3, 1, 1] self.time_stack = ResBlock( channels=out_channels, emb_channels=0, dropout=dropout, dims=3, use_scale_shift_norm=False, use_conv=False, up=False, down=False, kernel_size=video_kernel_size, use_checkpoint=False, skip_t_emb=True, ) self.merge_strategy = merge_strategy if self.merge_strategy == "fixed": self.register_buffer("mix_factor", torch.Tensor([alpha])) elif self.merge_strategy == "learned": self.register_parameter( "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) ) else: raise ValueError(f"unknown merge strategy {self.merge_strategy}") def get_alpha(self, bs): if self.merge_strategy == "fixed": return self.mix_factor elif self.merge_strategy == "learned": return torch.sigmoid(self.mix_factor) else: raise NotImplementedError() def forward(self, x, temb, skip_video=False, timesteps=None): if timesteps is None: timesteps = self.timesteps b, c, h, w = x.shape x = super().forward(x, temb) if not skip_video: x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) x = self.time_stack(x, temb) alpha = self.get_alpha(bs=b // timesteps) x = alpha * x + (1.0 - alpha) * x_mix x = rearrange(x, "b c t h w -> (b t) c h w") return x class AE3DConv(torch.nn.Conv2d): def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): super().__init__(in_channels, out_channels, *args, **kwargs) if isinstance(video_kernel_size, Iterable): padding = [int(k // 2) for k in video_kernel_size] else: padding = int(video_kernel_size // 2) self.time_mix_conv = torch.nn.Conv3d( in_channels=out_channels, out_channels=out_channels, kernel_size=video_kernel_size, padding=padding, ) def forward(self, input, timesteps, skip_video=False): x = super().forward(input) if skip_video: return x x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) x = self.time_mix_conv(x) return rearrange(x, "b c t h w -> (b t) c h w") class VideoBlock(AttnBlock): def __init__( self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned" ): super().__init__(in_channels) # no context, single headed, as in base class self.time_mix_block = VideoTransformerBlock( dim=in_channels, n_heads=1, d_head=in_channels, checkpoint=False, ff_in=True, attn_mode="softmax", ) time_embed_dim = self.in_channels * 4 self.video_time_embed = torch.nn.Sequential( torch.nn.Linear(self.in_channels, time_embed_dim), torch.nn.SiLU(), torch.nn.Linear(time_embed_dim, self.in_channels), ) self.merge_strategy = merge_strategy if self.merge_strategy == "fixed": self.register_buffer("mix_factor", torch.Tensor([alpha])) elif self.merge_strategy == "learned": self.register_parameter( "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) ) else: raise ValueError(f"unknown merge strategy {self.merge_strategy}") def forward(self, x, timesteps, skip_video=False): if skip_video: return super().forward(x) x_in = x x = self.attention(x) h, w = x.shape[2:] x = rearrange(x, "b c h w -> b (h w) c") x_mix = x num_frames = torch.arange(timesteps, device=x.device) num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) num_frames = rearrange(num_frames, "b t -> (b t)") t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) emb = self.video_time_embed(t_emb) # b, n_channels emb = emb[:, None, :] x_mix = x_mix + emb alpha = self.get_alpha() x_mix = self.time_mix_block(x_mix, timesteps=timesteps) x = alpha * x + (1.0 - alpha) * x_mix # alpha merge x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) x = self.proj_out(x) return x_in + x def get_alpha( self, ): if self.merge_strategy == "fixed": return self.mix_factor elif self.merge_strategy == "learned": return torch.sigmoid(self.mix_factor) else: raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock): def __init__( self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned" ): super().__init__(in_channels) # no context, single headed, as in base class self.time_mix_block = VideoTransformerBlock( dim=in_channels, n_heads=1, d_head=in_channels, checkpoint=False, ff_in=True, attn_mode="softmax-xformers", ) time_embed_dim = self.in_channels * 4 self.video_time_embed = torch.nn.Sequential( torch.nn.Linear(self.in_channels, time_embed_dim), torch.nn.SiLU(), torch.nn.Linear(time_embed_dim, self.in_channels), ) self.merge_strategy = merge_strategy if self.merge_strategy == "fixed": self.register_buffer("mix_factor", torch.Tensor([alpha])) elif self.merge_strategy == "learned": self.register_parameter( "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) ) else: raise ValueError(f"unknown merge strategy {self.merge_strategy}") def forward(self, x, timesteps, skip_time_block=False): if skip_time_block: return super().forward(x) x_in = x x = self.attention(x) h, w = x.shape[2:] x = rearrange(x, "b c h w -> b (h w) c") x_mix = x num_frames = torch.arange(timesteps, device=x.device) num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) num_frames = rearrange(num_frames, "b t -> (b t)") t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) emb = self.video_time_embed(t_emb) # b, n_channels emb = emb[:, None, :] x_mix = x_mix + emb alpha = self.get_alpha() x_mix = self.time_mix_block(x_mix, timesteps=timesteps) x = alpha * x + (1.0 - alpha) * x_mix # alpha merge x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) x = self.proj_out(x) return x_in + x def get_alpha( self, ): if self.merge_strategy == "fixed": return self.mix_factor elif self.merge_strategy == "learned": return torch.sigmoid(self.mix_factor) else: raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") def make_time_attn( in_channels, attn_type="vanilla", attn_kwargs=None, alpha: float = 0, merge_strategy: str = "learned", ): assert attn_type in [ "vanilla", "vanilla-xformers", ], f"attn_type {attn_type} not supported for spatio-temporal attention" print( f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels" ) if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers": print( f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. " f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}" ) attn_type = "vanilla" if attn_type == "vanilla": assert attn_kwargs is None return partialclass( VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy ) elif attn_type == "vanilla-xformers": print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") return partialclass( MemoryEfficientVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy, ) else: return NotImplementedError() class Conv2DWrapper(torch.nn.Conv2d): def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor: return super().forward(input) class VideoDecoder(Decoder): available_time_modes = ["all", "conv-only", "attn-only"] def __init__( self, *args, video_kernel_size: Union[int, list] = 3, alpha: float = 0.0, merge_strategy: str = "learned", time_mode: str = "conv-only", **kwargs, ): self.video_kernel_size = video_kernel_size self.alpha = alpha self.merge_strategy = merge_strategy self.time_mode = time_mode assert ( self.time_mode in self.available_time_modes ), f"time_mode parameter has to be in {self.available_time_modes}" super().__init__(*args, **kwargs) def get_last_layer(self, skip_time_mix=False, **kwargs): if self.time_mode == "attn-only": raise NotImplementedError("TODO") else: return ( self.conv_out.time_mix_conv.weight if not skip_time_mix else self.conv_out.weight ) def _make_attn(self) -> Callable: if self.time_mode not in ["conv-only", "only-last-conv"]: return partialclass( make_time_attn, alpha=self.alpha, merge_strategy=self.merge_strategy, ) else: return super()._make_attn() def _make_conv(self) -> Callable: if self.time_mode != "attn-only": return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size) else: return Conv2DWrapper def _make_resblock(self) -> Callable: if self.time_mode not in ["attn-only", "only-last-conv"]: return partialclass( VideoResBlock, video_kernel_size=self.video_kernel_size, alpha=self.alpha, merge_strategy=self.merge_strategy, ) else: return super()._make_resblock() ================================================ FILE: sgm/modules/diffusionmodules/__init__.py ================================================ ================================================ FILE: sgm/modules/diffusionmodules/denoiser.py ================================================ from typing import Dict, Union import torch import torch.nn as nn from ...util import append_dims, instantiate_from_config from .denoiser_scaling import DenoiserScaling from .discretizer import Discretization class Denoiser(nn.Module): def __init__(self, scaling_config: Dict): super().__init__() self.scaling: DenoiserScaling = instantiate_from_config(scaling_config) def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: return sigma def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: return c_noise def forward( self, network: nn.Module, input: torch.Tensor, sigma: torch.Tensor, cond: Dict, **additional_model_inputs, ) -> torch.Tensor: sigma = self.possibly_quantize_sigma(sigma) sigma_shape = sigma.shape sigma = append_dims(sigma, input.ndim) c_skip, c_out, c_in, c_noise = self.scaling(sigma) c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) c_skip, c_out, c_in, c_noise = c_skip.to(dtype=input.dtype), c_out.to(dtype=input.dtype), c_in.to(dtype=input.dtype), c_noise.to(dtype=input.dtype) return ( network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out + input * c_skip ) class DiscreteDenoiser(Denoiser): def __init__( self, scaling_config: Dict, num_idx: int, discretization_config: Dict, do_append_zero: bool = False, quantize_c_noise: bool = True, flip: bool = True, ): super().__init__(scaling_config) self.discretization: Discretization = instantiate_from_config( discretization_config ) sigmas = self.discretization(num_idx, do_append_zero=do_append_zero, flip=flip) self.register_buffer("sigmas", sigmas) self.quantize_c_noise = quantize_c_noise self.num_idx = num_idx def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor: dists = sigma - self.sigmas[:, None] return dists.abs().argmin(dim=0).view(sigma.shape) def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor: return self.sigmas[idx] def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: return self.idx_to_sigma(self.sigma_to_idx(sigma)) def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: if self.quantize_c_noise: return self.sigma_to_idx(c_noise) else: return c_noise ================================================ FILE: sgm/modules/diffusionmodules/denoiser_scaling.py ================================================ from abc import ABC, abstractmethod from typing import Tuple import torch class DenoiserScaling(ABC): @abstractmethod def __call__( self, sigma: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: pass class EDMScaling: def __init__(self, sigma_data: float = 0.5): self.sigma_data = sigma_data def __call__( self, sigma: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 c_noise = 0.25 * sigma.log() return c_skip, c_out, c_in, c_noise class EpsScaling: def __call__( self, sigma: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: c_skip = torch.ones_like(sigma, device=sigma.device) c_out = -sigma c_in = 1 / (sigma**2 + 1.0) ** 0.5 c_noise = sigma.clone() return c_skip, c_out, c_in, c_noise class VScaling: def __call__( self, sigma: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: c_skip = 1.0 / (sigma**2 + 1.0) c_out = -sigma / (sigma**2 + 1.0) ** 0.5 c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 c_noise = sigma.clone() return c_skip, c_out, c_in, c_noise class VScalingWithEDMcNoise(DenoiserScaling): def __call__( self, sigma: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: c_skip = 1.0 / (sigma**2 + 1.0) c_out = -sigma / (sigma**2 + 1.0) ** 0.5 c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 c_noise = 0.25 * sigma.log() return c_skip, c_out, c_in, c_noise class VScalingWithEDMcNoise_fp16(DenoiserScaling): def __call__( self, sigma: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: c_skip = 1.0 / (sigma**2 + 1.0) c_out = -sigma / (sigma**2 + 1.0) ** 0.5 c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 c_noise = 0.25 * sigma.log() c_skip, c_out, c_in, c_noise = c_skip.to(dtype=sigma.dtype), c_out.to(dtype=sigma.dtype), c_in.to(dtype=sigma.dtype), c_noise.to(dtype=sigma.dtype) return c_skip, c_out, c_in, c_noise ================================================ FILE: sgm/modules/diffusionmodules/denoiser_weighting.py ================================================ import torch class UnitWeighting: def __call__(self, sigma): return torch.ones_like(sigma, device=sigma.device) class EDMWeighting: def __init__(self, sigma_data=0.5): self.sigma_data = sigma_data def __call__(self, sigma): return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 class VWeighting(EDMWeighting): def __init__(self): super().__init__(sigma_data=1.0) class EpsWeighting: def __call__(self, sigma): return sigma**-2.0 ================================================ FILE: sgm/modules/diffusionmodules/discretizer.py ================================================ from abc import abstractmethod from functools import partial import numpy as np import torch from ...modules.diffusionmodules.util import make_beta_schedule from ...util import append_zero def generate_roughly_equally_spaced_steps( num_substeps: int, max_step: int ) -> np.ndarray: return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] class Discretization: def __call__(self, n, do_append_zero=True, device="cpu", flip=False): sigmas = self.get_sigmas(n, device=device) sigmas = append_zero(sigmas) if do_append_zero else sigmas return sigmas if not flip else torch.flip(sigmas, (0,)) @abstractmethod def get_sigmas(self, n, device): pass class EDMDiscretization(Discretization): def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0): self.sigma_min = sigma_min self.sigma_max = sigma_max self.rho = rho def get_sigmas(self, n, device="cpu"): ramp = torch.linspace(0, 1, n, device=device) min_inv_rho = self.sigma_min ** (1 / self.rho) max_inv_rho = self.sigma_max ** (1 / self.rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho return sigmas class LegacyDDPMDiscretization(Discretization): def __init__( self, linear_start=0.00085, linear_end=0.0120, num_timesteps=1000, ): super().__init__() self.num_timesteps = num_timesteps betas = make_beta_schedule( "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end ) alphas = 1.0 - betas self.alphas_cumprod = np.cumprod(alphas, axis=0) self.to_torch = partial(torch.tensor, dtype=torch.float32) def get_sigmas(self, n, device="cpu"): if n < self.num_timesteps: timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) alphas_cumprod = self.alphas_cumprod[timesteps] elif n == self.num_timesteps: alphas_cumprod = self.alphas_cumprod else: raise ValueError to_torch = partial(torch.tensor, dtype=torch.float32, device=device) sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 return torch.flip(sigmas, (0,)) ================================================ FILE: sgm/modules/diffusionmodules/guiders.py ================================================ import logging from abc import ABC, abstractmethod from typing import Dict, List, Optional, Tuple, Union import torch from einops import rearrange, repeat from ...util import append_dims, default logpy = logging.getLogger(__name__) class Guider(ABC): @abstractmethod def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: pass def prepare_inputs( self, x: torch.Tensor, s: float, c: Dict, uc: Dict ) -> Tuple[torch.Tensor, float, Dict]: pass class VanillaCFG(Guider): def __init__(self, scale: float): self.scale = scale def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: x_u, x_c = x.chunk(2) x_pred = x_u + self.scale * (x_c - x_u) return x_pred def prepare_inputs(self, x, s, c, uc): c_out = dict() for k in c: if k in ["vector", "crossattn", "concat"]: c_out[k] = torch.cat((uc[k], c[k]), 0) else: assert c[k] == uc[k] c_out[k] = c[k] return torch.cat([x] * 2), torch.cat([s] * 2), c_out class IdentityGuider(Guider): def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: return x def prepare_inputs( self, x: torch.Tensor, s: float, c: Dict, uc: Dict ) -> Tuple[torch.Tensor, float, Dict]: c_out = dict() for k in c: c_out[k] = c[k] return x, s, c_out class LinearPredictionGuider(Guider): def __init__( self, max_scale: float, num_frames: int, min_scale: float = 1.0, additional_cond_keys: Optional[Union[List[str], str]] = None, ): self.min_scale = min_scale self.max_scale = max_scale self.num_frames = num_frames self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0)#.flip(dims=(1,)) additional_cond_keys = default(additional_cond_keys, []) if isinstance(additional_cond_keys, str): additional_cond_keys = [additional_cond_keys] self.additional_cond_keys = additional_cond_keys def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: x_u, x_c = x.chunk(2) x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames) x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames) scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0]) scale = append_dims(scale, x_u.ndim).to(x_u.device, dtype=x_u.dtype) return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...") def prepare_inputs( self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict ) -> Tuple[torch.Tensor, torch.Tensor, dict]: c_out = dict() for k in c: if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys: c_out[k] = torch.cat((uc[k], c[k]), 0) else: assert c[k] == uc[k] c_out[k] = c[k] return torch.cat([x] * 2), torch.cat([s] * 2), c_out class LinearPredictionGuider_fp16(Guider): def __init__( self, max_scale: float, num_frames: int, min_scale: float = 1.0, additional_cond_keys: Optional[Union[List[str], str]] = None, ): self.min_scale = min_scale self.max_scale = max_scale self.num_frames = num_frames self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0)#.flip(dims=(1,)) additional_cond_keys = default(additional_cond_keys, []) if isinstance(additional_cond_keys, str): additional_cond_keys = [additional_cond_keys] self.additional_cond_keys = additional_cond_keys def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: x_u, x_c = x.chunk(2) x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames) x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames) scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0]) scale = append_dims(scale, x_u.ndim).to(x_u.device, dtype=x.dtype) return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...") def prepare_inputs( self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict ) -> Tuple[torch.Tensor, torch.Tensor, dict]: c_out = dict() for k in c: if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys: c_out[k] = torch.cat((uc[k], c[k]), 0) else: assert c[k] == uc[k] c_out[k] = c[k] return torch.cat([x] * 2), torch.cat([s] * 2), c_out ================================================ FILE: sgm/modules/diffusionmodules/model.py ================================================ # pytorch_diffusion + derived encoder decoder import logging import math from typing import Any, Callable, Optional import numpy as np import torch import torch.nn as nn from einops import rearrange from packaging import version logpy = logging.getLogger(__name__) try: import xformers import xformers.ops XFORMERS_IS_AVAILABLE = True except: XFORMERS_IS_AVAILABLE = False logpy.warning("no module 'xformers'. Processing without...") from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention def get_timestep_embedding(timesteps, embedding_dim): """ This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ assert len(timesteps.shape) == 1 half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) emb = emb.to(device=timesteps.device) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb def nonlinearity(x): # swish return x * torch.sigmoid(x) def Normalize(in_channels, num_groups=32): return torch.nn.GroupNorm( num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True ) class Upsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: self.conv = torch.nn.Conv2d( in_channels, in_channels, kernel_size=3, stride=1, padding=1 ) def forward(self, x): x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") if self.with_conv: x = self.conv(x) return x class Downsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves self.conv = torch.nn.Conv2d( in_channels, in_channels, kernel_size=3, stride=2, padding=0 ) def forward(self, x): if self.with_conv: pad = (0, 1, 0, 1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) return x class ResnetBlock(nn.Module): def __init__( self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512, ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels) self.conv1 = torch.nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, padding=1 ) if temb_channels > 0: self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout) self.conv2 = torch.nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1 ) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = torch.nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, padding=1 ) else: self.nin_shortcut = torch.nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0 ) def forward(self, x, temb): h = x h = self.norm1(h) h = nonlinearity(h) h = self.conv1(h) if temb is not None: h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] h = self.norm2(h) h = nonlinearity(h) h = self.dropout(h) h = self.conv2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) return x + h class LinAttnBlock(LinearAttention): """to match AttnBlock usage""" def __init__(self, in_channels): super().__init__(dim=in_channels, heads=1, dim_head=in_channels) class AttnBlock(nn.Module): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.k = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.v = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.proj_out = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) def attention(self, h_: torch.Tensor) -> torch.Tensor: h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) b, c, h, w = q.shape q, k, v = map( lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v) ) h_ = torch.nn.functional.scaled_dot_product_attention( q, k, v ) # scale is dim ** -0.5 per default # compute attention return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) def forward(self, x, **kwargs): h_ = x h_ = self.attention(h_) h_ = self.proj_out(h_) return x + h_ class MemoryEfficientAttnBlock(nn.Module): """ Uses xformers efficient implementation, see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 Note: this is a single-head self-attention operation """ # def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.k = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.v = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.proj_out = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.attention_op: Optional[Any] = None def attention(self, h_: torch.Tensor) -> torch.Tensor: h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention B, C, H, W = q.shape q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) q, k, v = map( lambda t: t.unsqueeze(3) .reshape(B, t.shape[1], 1, C) .permute(0, 2, 1, 3) .reshape(B * 1, t.shape[1], C) .contiguous(), (q, k, v), ) out = xformers.ops.memory_efficient_attention( q, k, v, attn_bias=None, op=self.attention_op ) out = ( out.unsqueeze(0) .reshape(B, 1, out.shape[1], C) .permute(0, 2, 1, 3) .reshape(B, out.shape[1], C) ) return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) def forward(self, x, **kwargs): h_ = x h_ = self.attention(h_) h_ = self.proj_out(h_) return x + h_ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): def forward(self, x, context=None, mask=None, **unused_kwargs): b, c, h, w = x.shape x = rearrange(x, "b c h w -> b (h w) c") out = super().forward(x, context=context, mask=mask) out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c) return x + out def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): assert attn_type in [ "vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none", ], f"attn_type {attn_type} unknown" if ( version.parse(torch.__version__) < version.parse("2.0.0") and attn_type != "none" ): assert XFORMERS_IS_AVAILABLE, ( f"We do not support vanilla attention in {torch.__version__} anymore, " f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'" ) attn_type = "vanilla-xformers" logpy.info(f"making attention of type '{attn_type}' with {in_channels} in_channels") if attn_type == "vanilla": assert attn_kwargs is None return AttnBlock(in_channels) elif attn_type == "vanilla-xformers": logpy.info( f"building MemoryEfficientAttnBlock with {in_channels} in_channels..." ) return MemoryEfficientAttnBlock(in_channels) elif type == "memory-efficient-cross-attn": attn_kwargs["query_dim"] = in_channels return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) elif attn_type == "none": return nn.Identity(in_channels) else: return LinAttnBlock(in_channels) class Model(nn.Module): def __init__( self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla", ): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = self.ch * 4 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.use_timestep = use_timestep if self.use_timestep: # timestep embedding self.temb = nn.Module() self.temb.dense = nn.ModuleList( [ torch.nn.Linear(self.ch, self.temb_ch), torch.nn.Linear(self.temb_ch, self.temb_ch), ] ) # downsampling self.conv_in = torch.nn.Conv2d( in_channels, self.ch, kernel_size=3, stride=1, padding=1 ) curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append( ResnetBlock( in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.block_2 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch * ch_mult[i_level] skip_in = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): if i_block == self.num_res_blocks: skip_in = ch * in_ch_mult[i_level] block.append( ResnetBlock( in_channels=block_in + skip_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d( block_in, out_ch, kernel_size=3, stride=1, padding=1 ) def forward(self, x, t=None, context=None): # assert x.shape[2] == x.shape[3] == self.resolution if context is not None: # assume aligned context, cat along channel axis x = torch.cat((x, context), dim=1) if self.use_timestep: # timestep embedding assert t is not None temb = get_timestep_embedding(t, self.ch) temb = self.temb.dense[0](temb) temb = nonlinearity(temb) temb = self.temb.dense[1](temb) else: temb = None # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] h = self.mid.block_1(h, temb) h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block]( torch.cat([h, hs.pop()], dim=1), temb ) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: h = self.up[i_level].upsample(h) # end h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) return h def get_last_layer(self): return self.conv_out.weight class Encoder(nn.Module): def __init__( self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", **ignore_kwargs, ): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels # downsampling self.conv_in = torch.nn.Conv2d( in_channels, self.ch, kernel_size=3, stride=1, padding=1 ) curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append( ResnetBlock( in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.block_2 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d( block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1, ) def forward(self, x): # timestep embedding temb = None # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] h = self.mid.block_1(h, temb) h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) # end h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) return h class Decoder(nn.Module): def __init__( self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, attn_type="vanilla", **ignorekwargs, ): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.give_pre_end = give_pre_end self.tanh_out = tanh_out # compute in_ch_mult, block_in and curr_res at lowest res in_ch_mult = (1,) + tuple(ch_mult) block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) logpy.info( "Working with z of shape {} = {} dimensions.".format( self.z_shape, np.prod(self.z_shape) ) ) make_attn_cls = self._make_attn() make_resblock_cls = self._make_resblock() make_conv_cls = self._make_conv() # z to block_in self.conv_in = torch.nn.Conv2d( z_channels, block_in, kernel_size=3, stride=1, padding=1 ) # middle self.mid = nn.Module() self.mid.block_1 = make_resblock_cls( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type) self.mid.block_2 = make_resblock_cls( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): block.append( make_resblock_cls( in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn_cls(block_in, attn_type=attn_type)) up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = make_conv_cls( block_in, out_ch, kernel_size=3, stride=1, padding=1 ) def _make_attn(self) -> Callable: return make_attn def _make_resblock(self) -> Callable: return ResnetBlock def _make_conv(self) -> Callable: return torch.nn.Conv2d def get_last_layer(self, **kwargs): return self.conv_out.weight def forward(self, z, **kwargs): # assert z.shape[1:] == self.z_shape[1:] self.last_z_shape = z.shape # timestep embedding temb = None # z to block_in h = self.conv_in(z) # middle h = self.mid.block_1(h, temb, **kwargs) h = self.mid.attn_1(h, **kwargs) h = self.mid.block_2(h, temb, **kwargs) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h, temb, **kwargs) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h, **kwargs) if i_level != 0: h = self.up[i_level].upsample(h) # end if self.give_pre_end: return h h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h, **kwargs) if self.tanh_out: h = torch.tanh(h) return h ================================================ FILE: sgm/modules/diffusionmodules/openaimodel.py ================================================ import logging import math from abc import abstractmethod from typing import Iterable, List, Optional, Tuple, Union import torch as th import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torch.utils.checkpoint import checkpoint from ...modules.attention import SpatialTransformer from ...modules.diffusionmodules.util import (avg_pool_nd, conv_nd, linear, normalization, timestep_embedding, zero_module) from ...modules.video_attention import SpatialVideoTransformer from ...util import exists logpy = logging.getLogger(__name__) class AttentionPool2d(nn.Module): """ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py """ def __init__( self, spacial_dim: int, embed_dim: int, num_heads_channels: int, output_dim: Optional[int] = None, ): super().__init__() self.positional_embedding = nn.Parameter( th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 ) self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) self.num_heads = embed_dim // num_heads_channels self.attention = QKVAttention(self.num_heads) def forward(self, x: th.Tensor) -> th.Tensor: b, c, _ = x.shape x = x.reshape(b, c, -1) x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) x = x + self.positional_embedding[None, :, :].to(x.dtype) x = self.qkv_proj(x) x = self.attention(x) x = self.c_proj(x) return x[:, :, 0] class TimestepBlock(nn.Module): """ Any module where forward() takes timestep embeddings as a second argument. """ @abstractmethod def forward(self, x: th.Tensor, emb: th.Tensor): """ Apply the module to `x` given `emb` timestep embeddings. """ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): """ A sequential module that passes timestep embeddings to the children that support it as an extra input. """ def forward( self, x: th.Tensor, emb: th.Tensor, context: Optional[th.Tensor] = None, image_only_indicator: Optional[th.Tensor] = None, time_context: Optional[int] = None, num_video_frames: Optional[int] = None, ): from ...modules.diffusionmodules.video_model import VideoResBlock for layer in self: module = layer if isinstance(module, TimestepBlock) and not isinstance( module, VideoResBlock ): x = layer(x, emb) elif isinstance(module, VideoResBlock): x = layer(x, emb, num_video_frames, image_only_indicator) elif isinstance(module, SpatialVideoTransformer): x = layer( x, context, time_context, num_video_frames, image_only_indicator, ) elif isinstance(module, SpatialTransformer): x = layer(x, context) else: x = layer(x) return x class Upsample(nn.Module): """ An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions. """ def __init__( self, channels: int, use_conv: bool, dims: int = 2, out_channels: Optional[int] = None, padding: int = 1, third_up: bool = False, kernel_size: int = 3, scale_factor: int = 2, ): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims self.third_up = third_up self.scale_factor = scale_factor if use_conv: self.conv = conv_nd( dims, self.channels, self.out_channels, kernel_size, padding=padding ) def forward(self, x: th.Tensor) -> th.Tensor: assert x.shape[1] == self.channels if self.dims == 3: t_factor = 1 if not self.third_up else self.scale_factor x = F.interpolate( x, ( t_factor * x.shape[2], x.shape[3] * self.scale_factor, x.shape[4] * self.scale_factor, ), mode="nearest", ) else: x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest") if self.use_conv: x = self.conv(x) return x class Downsample(nn.Module): """ A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions. """ def __init__( self, channels: int, use_conv: bool, dims: int = 2, out_channels: Optional[int] = None, padding: int = 1, third_down: bool = False, ): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2)) if use_conv: logpy.info(f"Building a Downsample layer with {dims} dims.") logpy.info( f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, " f"kernel-size: 3, stride: {stride}, padding: {padding}" ) if dims == 3: logpy.info(f" --> Downsampling third axis (time): {third_down}") self.op = conv_nd( dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, ) else: assert self.channels == self.out_channels self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) def forward(self, x: th.Tensor) -> th.Tensor: assert x.shape[1] == self.channels return self.op(x) class ResBlock(TimestepBlock): """ A residual block that can optionally change the number of channels. :param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use a spatial convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. :param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for downsampling. """ def __init__( self, channels: int, emb_channels: int, dropout: float, out_channels: Optional[int] = None, use_conv: bool = False, use_scale_shift_norm: bool = False, dims: int = 2, use_checkpoint: bool = False, up: bool = False, down: bool = False, kernel_size: int = 3, exchange_temb_dims: bool = False, skip_t_emb: bool = False, ): super().__init__() self.channels = channels self.emb_channels = emb_channels self.dropout = dropout self.out_channels = out_channels or channels self.use_conv = use_conv self.use_checkpoint = use_checkpoint self.use_scale_shift_norm = use_scale_shift_norm self.exchange_temb_dims = exchange_temb_dims if isinstance(kernel_size, Iterable): padding = [k // 2 for k in kernel_size] else: padding = kernel_size // 2 self.in_layers = nn.Sequential( normalization(channels), nn.SiLU(), conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), ) self.updown = up or down if up: self.h_upd = Upsample(channels, False, dims) self.x_upd = Upsample(channels, False, dims) elif down: self.h_upd = Downsample(channels, False, dims) self.x_upd = Downsample(channels, False, dims) else: self.h_upd = self.x_upd = nn.Identity() self.skip_t_emb = skip_t_emb self.emb_out_channels = ( 2 * self.out_channels if use_scale_shift_norm else self.out_channels ) if self.skip_t_emb: logpy.info(f"Skipping timestep embedding in {self.__class__.__name__}") assert not self.use_scale_shift_norm self.emb_layers = None self.exchange_temb_dims = False else: self.emb_layers = nn.Sequential( nn.SiLU(), linear( emb_channels, self.emb_out_channels, ), ) self.out_layers = nn.Sequential( normalization(self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), zero_module( conv_nd( dims, self.out_channels, self.out_channels, kernel_size, padding=padding, ) ), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: self.skip_connection = conv_nd( dims, channels, self.out_channels, kernel_size, padding=padding ) else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) def forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor: """ Apply the block to a Tensor, conditioned on a timestep embedding. :param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings. :return: an [N x C x ...] Tensor of outputs. """ if self.use_checkpoint: return checkpoint(self._forward, x, emb) else: return self._forward(x, emb) def _forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor: if self.updown: in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] h = in_rest(x) h = self.h_upd(h) x = self.x_upd(x) h = in_conv(h) else: h = self.in_layers(x) if self.skip_t_emb: emb_out = th.zeros_like(h) else: emb_out = self.emb_layers(emb).type(h.dtype) while len(emb_out.shape) < len(h.shape): emb_out = emb_out[..., None] if self.use_scale_shift_norm: out_norm, out_rest = self.out_layers[0], self.out_layers[1:] scale, shift = th.chunk(emb_out, 2, dim=1) h = out_norm(h) * (1 + scale) + shift h = out_rest(h) else: if self.exchange_temb_dims: emb_out = rearrange(emb_out, "b t c ... -> b c t ...") h = h + emb_out h = self.out_layers(h) return self.skip_connection(x) + h class AttentionBlock(nn.Module): """ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted to the N-d case. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. """ def __init__( self, channels: int, num_heads: int = 1, num_head_channels: int = -1, use_checkpoint: bool = False, use_new_attention_order: bool = False, ): super().__init__() self.channels = channels if num_head_channels == -1: self.num_heads = num_heads else: assert ( channels % num_head_channels == 0 ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" self.num_heads = channels // num_head_channels self.use_checkpoint = use_checkpoint self.norm = normalization(channels) self.qkv = conv_nd(1, channels, channels * 3, 1) if use_new_attention_order: # split qkv before split heads self.attention = QKVAttention(self.num_heads) else: # split heads before split qkv self.attention = QKVAttentionLegacy(self.num_heads) self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) def forward(self, x: th.Tensor, **kwargs) -> th.Tensor: return checkpoint(self._forward, x) def _forward(self, x: th.Tensor) -> th.Tensor: b, c, *spatial = x.shape x = x.reshape(b, c, -1) qkv = self.qkv(self.norm(x)) h = self.attention(qkv) h = self.proj_out(h) return (x + h).reshape(b, c, *spatial) class QKVAttentionLegacy(nn.Module): """ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping """ def __init__(self, n_heads: int): super().__init__() self.n_heads = n_heads def forward(self, qkv: th.Tensor) -> th.Tensor: """ Apply QKV attention. :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after attention. """ bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = th.einsum( "bct,bcs->bts", q * scale, k * scale ) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) a = th.einsum("bts,bcs->bct", weight, v) return a.reshape(bs, -1, length) class QKVAttention(nn.Module): """ A module which performs QKV attention and splits in a different order. """ def __init__(self, n_heads: int): super().__init__() self.n_heads = n_heads def forward(self, qkv: th.Tensor) -> th.Tensor: """ Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after attention. """ bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) q, k, v = qkv.chunk(3, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = th.einsum( "bct,bcs->bts", (q * scale).view(bs * self.n_heads, ch, length), (k * scale).view(bs * self.n_heads, ch, length), ) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) return a.reshape(bs, -1, length) class Timestep(nn.Module): def __init__(self, dim: int): super().__init__() self.dim = dim def forward(self, t: th.Tensor) -> th.Tensor: return timestep_embedding(t, self.dim) class UNetModel(nn.Module): """ The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample rates at which attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x downsampling, attention will be used. :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param conv_resample: if True, use learned convolutions for upsampling and downsampling. :param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this model will be class-conditional with `num_classes` classes. :param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use a fixed channel width per attention head. :param num_heads_upsample: works with num_heads to set a different number of heads for upsampling. Deprecated. :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially increased efficiency. """ def __init__( self, in_channels: int, model_channels: int, out_channels: int, num_res_blocks: int, attention_resolutions: int, dropout: float = 0.0, channel_mult: Union[List, Tuple] = (1, 2, 4, 8), conv_resample: bool = True, dims: int = 2, num_classes: Optional[Union[int, str]] = None, use_checkpoint: bool = False, num_heads: int = -1, num_head_channels: int = -1, num_heads_upsample: int = -1, use_scale_shift_norm: bool = False, resblock_updown: bool = False, transformer_depth: int = 1, context_dim: Optional[int] = None, disable_self_attentions: Optional[List[bool]] = None, num_attention_blocks: Optional[List[int]] = None, disable_middle_self_attn: bool = False, disable_middle_transformer: bool = False, use_linear_in_transformer: bool = False, spatial_transformer_attn_type: str = "softmax", adm_in_channels: Optional[int] = None, ): super().__init__() if num_heads_upsample == -1: num_heads_upsample = num_heads if num_heads == -1: assert ( num_head_channels != -1 ), "Either num_heads or num_head_channels has to be set" if num_head_channels == -1: assert ( num_heads != -1 ), "Either num_heads or num_head_channels has to be set" self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels if isinstance(transformer_depth, int): transformer_depth = len(channel_mult) * [transformer_depth] transformer_depth_middle = transformer_depth[-1] if isinstance(num_res_blocks, int): self.num_res_blocks = len(channel_mult) * [num_res_blocks] else: if len(num_res_blocks) != len(channel_mult): raise ValueError( "provide num_res_blocks either as an int (globally constant) or " "as a list/tuple (per-level) with the same length as channel_mult" ) self.num_res_blocks = num_res_blocks if disable_self_attentions is not None: assert len(disable_self_attentions) == len(channel_mult) if num_attention_blocks is not None: assert len(num_attention_blocks) == len(self.num_res_blocks) assert all( map( lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)), ) ) logpy.info( f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " f"This option has LESS priority than attention_resolutions {attention_resolutions}, " f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " f"attention will still not be set." ) self.attention_resolutions = attention_resolutions self.dropout = dropout self.channel_mult = channel_mult self.conv_resample = conv_resample self.num_classes = num_classes self.use_checkpoint = use_checkpoint self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ) if self.num_classes is not None: if isinstance(self.num_classes, int): self.label_emb = nn.Embedding(num_classes, time_embed_dim) elif self.num_classes == "continuous": logpy.info("setting up linear c_adm embedding layer") self.label_emb = nn.Linear(1, time_embed_dim) elif self.num_classes == "timestep": self.label_emb = nn.Sequential( Timestep(model_channels), nn.Sequential( linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ), ) elif self.num_classes == "sequential": assert adm_in_channels is not None self.label_emb = nn.Sequential( nn.Sequential( linear(adm_in_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ) ) else: raise ValueError self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( conv_nd(dims, in_channels, model_channels, 3, padding=1) ) ] ) self._feature_size = model_channels input_block_chans = [model_channels] ch = model_channels ds = 1 for level, mult in enumerate(channel_mult): for nr in range(self.num_res_blocks[level]): layers = [ ResBlock( ch, time_embed_dim, dropout, out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = mult * model_channels if ds in attention_resolutions: if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels if context_dim is not None and exists(disable_self_attentions): disabled_sa = disable_self_attentions[level] else: disabled_sa = False if ( not exists(num_attention_blocks) or nr < num_attention_blocks[level] ): layers.append( SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, attn_type=spatial_transformer_attn_type, use_checkpoint=use_checkpoint, ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) if level != len(channel_mult) - 1: out_ch = ch self.input_blocks.append( TimestepEmbedSequential( ResBlock( ch, time_embed_dim, dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True, ) if resblock_updown else Downsample( ch, conv_resample, dims=dims, out_channels=out_ch ) ) ) ch = out_ch input_block_chans.append(ch) ds *= 2 self._feature_size += ch if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels self.middle_block = TimestepEmbedSequential( ResBlock( ch, time_embed_dim, dropout, out_channels=ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, attn_type=spatial_transformer_attn_type, use_checkpoint=use_checkpoint, ) if not disable_middle_transformer else th.nn.Identity(), ResBlock( ch, time_embed_dim, dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), ) self._feature_size += ch self.output_blocks = nn.ModuleList([]) for level, mult in list(enumerate(channel_mult))[::-1]: for i in range(self.num_res_blocks[level] + 1): ich = input_block_chans.pop() layers = [ ResBlock( ch + ich, time_embed_dim, dropout, out_channels=model_channels * mult, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = model_channels * mult if ds in attention_resolutions: if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels if exists(disable_self_attentions): disabled_sa = disable_self_attentions[level] else: disabled_sa = False if ( not exists(num_attention_blocks) or i < num_attention_blocks[level] ): layers.append( SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, attn_type=spatial_transformer_attn_type, use_checkpoint=use_checkpoint, ) ) if level and i == self.num_res_blocks[level]: out_ch = ch layers.append( ResBlock( ch, time_embed_dim, dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, up=True, ) if resblock_updown else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) ) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch self.out = nn.Sequential( normalization(ch), nn.SiLU(), zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), ) def forward( self, x: th.Tensor, timesteps: Optional[th.Tensor] = None, context: Optional[th.Tensor] = None, y: Optional[th.Tensor] = None, **kwargs, ) -> th.Tensor: """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :param context: conditioning plugged in via crossattn :param y: an [N] Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs. """ assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" hs = [] t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) emb = self.time_embed(t_emb) if self.num_classes is not None: assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) h = x for module in self.input_blocks: h = module(h, emb, context) hs.append(h) h = self.middle_block(h, emb, context) for module in self.output_blocks: h = th.cat([h, hs.pop()], dim=1) h = module(h, emb, context) h = h.type(x.dtype) return self.out(h) ================================================ FILE: sgm/modules/diffusionmodules/sampling.py ================================================ """ Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py """ from typing import Dict, Union import torch from omegaconf import ListConfig, OmegaConf from tqdm import tqdm from ...modules.diffusionmodules.sampling_utils import (get_ancestral_step, linear_multistep_coeff, to_d, to_neg_log_sigma, to_sigma) from ...util import append_dims, default, instantiate_from_config DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} class BaseDiffusionSampler: def __init__( self, discretization_config: Union[Dict, ListConfig, OmegaConf], num_steps: Union[int, None] = None, guider_config: Union[Dict, ListConfig, OmegaConf, None] = None, verbose: bool = False, device: str = "cuda", ): self.num_steps = num_steps self.discretization = instantiate_from_config(discretization_config) self.guider = instantiate_from_config( default( guider_config, DEFAULT_GUIDER, ) ) self.verbose = verbose self.device = device def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): sigmas = self.discretization( self.num_steps if num_steps is None else num_steps, device=self.device ) uc = default(uc, cond) x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) num_sigmas = len(sigmas) s_in = x.new_ones([x.shape[0]]) return x, s_in, sigmas, num_sigmas, cond, uc def denoise(self, x, denoiser, sigma, cond, uc): denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc)) denoised = self.guider(denoised, sigma) return denoised def get_sigma_gen(self, num_sigmas): sigma_generator = range(num_sigmas - 1) if self.verbose: print("#" * 30, " Sampling setting ", "#" * 30) print(f"Sampler: {self.__class__.__name__}") print(f"Discretization: {self.discretization.__class__.__name__}") print(f"Guider: {self.guider.__class__.__name__}") sigma_generator = tqdm( sigma_generator, total=num_sigmas, desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps", ) return sigma_generator class SingleStepDiffusionSampler(BaseDiffusionSampler): def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs): raise NotImplementedError def euler_step(self, x, d, dt): return x + dt * d class EDMSampler(SingleStepDiffusionSampler): def __init__( self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs ): super().__init__(*args, **kwargs) self.s_churn = s_churn self.s_tmin = s_tmin self.s_tmax = s_tmax self.s_noise = s_noise def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0): sigma_hat = sigma * (gamma + 1.0) if gamma > 0: eps = torch.randn_like(x) * self.s_noise x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 denoised = self.denoise(x, denoiser, sigma_hat, cond, uc) d = to_d(x, sigma_hat, denoised) dt = append_dims(next_sigma - sigma_hat, x.ndim) euler_step = self.euler_step(x, d, dt) x = self.possible_correction_step( euler_step, x, d, dt, next_sigma, denoiser, cond, uc ) return x def __call__(self, denoiser, x, cond, uc=None, num_steps=None): x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( x, cond, uc, num_steps ) for i in self.get_sigma_gen(num_sigmas): gamma = ( min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 ) x = self.sampler_step( s_in * sigmas[i], s_in * sigmas[i + 1], denoiser, x, cond, uc, gamma, ) return x class AncestralSampler(SingleStepDiffusionSampler): def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs): super().__init__(*args, **kwargs) self.eta = eta self.s_noise = s_noise self.noise_sampler = lambda x: torch.randn_like(x) def ancestral_euler_step(self, x, denoised, sigma, sigma_down): d = to_d(x, sigma, denoised) dt = append_dims(sigma_down - sigma, x.ndim) return self.euler_step(x, d, dt) def ancestral_step(self, x, sigma, next_sigma, sigma_up): x = torch.where( append_dims(next_sigma, x.ndim) > 0.0, x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim), x, ) return x def __call__(self, denoiser, x, cond, uc=None, num_steps=None): x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( x, cond, uc, num_steps ) for i in self.get_sigma_gen(num_sigmas): x = self.sampler_step( s_in * sigmas[i], s_in * sigmas[i + 1], denoiser, x, cond, uc, ) return x class LinearMultistepSampler(BaseDiffusionSampler): def __init__( self, order=4, *args, **kwargs, ): super().__init__(*args, **kwargs) self.order = order def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( x, cond, uc, num_steps ) ds = [] sigmas_cpu = sigmas.detach().cpu().numpy() for i in self.get_sigma_gen(num_sigmas): sigma = s_in * sigmas[i] denoised = denoiser( *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs ) denoised = self.guider(denoised, sigma) d = to_d(x, sigma, denoised) ds.append(d) if len(ds) > self.order: ds.pop(0) cur_order = min(i + 1, self.order) coeffs = [ linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order) ] x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) return x class EulerEDMSampler(EDMSampler): def possible_correction_step( self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc ): return euler_step class HeunEDMSampler(EDMSampler): def possible_correction_step( self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc ): if torch.sum(next_sigma) < 1e-14: # Save a network evaluation if all noise levels are 0 return euler_step else: denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc) d_new = to_d(euler_step, next_sigma, denoised) d_prime = (d + d_new) / 2.0 # apply correction if noise level is not 0 x = torch.where( append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step ) return x class EulerAncestralSampler(AncestralSampler): def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc): sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) denoised = self.denoise(x, denoiser, sigma, cond, uc) x = self.ancestral_euler_step(x, denoised, sigma, sigma_down) x = self.ancestral_step(x, sigma, next_sigma, sigma_up) return x class DPMPP2SAncestralSampler(AncestralSampler): def get_variables(self, sigma, sigma_down): t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)] h = t_next - t s = t + 0.5 * h return h, s, t, t_next def get_mult(self, h, s, t, t_next): mult1 = to_sigma(s) / to_sigma(t) mult2 = (-0.5 * h).expm1() mult3 = to_sigma(t_next) / to_sigma(t) mult4 = (-h).expm1() return mult1, mult2, mult3, mult4 def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs): sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) denoised = self.denoise(x, denoiser, sigma, cond, uc) x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down) if torch.sum(sigma_down) < 1e-14: # Save a network evaluation if all noise levels are 0 x = x_euler else: h, s, t, t_next = self.get_variables(sigma, sigma_down) mult = [ append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next) ] x2 = mult[0] * x - mult[1] * denoised denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc) x_dpmpp2s = mult[2] * x - mult[3] * denoised2 # apply correction if noise level is not 0 x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler) x = self.ancestral_step(x, sigma, next_sigma, sigma_up) return x class DPMPP2MSampler(BaseDiffusionSampler): def get_variables(self, sigma, next_sigma, previous_sigma=None): t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] h = t_next - t if previous_sigma is not None: h_last = t - to_neg_log_sigma(previous_sigma) r = h_last / h return h, r, t, t_next else: return h, None, t, t_next def get_mult(self, h, r, t, t_next, previous_sigma): mult1 = to_sigma(t_next) / to_sigma(t) mult2 = (-h).expm1() if previous_sigma is not None: mult3 = 1 + 1 / (2 * r) mult4 = 1 / (2 * r) return mult1, mult2, mult3, mult4 else: return mult1, mult2 def sampler_step( self, old_denoised, previous_sigma, sigma, next_sigma, denoiser, x, cond, uc=None, ): denoised = self.denoise(x, denoiser, sigma, cond, uc) h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) mult = [ append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma) ] x_standard = mult[0] * x - mult[1] * denoised if old_denoised is None or torch.sum(next_sigma) < 1e-14: # Save a network evaluation if all noise levels are 0 or on the first step return x_standard, denoised else: denoised_d = mult[2] * denoised - mult[3] * old_denoised x_advanced = mult[0] * x - mult[1] * denoised_d # apply correction if noise level is not 0 and not first step x = torch.where( append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard ) return x, denoised def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( x, cond, uc, num_steps ) old_denoised = None for i in self.get_sigma_gen(num_sigmas): x, old_denoised = self.sampler_step( old_denoised, None if i == 0 else s_in * sigmas[i - 1], s_in * sigmas[i], s_in * sigmas[i + 1], denoiser, x, cond, uc=uc, ) return x ================================================ FILE: sgm/modules/diffusionmodules/sampling_utils.py ================================================ import torch from scipy import integrate from ...util import append_dims def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): if order - 1 > i: raise ValueError(f"Order {order} too high for step {i}") def fn(tau): prod = 1.0 for k in range(order): if j == k: continue prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) return prod return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0] def get_ancestral_step(sigma_from, sigma_to, eta=1.0): if not eta: return sigma_to, 0.0 sigma_up = torch.minimum( sigma_to, eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, ) sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 return sigma_down, sigma_up def to_d(x, sigma, denoised): return (x - denoised) / append_dims(sigma, x.ndim) def to_neg_log_sigma(sigma): return sigma.log().neg() def to_sigma(neg_log_sigma): return neg_log_sigma.neg().exp() ================================================ FILE: sgm/modules/diffusionmodules/sigma_sampling.py ================================================ import torch from ...util import default, instantiate_from_config class EDMSampling: def __init__(self, p_mean=-1.2, p_std=1.2): self.p_mean = p_mean self.p_std = p_std def __call__(self, n_samples, rand=None): log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,))) return log_sigma.exp() class DiscreteSampling: def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True): self.num_idx = num_idx self.sigmas = instantiate_from_config(discretization_config)( num_idx, do_append_zero=do_append_zero, flip=flip ) def idx_to_sigma(self, idx): return self.sigmas[idx] def __call__(self, n_samples, rand=None): idx = default( rand, torch.randint(0, self.num_idx, (n_samples,)), ) return self.idx_to_sigma(idx) ================================================ FILE: sgm/modules/diffusionmodules/util.py ================================================ """ partially adopted from https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py and https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py and https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py thanks! """ import math from typing import Optional import torch import torch.nn as nn from einops import rearrange, repeat def make_beta_schedule( schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, ): if schedule == "linear": betas = ( torch.linspace( linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 ) ** 2 ) return betas.numpy() def extract_into_tensor(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) def mixed_checkpoint(func, inputs: dict, params, flag): """ Evaluate a function without caching intermediate activations, allowing for reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that it also works with non-tensor inputs :param func: the function to evaluate. :param inputs: the argument dictionary to pass to `func`. :param params: a sequence of parameters `func` depends on but does not explicitly take as arguments. :param flag: if False, disable gradient checkpointing. """ if flag: tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)] tensor_inputs = [ inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor) ] non_tensor_keys = [ key for key in inputs if not isinstance(inputs[key], torch.Tensor) ] non_tensor_inputs = [ inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor) ] args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params) return MixedCheckpointFunction.apply( func, len(tensor_inputs), len(non_tensor_inputs), tensor_keys, non_tensor_keys, *args, ) else: return func(**inputs) class MixedCheckpointFunction(torch.autograd.Function): @staticmethod def forward( ctx, run_function, length_tensors, length_non_tensors, tensor_keys, non_tensor_keys, *args, ): ctx.end_tensors = length_tensors ctx.end_non_tensors = length_tensors + length_non_tensors ctx.gpu_autocast_kwargs = { "enabled": torch.is_autocast_enabled(), "dtype": torch.get_autocast_gpu_dtype(), "cache_enabled": torch.is_autocast_cache_enabled(), } assert ( len(tensor_keys) == length_tensors and len(non_tensor_keys) == length_non_tensors ) ctx.input_tensors = { key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors])) } ctx.input_non_tensors = { key: val for (key, val) in zip( non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors]) ) } ctx.run_function = run_function ctx.input_params = list(args[ctx.end_non_tensors :]) with torch.no_grad(): output_tensors = ctx.run_function( **ctx.input_tensors, **ctx.input_non_tensors ) return output_tensors @staticmethod def backward(ctx, *output_grads): # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)} ctx.input_tensors = { key: ctx.input_tensors[key].detach().requires_grad_(True) for key in ctx.input_tensors } with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): # Fixes a bug where the first op in run_function modifies the # Tensor storage in place, which is not allowed for detach()'d # Tensors. shallow_copies = { key: ctx.input_tensors[key].view_as(ctx.input_tensors[key]) for key in ctx.input_tensors } # shallow_copies.update(additional_args) output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors) input_grads = torch.autograd.grad( output_tensors, list(ctx.input_tensors.values()) + ctx.input_params, output_grads, allow_unused=True, ) del ctx.input_tensors del ctx.input_params del output_tensors return ( (None, None, None, None, None) + input_grads[: ctx.end_tensors] + (None,) * (ctx.end_non_tensors - ctx.end_tensors) + input_grads[ctx.end_tensors :] ) ckpt = torch.utils.checkpoint.checkpoint def checkpoint(func, inputs, params, flag): """ Evaluate a function without caching intermediate activations, allowing for reduced memory at the expense of extra compute in the backward pass. :param func: the function to evaluate. :param inputs: the argument sequence to pass to `func`. :param params: a sequence of parameters `func` depends on but does not explicitly take as arguments. :param flag: if False, disable gradient checkpointing. """ if flag: #args = tuple(inputs) + tuple(params) #return CheckpointFunction.apply(func, len(inputs), *args) return ckpt(func, *inputs) else: return func(*inputs) class CheckpointFunction(torch.autograd.Function): @staticmethod def forward(ctx, run_function, length, *args): ctx.run_function = run_function ctx.input_tensors = list(args[:length]) ctx.input_params = list(args[length:]) ctx.gpu_autocast_kwargs = { "enabled": torch.is_autocast_enabled(), "dtype": torch.get_autocast_gpu_dtype(), "cache_enabled": torch.is_autocast_cache_enabled(), } with torch.no_grad(): output_tensors = ctx.run_function(*ctx.input_tensors) return output_tensors @staticmethod def backward(ctx, *output_grads): ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): # Fixes a bug where the first op in run_function modifies the # Tensor storage in place, which is not allowed for detach()'d # Tensors. shallow_copies = [x.view_as(x) for x in ctx.input_tensors] output_tensors = ctx.run_function(*shallow_copies) input_grads = torch.autograd.grad( output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True, ) del ctx.input_tensors del ctx.input_params del output_tensors return (None, None) + input_grads def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): """ Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ if not repeat_only: half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=timesteps.device) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat( [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 ) else: embedding = repeat(timesteps, "b -> b d", d=dim) return embedding def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module def scale_module(module, scale): """ Scale the parameters of a module and return it. """ for p in module.parameters(): p.detach().mul_(scale) return module def mean_flat(tensor): """ Take the mean over all non-batch dimensions. """ return tensor.mean(dim=list(range(1, len(tensor.shape)))) def normalization(channels): """ Make a standard normalization layer. :param channels: number of input channels. :return: an nn.Module for normalization. """ return GroupNorm32(32, channels) # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. class SiLU(nn.Module): def forward(self, x): return x * torch.sigmoid(x) class GroupNorm32(nn.GroupNorm): def forward(self, x): return super().forward(x).type(x.dtype) def conv_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D convolution module. """ if dims == 1: return nn.Conv1d(*args, **kwargs) elif dims == 2: return nn.Conv2d(*args, **kwargs) elif dims == 3: return nn.Conv3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") def linear(*args, **kwargs): """ Create a linear module. """ return nn.Linear(*args, **kwargs) def avg_pool_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D average pooling module. """ if dims == 1: return nn.AvgPool1d(*args, **kwargs) elif dims == 2: return nn.AvgPool2d(*args, **kwargs) elif dims == 3: return nn.AvgPool3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") class AlphaBlender(nn.Module): strategies = ["learned", "fixed", "learned_with_images"] def __init__( self, alpha: float, merge_strategy: str = "learned_with_images", rearrange_pattern: str = "b t -> (b t) 1 1", ): super().__init__() self.merge_strategy = merge_strategy self.rearrange_pattern = rearrange_pattern assert ( merge_strategy in self.strategies ), f"merge_strategy needs to be in {self.strategies}" if self.merge_strategy == "fixed": self.register_buffer("mix_factor", torch.Tensor([alpha])) elif ( self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images" ): self.register_parameter( "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) ) else: raise ValueError(f"unknown merge strategy {self.merge_strategy}") def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor: if self.merge_strategy == "fixed": alpha = self.mix_factor elif self.merge_strategy == "learned": alpha = torch.sigmoid(self.mix_factor) elif self.merge_strategy == "learned_with_images": assert image_only_indicator is not None, "need image_only_indicator ..." alpha = torch.where( image_only_indicator.bool(), torch.ones(1, 1, device=image_only_indicator.device), rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), ) alpha = rearrange(alpha, self.rearrange_pattern) else: raise NotImplementedError return alpha def forward( self, x_spatial: torch.Tensor, x_temporal: torch.Tensor, image_only_indicator: Optional[torch.Tensor] = None, ) -> torch.Tensor: alpha = self.get_alpha(image_only_indicator) x = ( alpha.to(x_spatial.dtype) * x_spatial + (1.0 - alpha).to(x_spatial.dtype) * x_temporal ) return x ================================================ FILE: sgm/modules/diffusionmodules/video_model.py ================================================ from functools import partial from typing import List, Optional, Union from einops import rearrange import torch from ...modules.diffusionmodules.openaimodel import * from ...modules.video_attention import SpatialVideoTransformer from ...util import default from .util import AlphaBlender class VideoResBlock(ResBlock): def __init__( self, channels: int, emb_channels: int, dropout: float, video_kernel_size: Union[int, List[int]] = 3, merge_strategy: str = "fixed", merge_factor: float = 0.5, out_channels: Optional[int] = None, use_conv: bool = False, use_scale_shift_norm: bool = False, dims: int = 2, use_checkpoint: bool = False, up: bool = False, down: bool = False, ): super().__init__( channels, emb_channels, dropout, out_channels=out_channels, use_conv=use_conv, use_scale_shift_norm=use_scale_shift_norm, dims=dims, use_checkpoint=use_checkpoint, up=up, down=down, ) self.time_stack = ResBlock( default(out_channels, channels), emb_channels, dropout=dropout, dims=3, out_channels=default(out_channels, channels), use_scale_shift_norm=False, use_conv=False, up=False, down=False, kernel_size=video_kernel_size, use_checkpoint=use_checkpoint, exchange_temb_dims=True, ) self.time_mixer = AlphaBlender( alpha=merge_factor, merge_strategy=merge_strategy, rearrange_pattern="b t -> b 1 t 1 1", ) def forward( self, x: th.Tensor, emb: th.Tensor, num_video_frames: int, image_only_indicator: Optional[th.Tensor] = None, ) -> th.Tensor: x = super().forward(x, emb) x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames) x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames) x = self.time_stack( x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames) ) x = self.time_mixer( x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator ) x = rearrange(x, "b c t h w -> (b t) c h w") return x class VideoUNet(nn.Module): def __init__( self, in_channels: int, model_channels: int, out_channels: int, num_frames: int, num_res_blocks: int, attention_resolutions: int, dropout: float = 0.0, channel_mult: List[int] = (1, 2, 4, 8), conv_resample: bool = True, dims: int = 2, num_classes: Optional[int] = None, use_checkpoint: bool = False, num_heads: int = -1, num_head_channels: int = -1, num_heads_upsample: int = -1, use_scale_shift_norm: bool = False, resblock_updown: bool = False, transformer_depth: Union[List[int], int] = 1, transformer_depth_middle: Optional[int] = None, context_dim: Optional[int] = None, time_downup: bool = False, time_context_dim: Optional[int] = None, extra_ff_mix_layer: bool = False, use_spatial_context: bool = False, merge_strategy: str = "fixed", merge_factor: float = 0.5, spatial_transformer_attn_type: str = "softmax", video_kernel_size: Union[int, List[int]] = 3, use_linear_in_transformer: bool = False, adm_in_channels: Optional[int] = None, disable_temporal_crossattention: bool = False, max_ddpm_temb_period: int = 10000, ): super(VideoUNet, self).__init__() assert context_dim is not None if num_heads_upsample == -1: num_heads_upsample = num_heads if num_heads == -1: assert num_head_channels != -1 if num_head_channels == -1: assert num_heads != -1 self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels self.num_frames = num_frames if isinstance(transformer_depth, int): transformer_depth = len(channel_mult) * [transformer_depth] transformer_depth_middle = default( transformer_depth_middle, transformer_depth[-1] ) self.num_res_blocks = num_res_blocks self.attention_resolutions = attention_resolutions self.dropout = dropout self.channel_mult = channel_mult self.conv_resample = conv_resample self.num_classes = num_classes self.use_checkpoint = use_checkpoint self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ) if self.num_classes is not None: if isinstance(self.num_classes, int): self.label_emb = nn.Embedding(num_classes, time_embed_dim) elif self.num_classes == "continuous": print("setting up linear c_adm embedding layer") self.label_emb = nn.Linear(1, time_embed_dim) elif self.num_classes == "timestep": self.label_emb = nn.Sequential( Timestep(model_channels), nn.Sequential( linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ), ) elif self.num_classes == "sequential": assert adm_in_channels is not None self.label_emb = nn.Sequential( nn.Sequential( linear(adm_in_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ) ) else: raise ValueError() self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( conv_nd(dims, in_channels, model_channels, 3, padding=1) ) ] ) self._feature_size = model_channels input_block_chans = [model_channels] ch = model_channels ds = 1 def get_attention_layer( ch, num_heads, dim_head, depth=1, context_dim=None, use_checkpoint=False, disabled_sa=False, ): return SpatialVideoTransformer( ch, num_heads, dim_head, depth=depth, context_dim=context_dim, time_context_dim=time_context_dim, dropout=dropout, ff_in=extra_ff_mix_layer, use_spatial_context=use_spatial_context, merge_strategy=merge_strategy, merge_factor=merge_factor, checkpoint=use_checkpoint, use_linear=use_linear_in_transformer, attn_mode=spatial_transformer_attn_type, disable_self_attn=disabled_sa, disable_temporal_crossattention=disable_temporal_crossattention, max_time_embed_period=max_ddpm_temb_period, ) def get_resblock( merge_factor, merge_strategy, video_kernel_size, ch, time_embed_dim, dropout, out_ch, dims, use_checkpoint, use_scale_shift_norm, down=False, up=False, ): return VideoResBlock( merge_factor=merge_factor, merge_strategy=merge_strategy, video_kernel_size=video_kernel_size, channels=ch, emb_channels=time_embed_dim, dropout=dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=down, up=up, ) for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): layers = [ get_resblock( merge_factor=merge_factor, merge_strategy=merge_strategy, video_kernel_size=video_kernel_size, ch=ch, time_embed_dim=time_embed_dim, dropout=dropout, out_ch=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = mult * model_channels if ds in attention_resolutions: if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels layers.append( get_attention_layer( ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, use_checkpoint=use_checkpoint, disabled_sa=False, ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) if level != len(channel_mult) - 1: ds *= 2 out_ch = ch self.input_blocks.append( TimestepEmbedSequential( get_resblock( merge_factor=merge_factor, merge_strategy=merge_strategy, video_kernel_size=video_kernel_size, ch=ch, time_embed_dim=time_embed_dim, dropout=dropout, out_ch=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True, ) if resblock_updown else Downsample( ch, conv_resample, dims=dims, out_channels=out_ch, third_down=time_downup, ) ) ) ch = out_ch input_block_chans.append(ch) self._feature_size += ch if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels self.middle_block = TimestepEmbedSequential( get_resblock( merge_factor=merge_factor, merge_strategy=merge_strategy, video_kernel_size=video_kernel_size, ch=ch, time_embed_dim=time_embed_dim, out_ch=None, dropout=dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), get_attention_layer( ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, use_checkpoint=use_checkpoint, ), get_resblock( merge_factor=merge_factor, merge_strategy=merge_strategy, video_kernel_size=video_kernel_size, ch=ch, out_ch=None, time_embed_dim=time_embed_dim, dropout=dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), ) self._feature_size += ch self.output_blocks = nn.ModuleList([]) for level, mult in list(enumerate(channel_mult))[::-1]: for i in range(num_res_blocks + 1): ich = input_block_chans.pop() layers = [ get_resblock( merge_factor=merge_factor, merge_strategy=merge_strategy, video_kernel_size=video_kernel_size, ch=ch + ich, time_embed_dim=time_embed_dim, dropout=dropout, out_ch=model_channels * mult, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = model_channels * mult if ds in attention_resolutions: if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels layers.append( get_attention_layer( ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, use_checkpoint=use_checkpoint, disabled_sa=False, ) ) if level and i == num_res_blocks: out_ch = ch ds //= 2 layers.append( get_resblock( merge_factor=merge_factor, merge_strategy=merge_strategy, video_kernel_size=video_kernel_size, ch=ch, time_embed_dim=time_embed_dim, dropout=dropout, out_ch=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, up=True, ) if resblock_updown else Upsample( ch, conv_resample, dims=dims, out_channels=out_ch, third_up=time_downup, ) ) self.output_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch self.out = nn.Sequential( normalization(ch), nn.SiLU(), zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), ) def forward( self, x: th.Tensor, timesteps: th.Tensor, context: Optional[th.Tensor] = None, y: Optional[th.Tensor] = None, time_context: Optional[th.Tensor] = None, num_video_frames: Optional[int] = None, image_only_indicator: Optional[th.Tensor] = None, ): assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional -> no, relax this TODO" t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) emb = self.time_embed(t_emb) ## tbd: check the role of "image_only_indicator" num_video_frames = self.num_frames image_only_indicator = torch.zeros( x.shape[0]//num_video_frames, num_video_frames ).to(x.device) if image_only_indicator is None else image_only_indicator if self.num_classes is not None: assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) ## x shape: [bt,c,h,w] h = x hs = [] for module in self.input_blocks: h = module( h, emb, context=context, image_only_indicator=image_only_indicator, time_context=time_context, num_video_frames=num_video_frames, ) hs.append(h) h = self.middle_block( h, emb, context=context, image_only_indicator=image_only_indicator, time_context=time_context, num_video_frames=num_video_frames, ) for module in self.output_blocks: h = th.cat([h, hs.pop()], dim=1) h = module( h, emb, context=context, image_only_indicator=image_only_indicator, time_context=time_context, num_video_frames=num_video_frames, ) h = h.type(x.dtype) return self.out(h) ================================================ FILE: sgm/modules/diffusionmodules/wrappers.py ================================================ import torch import torch.nn as nn from packaging import version OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" class IdentityWrapper(nn.Module): def __init__(self, diffusion_model, compile_model: bool = False): super().__init__() compile = ( torch.compile if (version.parse(torch.__version__) >= version.parse("2.0.0")) and compile_model else lambda x: x ) self.diffusion_model = compile(diffusion_model) def forward(self, *args, **kwargs): return self.diffusion_model(*args, **kwargs) class OpenAIWrapper(IdentityWrapper): def forward( self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs ) -> torch.Tensor: x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) return self.diffusion_model( x, timesteps=t, context=c.get("crossattn", None), y=c.get("vector", None), **kwargs, ) ================================================ FILE: sgm/modules/distributions/__init__.py ================================================ ================================================ FILE: sgm/modules/distributions/distributions.py ================================================ import numpy as np import torch class AbstractDistribution: def sample(self): raise NotImplementedError() def mode(self): raise NotImplementedError() class DiracDistribution(AbstractDistribution): def __init__(self, value): self.value = value def sample(self): return self.value def mode(self): return self.value class DiagonalGaussianDistribution(object): def __init__(self, parameters, deterministic=False): self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: self.var = self.std = torch.zeros_like(self.mean).to( device=self.parameters.device ) def sample(self): x = self.mean + self.std * torch.randn(self.mean.shape).to( device=self.parameters.device ) return x def kl(self, other=None): if self.deterministic: return torch.Tensor([0.0]) else: if other is None: return 0.5 * torch.sum( torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3], ) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, dim=[1, 2, 3], ) def nll(self, sample, dims=[1, 2, 3]): if self.deterministic: return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) return 0.5 * torch.sum( logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims, ) def mode(self): return self.mean def normal_kl(mean1, logvar1, mean2, logvar2): """ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 Compute the KL divergence between two gaussians. Shapes are automatically broadcasted, so batches can be compared to scalars, among other use cases. """ tensor = None for obj in (mean1, logvar1, mean2, logvar2): if isinstance(obj, torch.Tensor): tensor = obj break assert tensor is not None, "at least one argument must be a Tensor" # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for torch.exp(). logvar1, logvar2 = [ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2) ] return 0.5 * ( -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) ) ================================================ FILE: sgm/modules/ema.py ================================================ import torch from torch import nn class LitEma(nn.Module): def __init__(self, model, decay=0.9999, use_num_upates=True): super().__init__() if decay < 0.0 or decay > 1.0: raise ValueError("Decay must be between 0 and 1") self.m_name2s_name = {} self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) self.register_buffer( "num_updates", torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int), ) for name, p in model.named_parameters(): if p.requires_grad: # remove as '.'-character is not allowed in buffers s_name = name.replace(".", "") self.m_name2s_name.update({name: s_name}) self.register_buffer(s_name, p.clone().detach().data) self.collected_params = [] def reset_num_updates(self): del self.num_updates self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) def forward(self, model): decay = self.decay if self.num_updates >= 0: self.num_updates += 1 decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) one_minus_decay = 1.0 - decay with torch.no_grad(): m_param = dict(model.named_parameters()) shadow_params = dict(self.named_buffers()) for key in m_param: if m_param[key].requires_grad: sname = self.m_name2s_name[key] shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) shadow_params[sname].sub_( one_minus_decay * (shadow_params[sname] - m_param[key]) ) else: assert not key in self.m_name2s_name def copy_to(self, model): m_param = dict(model.named_parameters()) shadow_params = dict(self.named_buffers()) for key in m_param: if m_param[key].requires_grad: m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) else: assert not key in self.m_name2s_name def store(self, parameters): """ Save the current parameters for restoring later. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. """ self.collected_params = [param.clone() for param in parameters] def restore(self, parameters): """ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without affecting the original optimization process. Store the parameters before the `copy_to` method. After validation (or model saving), use this to restore the former parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. """ for c_param, param in zip(self.collected_params, parameters): param.data.copy_(c_param.data) ================================================ FILE: sgm/modules/encoders/__init__.py ================================================ ================================================ FILE: sgm/modules/encoders/modules.py ================================================ import math from contextlib import nullcontext from functools import partial from typing import Dict, List, Optional, Tuple, Union import kornia import numpy as np import open_clip import torch import torch.nn as nn from einops import rearrange, repeat from omegaconf import ListConfig from torch.utils.checkpoint import checkpoint from transformers import (ByT5Tokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer) from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer from ...modules.diffusionmodules.model import Encoder from ...modules.diffusionmodules.openaimodel import Timestep from ...modules.diffusionmodules.util import (extract_into_tensor, make_beta_schedule) from ...modules.distributions.distributions import DiagonalGaussianDistribution from ...util import (append_dims, autocast, count_params, default, disabled_train, expand_dims_like, instantiate_from_config) class AbstractEmbModel(nn.Module): def __init__(self): super().__init__() self._is_trainable = None self._ucg_rate = None self._input_key = None @property def is_trainable(self) -> bool: return self._is_trainable @property def ucg_rate(self) -> Union[float, torch.Tensor]: return self._ucg_rate @property def input_key(self) -> str: return self._input_key @is_trainable.setter def is_trainable(self, value: bool): self._is_trainable = value @ucg_rate.setter def ucg_rate(self, value: Union[float, torch.Tensor]): self._ucg_rate = value @input_key.setter def input_key(self, value: str): self._input_key = value @is_trainable.deleter def is_trainable(self): del self._is_trainable @ucg_rate.deleter def ucg_rate(self): del self._ucg_rate @input_key.deleter def input_key(self): del self._input_key class GeneralConditioner(nn.Module): OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"} KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1} def __init__(self, emb_models: Union[List, ListConfig]): super().__init__() embedders = [] for n, embconfig in enumerate(emb_models): embedder = instantiate_from_config(embconfig) assert isinstance( embedder, AbstractEmbModel ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel" embedder.is_trainable = embconfig.get("is_trainable", False) embedder.ucg_rate = embconfig.get("ucg_rate", 0.0) if not embedder.is_trainable: embedder.train = disabled_train for param in embedder.parameters(): param.requires_grad = False embedder.eval() print( f"Initialized embedder #{n}: {embedder.__class__.__name__} " f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}" ) if "input_key" in embconfig: embedder.input_key = embconfig["input_key"] elif "input_keys" in embconfig: embedder.input_keys = embconfig["input_keys"] else: raise KeyError( f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}" ) embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None) if embedder.legacy_ucg_val is not None: embedder.ucg_prng = np.random.RandomState() embedders.append(embedder) self.embedders = nn.ModuleList(embedders) def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict: assert embedder.legacy_ucg_val is not None p = embedder.ucg_rate val = embedder.legacy_ucg_val for i in range(len(batch[embedder.input_key])): if embedder.ucg_prng.choice(2, p=[1 - p, p]): batch[embedder.input_key][i] = val return batch def forward( self, batch: Dict, force_zero_embeddings: Optional[List] = None ) -> Dict: output = dict() if force_zero_embeddings is None: force_zero_embeddings = [] for embedder in self.embedders: embedding_context = nullcontext if embedder.is_trainable else torch.no_grad with embedding_context(): if hasattr(embedder, "input_key") and (embedder.input_key is not None): if embedder.legacy_ucg_val is not None: batch = self.possibly_get_ucg_val(embedder, batch) # print(embedder.input_key) emb_out = embedder(batch[embedder.input_key]) elif hasattr(embedder, "input_keys"): emb_out = embedder(*[batch[k] for k in embedder.input_keys]) assert isinstance( emb_out, (torch.Tensor, list, tuple) ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}" if not isinstance(emb_out, (list, tuple)): emb_out = [emb_out] for emb in emb_out: out_key = self.OUTPUT_DIM2KEYS[emb.dim()] if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: emb = ( expand_dims_like( torch.bernoulli( (1.0 - embedder.ucg_rate) * torch.ones(emb.shape[0], device=emb.device) ), emb, ) * emb ) if ( hasattr(embedder, "input_key") and embedder.input_key in force_zero_embeddings ): emb = torch.zeros_like(emb) if out_key in output: output[out_key] = torch.cat( (output[out_key], emb), self.KEY2CATDIM[out_key] ) else: output[out_key] = emb return output def get_unconditional_conditioning( self, batch_c: Dict, batch_uc: Optional[Dict] = None, force_uc_zero_embeddings: Optional[List[str]] = None, force_cond_zero_embeddings: Optional[List[str]] = None, ): if force_uc_zero_embeddings is None: force_uc_zero_embeddings = [] ucg_rates = list() for embedder in self.embedders: ucg_rates.append(embedder.ucg_rate) embedder.ucg_rate = 0.0 c = self(batch_c, force_cond_zero_embeddings) uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings) for embedder, rate in zip(self.embedders, ucg_rates): embedder.ucg_rate = rate return c, uc class InceptionV3(nn.Module): """Wrapper around the https://github.com/mseitzer/pytorch-fid inception port with an additional squeeze at the end""" def __init__(self, normalize_input=False, **kwargs): super().__init__() from pytorch_fid import inception kwargs["resize_input"] = True self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs) def forward(self, inp): outp = self.model(inp) if len(outp) == 1: return outp[0].squeeze() return outp class IdentityEncoder(AbstractEmbModel): def encode(self, x): return x def forward(self, x): return x class ClassEmbedder(AbstractEmbModel): def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False): super().__init__() self.embedding = nn.Embedding(n_classes, embed_dim) self.n_classes = n_classes self.add_sequence_dim = add_sequence_dim def forward(self, c): c = self.embedding(c) if self.add_sequence_dim: c = c[:, None, :] return c def get_unconditional_conditioning(self, bs, device="cuda"): uc_class = ( self.n_classes - 1 ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) uc = torch.ones((bs,), device=device) * uc_class uc = {self.key: uc.long()} return uc class ClassEmbedderForMultiCond(ClassEmbedder): def forward(self, batch, key=None, disable_dropout=False): out = batch key = default(key, self.key) islist = isinstance(batch[key], list) if islist: batch[key] = batch[key][0] c_out = super().forward(batch, key, disable_dropout) out[key] = [c_out] if islist else c_out return out class FrozenT5Embedder(AbstractEmbModel): """Uses the T5 transformer encoder for text""" def __init__( self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() self.tokenizer = T5Tokenizer.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version) self.device = device self.max_length = max_length if freeze: self.freeze() def freeze(self): self.transformer = self.transformer.eval() for param in self.parameters(): param.requires_grad = False def forward(self, text): batch_encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt", ) tokens = batch_encoding["input_ids"].to(self.device) with torch.autocast("cuda", enabled=False): outputs = self.transformer(input_ids=tokens) z = outputs.last_hidden_state return z def encode(self, text): return self(text) class FrozenByT5Embedder(AbstractEmbModel): """ Uses the ByT5 transformer encoder for text. Is character-aware. """ def __init__( self, version="google/byt5-base", device="cuda", max_length=77, freeze=True ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() self.tokenizer = ByT5Tokenizer.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version) self.device = device self.max_length = max_length if freeze: self.freeze() def freeze(self): self.transformer = self.transformer.eval() for param in self.parameters(): param.requires_grad = False def forward(self, text): batch_encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt", ) tokens = batch_encoding["input_ids"].to(self.device) with torch.autocast("cuda", enabled=False): outputs = self.transformer(input_ids=tokens) z = outputs.last_hidden_state return z def encode(self, text): return self(text) class FrozenCLIPEmbedder(AbstractEmbModel): """Uses the CLIP transformer encoder for text (from huggingface)""" LAYERS = ["last", "pooled", "hidden"] def __init__( self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, freeze=True, layer="last", layer_idx=None, always_return_pooled=False, ): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS self.tokenizer = CLIPTokenizer.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version) self.device = device self.max_length = max_length if freeze: self.freeze() self.layer = layer self.layer_idx = layer_idx self.return_pooled = always_return_pooled if layer == "hidden": assert layer_idx is not None assert 0 <= abs(layer_idx) <= 12 def freeze(self): self.transformer = self.transformer.eval() for param in self.parameters(): param.requires_grad = False @autocast def forward(self, text): batch_encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt", ) tokens = batch_encoding["input_ids"].to(self.device) outputs = self.transformer( input_ids=tokens, output_hidden_states=self.layer == "hidden" ) if self.layer == "last": z = outputs.last_hidden_state elif self.layer == "pooled": z = outputs.pooler_output[:, None, :] else: z = outputs.hidden_states[self.layer_idx] if self.return_pooled: return z, outputs.pooler_output return z def encode(self, text): return self(text) class FrozenOpenCLIPEmbedder2(AbstractEmbModel): """ Uses the OpenCLIP transformer encoder for text """ LAYERS = ["pooled", "last", "penultimate"] def __init__( self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, freeze=True, layer="last", always_return_pooled=False, legacy=True, ): super().__init__() assert layer in self.LAYERS model, _, _ = open_clip.create_model_and_transforms( arch, device=torch.device("cpu"), pretrained=version, ) del model.visual self.model = model self.device = device self.max_length = max_length self.return_pooled = always_return_pooled if freeze: self.freeze() self.layer = layer if self.layer == "last": self.layer_idx = 0 elif self.layer == "penultimate": self.layer_idx = 1 else: raise NotImplementedError() self.legacy = legacy def freeze(self): self.model = self.model.eval() for param in self.parameters(): param.requires_grad = False @autocast def forward(self, text): tokens = open_clip.tokenize(text) z = self.encode_with_transformer(tokens.to(self.device)) if not self.return_pooled and self.legacy: return z if self.return_pooled: assert not self.legacy return z[self.layer], z["pooled"] return z[self.layer] def encode_with_transformer(self, text): x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] x = x + self.model.positional_embedding x = x.permute(1, 0, 2) # NLD -> LND x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) if self.legacy: x = x[self.layer] x = self.model.ln_final(x) return x else: # x is a dict and will stay a dict o = x["last"] o = self.model.ln_final(o) pooled = self.pool(o, text) x["pooled"] = pooled return x def pool(self, x, text): # take features from the eot embedding (eot_token is the highest number in each sequence) x = ( x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection ) return x def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): outputs = {} for i, r in enumerate(self.model.transformer.resblocks): if i == len(self.model.transformer.resblocks) - 1: outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD if ( self.model.transformer.grad_checkpointing and not torch.jit.is_scripting() ): x = checkpoint(r, x, attn_mask) else: x = r(x, attn_mask=attn_mask) outputs["last"] = x.permute(1, 0, 2) # LND -> NLD return outputs def encode(self, text): return self(text) class FrozenOpenCLIPEmbedder(AbstractEmbModel): LAYERS = [ # "pooled", "last", "penultimate", ] def __init__( self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, freeze=True, layer="last", ): super().__init__() assert layer in self.LAYERS model, _, _ = open_clip.create_model_and_transforms( arch, device=torch.device("cpu"), pretrained=version ) del model.visual self.model = model self.device = device self.max_length = max_length if freeze: self.freeze() self.layer = layer if self.layer == "last": self.layer_idx = 0 elif self.layer == "penultimate": self.layer_idx = 1 else: raise NotImplementedError() def freeze(self): self.model = self.model.eval() for param in self.parameters(): param.requires_grad = False def forward(self, text): tokens = open_clip.tokenize(text) z = self.encode_with_transformer(tokens.to(self.device)) return z def encode_with_transformer(self, text): x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] x = x + self.model.positional_embedding x = x.permute(1, 0, 2) # NLD -> LND x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) x = x.permute(1, 0, 2) # LND -> NLD x = self.model.ln_final(x) return x def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): for i, r in enumerate(self.model.transformer.resblocks): if i == len(self.model.transformer.resblocks) - self.layer_idx: break if ( self.model.transformer.grad_checkpointing and not torch.jit.is_scripting() ): x = checkpoint(r, x, attn_mask) else: x = r(x, attn_mask=attn_mask) return x def encode(self, text): return self(text) class FrozenOpenCLIPImageEmbedder(AbstractEmbModel): """ Uses the OpenCLIP vision transformer encoder for images """ def __init__( self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, freeze=True, antialias=True, ucg_rate=0.0, unsqueeze_dim=False, repeat_to_max_len=False, num_image_crops=0, output_tokens=False, init_device=None, ): super().__init__() model, _, _ = open_clip.create_model_and_transforms( arch, device=torch.device(default(init_device, "cuda")), pretrained=version, ) del model.transformer self.model = model self.max_crops = num_image_crops self.pad_to_max_len = self.max_crops > 0 self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len) self.device = device self.max_length = max_length if freeze: self.freeze() self.antialias = antialias self.register_buffer( "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False ) self.register_buffer( "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False ) self.ucg_rate = ucg_rate self.unsqueeze_dim = unsqueeze_dim self.stored_batch = None self.model.visual.output_tokens = output_tokens self.output_tokens = output_tokens def preprocess(self, x): # normalize to [0,1] x = kornia.geometry.resize( x, (224, 224), interpolation="bicubic", align_corners=True, antialias=self.antialias, ) x = (x + 1.0) / 2.0 # renormalize according to clip x = kornia.enhance.normalize(x, self.mean, self.std) return x def freeze(self): self.model = self.model.eval() for param in self.parameters(): param.requires_grad = False @autocast def forward(self, image, no_dropout=False): z = self.encode_with_vision_transformer(image) tokens = None if self.output_tokens: z, tokens = z[0], z[1] z = z.to(image.dtype) if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0): z = ( torch.bernoulli( (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device) )[:, None] * z ) if tokens is not None: tokens = ( expand_dims_like( torch.bernoulli( (1.0 - self.ucg_rate) * torch.ones(tokens.shape[0], device=tokens.device) ), tokens, ) * tokens ) if self.unsqueeze_dim: z = z[:, None, :] if self.output_tokens: assert not self.repeat_to_max_len assert not self.pad_to_max_len return tokens, z if self.repeat_to_max_len: if z.dim() == 2: z_ = z[:, None, :] else: z_ = z return repeat(z_, "b 1 d -> b n d", n=self.max_length), z elif self.pad_to_max_len: assert z.dim() == 3 z_pad = torch.cat( ( z, torch.zeros( z.shape[0], self.max_length - z.shape[1], z.shape[2], device=z.device, ), ), 1, ) return z_pad, z_pad[:, 0, ...] return z def encode_with_vision_transformer(self, img): # if self.max_crops > 0: # img = self.preprocess_by_cropping(img) if img.dim() == 5: assert self.max_crops == img.shape[1] img = rearrange(img, "b n c h w -> (b n) c h w") img = self.preprocess(img) if not self.output_tokens: assert not self.model.visual.output_tokens x = self.model.visual(img) tokens = None else: assert self.model.visual.output_tokens x, tokens = self.model.visual(img) if self.max_crops > 0: x = rearrange(x, "(b n) d -> b n d", n=self.max_crops) # drop out between 0 and all along the sequence axis x = ( torch.bernoulli( (1.0 - self.ucg_rate) * torch.ones(x.shape[0], x.shape[1], 1, device=x.device) ) * x ) if tokens is not None: tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops) print( f"You are running very experimental token-concat in {self.__class__.__name__}. " f"Check what you are doing, and then remove this message." ) if self.output_tokens: return x, tokens return x def encode(self, text): return self(text) class FrozenCLIPT5Encoder(AbstractEmbModel): def __init__( self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", clip_max_length=77, t5_max_length=77, ): super().__init__() self.clip_encoder = FrozenCLIPEmbedder( clip_version, device, max_length=clip_max_length ) self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) print( f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params." ) def encode(self, text): return self(text) def forward(self, text): clip_z = self.clip_encoder.encode(text) t5_z = self.t5_encoder.encode(text) return [clip_z, t5_z] class SpatialRescaler(nn.Module): def __init__( self, n_stages=1, method="bilinear", multiplier=0.5, in_channels=3, out_channels=None, bias=False, wrap_video=False, kernel_size=1, remap_output=False, ): super().__init__() self.n_stages = n_stages assert self.n_stages >= 0 assert method in [ "nearest", "linear", "bilinear", "trilinear", "bicubic", "area", ] self.multiplier = multiplier self.interpolator = partial(torch.nn.functional.interpolate, mode=method) self.remap_output = out_channels is not None or remap_output if self.remap_output: print( f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing." ) self.channel_mapper = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, bias=bias, padding=kernel_size // 2, ) self.wrap_video = wrap_video def forward(self, x): if self.wrap_video and x.ndim == 5: B, C, T, H, W = x.shape x = rearrange(x, "b c t h w -> b t c h w") x = rearrange(x, "b t c h w -> (b t) c h w") for stage in range(self.n_stages): x = self.interpolator(x, scale_factor=self.multiplier) if self.wrap_video: x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C) x = rearrange(x, "b t c h w -> b c t h w") if self.remap_output: x = self.channel_mapper(x) return x def encode(self, x): return self(x) class LowScaleEncoder(nn.Module): def __init__( self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64, scale_factor=1.0, ): super().__init__() self.max_noise_level = max_noise_level self.model = instantiate_from_config(model_config) self.augmentation_schedule = self.register_schedule( timesteps=timesteps, linear_start=linear_start, linear_end=linear_end ) self.out_size = output_size self.scale_factor = scale_factor def register_schedule( self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, ): betas = make_beta_schedule( beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s, ) alphas = 1.0 - betas alphas_cumprod = np.cumprod(alphas, axis=0) alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) (timesteps,) = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end assert ( alphas_cumprod.shape[0] == self.num_timesteps ), "alphas have to be defined for each timestep" to_torch = partial(torch.tensor, dtype=torch.float32) self.register_buffer("betas", to_torch(betas)) self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) self.register_buffer( "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) ) self.register_buffer( "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) ) self.register_buffer( "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) ) self.register_buffer( "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) ) def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) return ( extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) def forward(self, x): z = self.model.encode(x) if isinstance(z, DiagonalGaussianDistribution): z = z.sample() z = z * self.scale_factor noise_level = torch.randint( 0, self.max_noise_level, (x.shape[0],), device=x.device ).long() z = self.q_sample(z, noise_level) if self.out_size is not None: z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") return z, noise_level def decode(self, z): z = z / self.scale_factor return self.model.decode(z) class ConcatTimestepEmbedderND(AbstractEmbModel): """embeds each dimension independently and concatenates them""" def __init__(self, outdim): super().__init__() self.timestep = Timestep(outdim) self.outdim = outdim def forward(self, x): if x.ndim == 1: x = x[:, None] assert len(x.shape) == 2 b, dims = x.shape[0], x.shape[1] x = rearrange(x, "b d -> (b d)") emb = self.timestep(x) emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) return emb class GaussianEncoder(Encoder, AbstractEmbModel): def __init__( self, weight: float = 1.0, flatten_output: bool = True, *args, **kwargs ): super().__init__(*args, **kwargs) self.posterior = DiagonalGaussianRegularizer() self.weight = weight self.flatten_output = flatten_output def forward(self, x) -> Tuple[Dict, torch.Tensor]: z = super().forward(x) z, log = self.posterior(z) log["loss"] = log["kl_loss"] log["weight"] = self.weight if self.flatten_output: z = rearrange(z, "b c h w -> b (h w ) c") return log, z class VideoPredictionEmbedderWithEncoder(AbstractEmbModel): def __init__( self, n_cond_frames: int, n_copies: int, encoder_config: dict, sigma_sampler_config: Optional[dict] = None, sigma_cond_config: Optional[dict] = None, is_ae: bool = False, scale_factor: float = 1.0, disable_encoder_autocast: bool = False, en_and_decode_n_samples_a_time: Optional[int] = None, ): super().__init__() self.n_cond_frames = n_cond_frames self.n_copies = n_copies self.encoder = instantiate_from_config(encoder_config) self.sigma_sampler = ( instantiate_from_config(sigma_sampler_config) if sigma_sampler_config is not None else None ) self.sigma_cond = ( instantiate_from_config(sigma_cond_config) if sigma_cond_config is not None else None ) self.is_ae = is_ae self.scale_factor = scale_factor self.disable_encoder_autocast = disable_encoder_autocast self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time def forward( self, vid: torch.Tensor ) -> Union[ torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, dict], Tuple[Tuple[torch.Tensor, torch.Tensor], dict], ]: if self.sigma_sampler is not None: b = vid.shape[0] // self.n_cond_frames sigmas = self.sigma_sampler(b).to(vid.device) if self.sigma_cond is not None: sigma_cond = self.sigma_cond(sigmas) sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies) sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames) noise = torch.randn_like(vid) vid = vid + noise * append_dims(sigmas, vid.ndim) with torch.autocast("cuda", enabled=not self.disable_encoder_autocast): n_samples = ( self.en_and_decode_n_samples_a_time if self.en_and_decode_n_samples_a_time is not None else vid.shape[0] ) n_rounds = math.ceil(vid.shape[0] / n_samples) all_out = [] for n in range(n_rounds): if self.is_ae: out = self.encoder.encode(vid[n * n_samples : (n + 1) * n_samples]) else: out = self.encoder(vid[n * n_samples : (n + 1) * n_samples]) all_out.append(out) vid = torch.cat(all_out, dim=0) vid *= self.scale_factor vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames) vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies) return_val = (vid, sigma_cond) if self.sigma_cond is not None else vid return return_val class FrozenOpenCLIPImagePredictionEmbedder(AbstractEmbModel): def __init__( self, open_clip_embedding_config: Dict, n_cond_frames: int, n_copies: int, ): super().__init__() self.n_cond_frames = n_cond_frames self.n_copies = n_copies self.open_clip = instantiate_from_config(open_clip_embedding_config) def forward(self, vid): vid = self.open_clip(vid) vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames) vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies) return vid ================================================ FILE: sgm/modules/video_attention.py ================================================ import torch from ..modules.attention import * from ..modules.diffusionmodules.util import AlphaBlender, linear, timestep_embedding class TimeMixSequential(nn.Sequential): def forward(self, x, context=None, timesteps=None): for layer in self: x = layer(x, context, timesteps) return x class VideoTransformerBlock(nn.Module): ATTENTION_MODES = { "softmax": CrossAttention, "softmax-xformers": MemoryEfficientCrossAttention, } def __init__( self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True, timesteps=None, ff_in=False, inner_dim=None, attn_mode="softmax", disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, ): super().__init__() attn_cls = self.ATTENTION_MODES[attn_mode] # print(attn_cls) # exit(0) self.ff_in = ff_in or inner_dim is not None if inner_dim is None: inner_dim = dim assert int(n_heads * d_head) == inner_dim self.is_res = inner_dim == dim if self.ff_in: self.norm_in = nn.LayerNorm(dim) self.ff_in = FeedForward( dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff ) self.timesteps = timesteps self.disable_self_attn = disable_self_attn if self.disable_self_attn: self.attn1 = attn_cls( query_dim=inner_dim, heads=n_heads, dim_head=d_head, context_dim=context_dim, dropout=dropout, ) # is a cross-attention else: self.attn1 = attn_cls( query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout ) # is a self-attention self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff) if disable_temporal_crossattention: if switch_temporal_ca_to_sa: raise ValueError else: self.attn2 = None else: self.norm2 = nn.LayerNorm(inner_dim) if switch_temporal_ca_to_sa: self.attn2 = attn_cls( query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout ) # is a self-attention else: self.attn2 = attn_cls( query_dim=inner_dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout, ) # is self-attn if context is none self.norm1 = nn.LayerNorm(inner_dim) self.norm3 = nn.LayerNorm(inner_dim) self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa self.checkpoint = checkpoint ''' if self.checkpoint: print(f"{self.__class__.__name__} is using checkpointing") ''' def forward( self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None ) -> torch.Tensor: if self.checkpoint: return checkpoint(self._forward, x, context, timesteps, use_reentrant=False) # unsure: use_reentrant=False else: return self._forward(x, context, timesteps=timesteps) def _forward(self, x, context=None, timesteps=None): assert self.timesteps or timesteps assert not (self.timesteps and timesteps) or self.timesteps == timesteps timesteps = self.timesteps or timesteps B, S, C = x.shape x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps) # print('in', x.shape) if self.ff_in: x_skip = x x = self.ff_in(self.norm_in(x)) if self.is_res: x += x_skip if self.disable_self_attn: x = self.attn1(self.norm1(x), context=context) + x else: # print(self.attn1) # exit(0) # print(x.shape) # x = self.attn1(self.norm1(x)) + x x = self.attn1(self.norm1(x), align_w_first_frame=True) + x if self.attn2 is not None: # print(self.switch_temporal_ca_to_sa) if self.switch_temporal_ca_to_sa: x = self.attn2(self.norm2(x)) + x else: # print('att2') x = self.attn2(self.norm2(x), context=context) + x # exit(0) x_skip = x x = self.ff(self.norm3(x)) if self.is_res: x += x_skip x = rearrange( x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps ) # print('out', x.shape) return x def get_last_layer(self): return self.ff.net[-1].weight class SpatialVideoTransformer(SpatialTransformer): def __init__( self, in_channels, n_heads, d_head, depth=1, dropout=0.0, use_linear=False, context_dim=None, use_spatial_context=False, timesteps=None, merge_strategy: str = "fixed", merge_factor: float = 0.5, time_context_dim=None, ff_in=False, checkpoint=False, time_depth=1, attn_mode="softmax", disable_self_attn=False, disable_temporal_crossattention=False, max_time_embed_period: int = 10000, ): super().__init__( in_channels, n_heads, d_head, depth=depth, dropout=dropout, attn_type=attn_mode, use_checkpoint=checkpoint, context_dim=context_dim, use_linear=use_linear, disable_self_attn=disable_self_attn, ) self.time_depth = time_depth self.depth = depth self.max_time_embed_period = max_time_embed_period time_mix_d_head = d_head n_time_mix_heads = n_heads time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads) inner_dim = n_heads * d_head if use_spatial_context: time_context_dim = context_dim self.time_stack = nn.ModuleList( [ VideoTransformerBlock( inner_dim, n_time_mix_heads, time_mix_d_head, dropout=dropout, context_dim=time_context_dim, timesteps=timesteps, checkpoint=checkpoint, ff_in=ff_in, inner_dim=time_mix_inner_dim, attn_mode=attn_mode, disable_self_attn=disable_self_attn, disable_temporal_crossattention=disable_temporal_crossattention, ) for _ in range(self.depth) ] ) assert len(self.time_stack) == len(self.transformer_blocks) self.use_spatial_context = use_spatial_context self.in_channels = in_channels time_embed_dim = self.in_channels * 4 self.time_pos_embed = nn.Sequential( linear(self.in_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, self.in_channels), ) self.time_mixer = AlphaBlender( alpha=merge_factor, merge_strategy=merge_strategy ) def forward( self, x: torch.Tensor, context: Optional[torch.Tensor] = None, time_context: Optional[torch.Tensor] = None, timesteps: Optional[int] = None, image_only_indicator: Optional[torch.Tensor] = None, ) -> torch.Tensor: # print('all_in') _, _, h, w = x.shape x_in = x spatial_context = None if exists(context): spatial_context = context if self.use_spatial_context: assert ( context.ndim == 3 ), f"n dims of spatial context should be 3 but are {context.ndim}" time_context = context time_context_first_timestep = time_context[::timesteps] time_context = repeat( time_context_first_timestep, "b ... -> (b n) ...", n=h * w ) elif time_context is not None and not self.use_spatial_context: time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w) if time_context.ndim == 2: time_context = rearrange(time_context, "b c -> b 1 c") x = self.norm(x) if not self.use_linear: x = self.proj_in(x) x = rearrange(x, "b c h w -> b (h w) c") if self.use_linear: x = self.proj_in(x) num_frames = torch.arange(timesteps, device=x.device) num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) num_frames = rearrange(num_frames, "b t -> (b t)") t_emb = timestep_embedding( num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period, ) emb = self.time_pos_embed(t_emb) emb = emb[:, None, :] for it_, (block, mix_block) in enumerate( zip(self.transformer_blocks, self.time_stack) ): # print('in', x.shape) x = block( x, context=spatial_context, ) # print('out', x.shape) x_mix = x x_mix = x_mix + emb x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps) x = self.time_mixer( x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator, ) if self.use_linear: x = self.proj_out(x) x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) if not self.use_linear: x = self.proj_out(x) out = x + x_in # print('all_out') return out ================================================ FILE: sgm/util.py ================================================ import functools import importlib import os from functools import partial from inspect import isfunction import fsspec import numpy as np import torch from PIL import Image, ImageDraw, ImageFont from safetensors.torch import load_file as load_safetensors def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self def get_string_from_tuple(s): try: # Check if the string starts and ends with parentheses if s[0] == "(" and s[-1] == ")": # Convert the string to a tuple t = eval(s) # Check if the type of t is tuple if type(t) == tuple: return t[0] else: pass except: pass return s def is_power_of_two(n): """ chat.openai.com/chat Return True if n is a power of 2, otherwise return False. The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False. The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False. If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise. Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False. """ if n <= 0: return False return (n & (n - 1)) == 0 def autocast(f, enabled=True): def do_autocast(*args, **kwargs): with torch.cuda.amp.autocast( enabled=enabled, dtype=torch.get_autocast_gpu_dtype(), cache_enabled=torch.is_autocast_cache_enabled(), ): return f(*args, **kwargs) return do_autocast def load_partial_from_config(config): return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) # xc a list of captions to plot b = len(xc) txts = list() for bi in range(b): txt = Image.new("RGB", wh, color="white") draw = ImageDraw.Draw(txt) font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) nc = int(40 * (wh[0] / 256)) if isinstance(xc[bi], list): text_seq = xc[bi][0] else: text_seq = xc[bi] lines = "\n".join( text_seq[start : start + nc] for start in range(0, len(text_seq), nc) ) try: draw.text((0, 0), lines, fill="black", font=font) except UnicodeEncodeError: print("Cant encode string for logging. Skipping.") txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 txts.append(txt) txts = np.stack(txts) txts = torch.tensor(txts) return txts def partialclass(cls, *args, **kwargs): class NewCls(cls): __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) return NewCls def make_path_absolute(path): fs, p = fsspec.core.url_to_fs(path) if fs.protocol == "file": return os.path.abspath(p) return path def ismap(x): if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] > 3) def isimage(x): if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) def isheatmap(x): if not isinstance(x, torch.Tensor): return False return x.ndim == 2 def isneighbors(x): if not isinstance(x, torch.Tensor): return False return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1) def exists(x): return x is not None def expand_dims_like(x, y): while x.dim() != y.dim(): x = x.unsqueeze(-1) return x def default(val, d): if exists(val): return val return d() if isfunction(d) else d def mean_flat(tensor): """ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 Take the mean over all non-batch dimensions. """ return tensor.mean(dim=list(range(1, len(tensor.shape)))) def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") return total_params def instantiate_from_config(config): if not "target" in config: if config == "__is_first_stage__": return None elif config == "__is_unconditional__": return None raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) def get_obj_from_str(string, reload=False, invalidate_cache=True): module, cls = string.rsplit(".", 1) if invalidate_cache: importlib.invalidate_caches() if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def append_zero(x): return torch.cat([x, x.new_zeros([1])]) def append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: raise ValueError( f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" ) return x[(...,) + (None,) * dims_to_append] def load_model_from_config(config, ckpt, verbose=True, freeze=True): print(f"Loading model from {ckpt}") if ckpt.endswith("ckpt"): pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] elif ckpt.endswith("safetensors"): sd = load_safetensors(ckpt) else: raise NotImplementedError model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") print(m) if len(u) > 0 and verbose: print("unexpected keys:") print(u) if freeze: for param in model.parameters(): param.requires_grad = False model.eval() return model def get_configs_path() -> str: """ Get the `configs` directory. For a working copy, this is the one in the root of the repository, but for an installed copy, it's in the `sgm` package (see pyproject.toml). """ this_dir = os.path.dirname(__file__) candidates = ( os.path.join(this_dir, "configs"), os.path.join(this_dir, "..", "configs"), ) for candidate in candidates: candidate = os.path.abspath(candidate) if os.path.isdir(candidate): return candidate raise FileNotFoundError(f"Could not find SGM configs in {candidates}") def get_nested_attribute(obj, attribute_path, depth=None, return_key=False): """ Will return the result of a recursive get attribute call. E.g.: a.b.c = getattr(getattr(a, "b"), "c") = get_nested_attribute(a, "b.c") If any part of the attribute call is an integer x with current obj a, will try to call a[x] instead of a.x first. """ attributes = attribute_path.split(".") if depth is not None and depth > 0: attributes = attributes[:depth] assert len(attributes) > 0, "At least one attribute should be selected" current_attribute = obj current_key = None for level, attribute in enumerate(attributes): current_key = ".".join(attributes[: level + 1]) try: id_ = int(attribute) current_attribute = current_attribute[id_] except ValueError: current_attribute = getattr(current_attribute, attribute) return (current_attribute, current_key) if return_key else current_attribute ================================================ FILE: utils/save_video.py ================================================ import os import numpy as np from tqdm import tqdm from PIL import Image from einops import rearrange import cv2 import torch import torchvision from torch import Tensor from torchvision.utils import make_grid from torchvision.transforms.functional import to_tensor def frames_to_mp4(frame_dir, output_path, fps): def read_first_n_frames(d: os.PathLike, num_frames: int): if num_frames: images = [Image.open(os.path.join(d, f)) for f in sorted(os.listdir(d))[:num_frames]] else: images = [Image.open(os.path.join(d, f)) for f in sorted(os.listdir(d))] images = [to_tensor(x) for x in images] return torch.stack(images) videos = read_first_n_frames(frame_dir, num_frames=None) videos = videos.mul(255).to(torch.uint8).permute(0, 2, 3, 1) torchvision.io.write_video(output_path, videos, fps=fps, video_codec='h264', options={'crf': '10'}) def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None): """ video: torch.Tensor, b,c,t,h,w, 0-1 if -1~1, enable rescale=True """ n = video.shape[0] video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w nrow = int(np.sqrt(n)) if nrow is None else nrow frame_grids = [torchvision.utils.make_grid(framesheet, nrow=nrow) for framesheet in video] # [3, grid_h, grid_w] grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [T, 3, grid_h, grid_w] grid = torch.clamp(grid.float(), -1., 1.) if rescale: grid = (grid + 1.0) / 2.0 grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, 3, grid_h, grid_w] -> [T, grid_h, grid_w, 3] torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'}) def tensor2videogrids(video, root, filename, fps, rescale=True, clamp=True): assert(video.dim() == 5) # b,c,t,h,w assert(isinstance(video, torch.Tensor)) video = video.detach().cpu() if clamp: video = torch.clamp(video, -1., 1.) n = video.shape[0] video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(np.sqrt(n))) for framesheet in video] # [3, grid_h, grid_w] grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [T, 3, grid_h, grid_w] if rescale: grid = (grid + 1.0) / 2.0 grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, 3, grid_h, grid_w] -> [T, grid_h, grid_w, 3] path = os.path.join(root, filename) torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', options={'crf': '10'}) def flow2rgb(flow_map, max_value): flow_map_np = flow_map _, h, w = flow_map_np.shape rgb_map = np.ones((3,h,w)).astype(np.float32) if max_value is not None: normalized_flow_map = flow_map_np / max_value else: normalized_flow_map = flow_map_np / (np.abs(flow_map_np).max()) rgb_map[0] += normalized_flow_map[0] rgb_map[1] -= 0.5*(normalized_flow_map[0] + normalized_flow_map[1]) rgb_map[2] += normalized_flow_map[1] return rgb_map.clip(0,1) def save_flow_video(flow_tensor, output_file, fps=10, max_flow=None): b, _, n, h, w = flow_tensor.shape fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_writer = cv2.VideoWriter(output_file, fourcc, fps, (w,int(b*h))) for i in range(n): color_flows = [] for k in range(b): flow = flow_tensor[k, :, i].cpu().data.numpy() color_flow = (flow2rgb(flow, max_flow) * 255).astype(np.uint8) color_flow = np.transpose(color_flow, (1, 2, 0)) color_flows.append(color_flow) color_flows = np.concatenate(color_flows, axis=0) video_writer.write(color_flows) video_writer.release() def save_rgb_video(flow_tensor, output_file, fps=10, max_flow=None): b, _, n, h, w = flow_tensor.shape fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_writer = cv2.VideoWriter(output_file, fourcc, fps, (w,int(b*h))) for i in range(n): color_flows = [] for k in range(b): flow = flow_tensor[k, :, i].cpu().data.numpy() color_flow = (flow * 255).astype(np.uint8) color_flow = np.transpose(color_flow, (1, 2, 0)) color_flows.append(color_flow) color_flows = np.concatenate(color_flows, axis=0) video_writer.write(color_flows) video_writer.release() def log_local(batch_logs, save_dir, filename, save_fps=10, rescale=True): if batch_logs is None: return None """ save images and videos from images dict """ def save_img_grid(grid, path, rescale): if rescale: grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) grid = grid.numpy() grid = (grid * 255).astype(np.uint8) os.makedirs(os.path.split(path)[0], exist_ok=True) Image.fromarray(grid).save(path) for key in batch_logs: value = batch_logs[key] if key == 'flow': path = os.path.join(save_dir, "%s-%s.mp4"%(key, filename)) save_flow_video(value, path) else: if isinstance(value, list) and isinstance(value[0], str): ## a batch of captions path = os.path.join(save_dir, "%s-%s.txt"%(key, filename)) with open(path, 'w') as f: for i, txt in enumerate(value): f.write(f'idx={i}, txt={txt}\n') f.close() elif isinstance(value, torch.Tensor) and value.dim() == 5: ## save video grids video = value # b,c,t,h,w ## only save grayscale or rgb mode if video.shape[1] != 1 and video.shape[1] != 3: continue n = video.shape[0] video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(1)) for framesheet in video] #[3, n*h, 1*w] grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w] if rescale: grid = (grid + 1.0) / 2.0 grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) path = os.path.join(save_dir, "%s-%s.mp4"%(key, filename)) torchvision.io.write_video(path, grid, fps=save_fps, video_codec='h264', options={'crf': '10'}) ## save frame sheet img = value video_frames = rearrange(img, 'b c t h w -> (b t) c h w') t = img.shape[2] grid = torchvision.utils.make_grid(video_frames, nrow=t) path = os.path.join(save_dir, "%s-%s.jpg"%(key, filename)) #save_img_grid(grid, path, rescale) elif isinstance(value, torch.Tensor) and value.dim() == 4: ## save image grids img = value ## only save grayscale or rgb mode if img.shape[1] != 1 and img.shape[1] != 3: continue n = img.shape[0] grid = torchvision.utils.make_grid(img, nrow=1) path = os.path.join(save_dir, "%s-%s.jpg"%(key, filename)) save_img_grid(grid, path, rescale) else: pass def prepare_to_log(batch_logs, max_images=100000, clamp=True): if batch_logs is None: return None # process for key in batch_logs: N = batch_logs[key].shape[0] if hasattr(batch_logs[key], 'shape') else len(batch_logs[key]) N = min(N, max_images) batch_logs[key] = batch_logs[key][:N] ## in batch_logs: images & caption if isinstance(batch_logs[key], torch.Tensor): batch_logs[key] = batch_logs[key].detach().cpu() if clamp and key!='flow': try: batch_logs[key] = torch.clamp(batch_logs[key].float(), -1., 1.) except RuntimeError: print("clamp_scalar_cpu not implemented for Half") return batch_logs # ---------------------------------------------------------------------------------------------- def fill_with_black_squares(video, desired_len: int) -> Tensor: if len(video) >= desired_len: return video return torch.cat([ video, torch.zeros_like(video[0]).unsqueeze(0).repeat(desired_len - len(video), 1, 1, 1), ], dim=0) # ---------------------------------------------------------------------------------------------- def load_num_videos(data_path, num_videos): # first argument can be either data_path of np array if isinstance(data_path, str): videos = np.load(data_path)['arr_0'] # NTHWC elif isinstance(data_path, np.ndarray): videos = data_path else: raise Exception if num_videos is not None: videos = videos[:num_videos, :, :, :, :] return videos def npz_to_video_grid(data_path, out_path, num_frames, fps, num_videos=None, nrow=None, verbose=True): # videos = torch.tensor(np.load(data_path)['arr_0']).permute(0,1,4,2,3).div_(255).mul_(2) - 1.0 # NTHWC->NTCHW, np int -> torch tensor 0-1 if isinstance(data_path, str): videos = load_num_videos(data_path, num_videos) elif isinstance(data_path, np.ndarray): videos = data_path else: raise Exception n,t,h,w,c = videos.shape videos_th = [] for i in range(n): video = videos[i, :,:,:,:] images = [video[j, :,:,:] for j in range(t)] images = [to_tensor(img) for img in images] video = torch.stack(images) videos_th.append(video) if verbose: videos = [fill_with_black_squares(v, num_frames) for v in tqdm(videos_th, desc='Adding empty frames')] # NTCHW else: videos = [fill_with_black_squares(v, num_frames) for v in videos_th] # NTCHW frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4) # [T, N, C, H, W] if nrow is None: nrow = int(np.ceil(np.sqrt(n))) if verbose: frame_grids = [make_grid(fs, nrow=nrow) for fs in tqdm(frame_grids, desc='Making grids')] else: frame_grids = [make_grid(fs, nrow=nrow) for fs in frame_grids] if os.path.dirname(out_path) != "": os.makedirs(os.path.dirname(out_path), exist_ok=True) frame_grids = (torch.stack(frame_grids) * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, H, W, C] torchvision.io.write_video(out_path, frame_grids, fps=fps, video_codec='h264', options={'crf': '10'}) ================================================ FILE: utils/tools.py ================================================ import torch import torch.nn as nn import numpy as np from omegaconf import OmegaConf from sgm.util import default, instantiate_from_config from omegaconf import OmegaConf from einops import rearrange, repeat import math def quick_freeze(model): for name, param in model.named_parameters(): param.requires_grad = False return model def get_gaussian_kernel(kernel_size, sigma, channels): print('parameters of gaussian kernel: kernel_size: {}, sigma: {}, channels: {}'.format(kernel_size, sigma, channels)) x_coord = torch.arange(kernel_size) x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size) y_grid = x_grid.t() xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() mean = (kernel_size - 1)/2. variance = sigma**2. gaussian_kernel = torch.exp( -torch.sum((xy_grid - mean)**2., dim=-1) /\ (2*variance) ) gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1) gaussian_filter = nn.Conv2d(in_channels=channels, out_channels=channels,kernel_size=kernel_size, groups=channels, bias=False, padding=kernel_size//2) gaussian_filter.weight.data = gaussian_kernel gaussian_filter.weight.requires_grad = False return gaussian_filter def resize_pil_image(image, max_resolution=768 * 768, resize_short_edge=None): h, w = image.size if resize_short_edge is not None: k = resize_short_edge / min(h, w) else: k = max_resolution / (h * w) k = k**0.5 h = int(np.round(h * k / 64)) * 64 w = int(np.round(w * k / 64)) * 64 image = image.resize((h, w)) return image# def get_unique_embedder_keys_from_conditioner(conditioner): return list(set([x.input_key for x in conditioner.embedders])) def get_batch(keys, value_dict, N, T, device): batch = {} batch_uc = {} for key in keys: if key == "fps_id": batch[key] = ( torch.tensor([value_dict["fps_id"]]) .to(device) .repeat(int(math.prod(N))) ) elif key == "motion_bucket_id": batch[key] = ( torch.tensor([value_dict["motion_bucket_id"]]) .to(device) .repeat(int(math.prod(N))) ) elif key == "cond_aug": batch[key] = repeat( torch.tensor([value_dict["cond_aug"]]).to(device), "1 -> b", b=math.prod(N), ) elif key == "cond_frames": batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0]) elif key == "cond_frames_without_noise": batch[key] = repeat( value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0] ) else: batch[key] = value_dict[key] if T is not None: batch["num_video_frames"] = T for key in batch.keys(): if key not in batch_uc and isinstance(batch[key], torch.Tensor): batch_uc[key] = torch.clone(batch[key]) return batch, batch_uc def load_model( config: str, ckpt: str, device: str, num_frames: int, num_steps: int, ): config = OmegaConf.load(config) config.model.params.ckpt_path = ckpt if device == "cuda": config.model.params.conditioner_config.params.emb_models[ 0 ].params.open_clip_embedding_config.params.init_device = device config.model.params.sampler_config.params.num_steps = num_steps config.model.params.sampler_config.params.guider_config.params.num_frames = ( num_frames ) if device == "cuda": #with torch.device(device): model = instantiate_from_config(config.model).to(device, dtype=torch.float16).eval() else: model = instantiate_from_config(config.model).to(device).eval() filter = None #DeepFloydDataFiltering(verbose=False, device=device) return model, filter ================================================ FILE: utils/visualizer.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os import numpy as np import imageio import torch from matplotlib import cm import torch.nn.functional as F import torchvision.transforms as transforms import matplotlib.pyplot as plt from PIL import Image, ImageDraw def read_video_from_path(path): try: reader = imageio.get_reader(path) except Exception as e: print("Error opening video file: ", e) return None frames = [] for i, im in enumerate(reader): frames.append(np.array(im)) return np.stack(frames) def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True): # Create a draw object draw = ImageDraw.Draw(rgb) # Calculate the bounding box of the circle left_up_point = (coord[0] - radius, coord[1] - radius) right_down_point = (coord[0] + radius, coord[1] + radius) # Draw the circle draw.ellipse( [left_up_point, right_down_point], fill=tuple(color) if visible else None, outline=tuple(color), ) return rgb def draw_line(rgb, coord_y, coord_x, color, linewidth): draw = ImageDraw.Draw(rgb) draw.line( (coord_y[0], coord_y[1], coord_x[0], coord_x[1]), fill=tuple(color), width=linewidth, ) return rgb def add_weighted(rgb, alpha, original, beta, gamma): return (rgb * alpha + original * beta + gamma).astype("uint8") class Visualizer: def __init__( self, save_dir: str = "./results", grayscale: bool = False, pad_value: int = 0, fps: int = 10, mode: str = "rainbow", # 'cool', 'optical_flow' linewidth: int = 2, show_first_frame: int = 10, tracks_leave_trace: int = 0, # -1 for infinite ): self.mode = mode self.save_dir = save_dir if mode == "rainbow": self.color_map = cm.get_cmap("gist_rainbow") elif mode == "cool": self.color_map = cm.get_cmap(mode) self.show_first_frame = show_first_frame self.grayscale = grayscale self.tracks_leave_trace = tracks_leave_trace self.pad_value = pad_value self.linewidth = linewidth self.fps = fps def visualize( self, video: torch.Tensor, # (B,T,C,H,W) tracks: torch.Tensor, # (B,T,N,2) visibility: torch.Tensor = None, # (B, T, N, 1) bool gt_tracks: torch.Tensor = None, # (B,T,N,2) segm_mask: torch.Tensor = None, # (B,1,H,W) filename: str = "video", writer=None, # tensorboard Summary Writer, used for visualization during training step: int = 0, query_frame: int = 0, save_video: bool = True, compensate_for_camera_motion: bool = False, ): if compensate_for_camera_motion: assert segm_mask is not None if segm_mask is not None: coords = tracks[0, query_frame].round().long() segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long() video = F.pad( video, (self.pad_value, self.pad_value, self.pad_value, self.pad_value), "constant", 255, ) tracks = tracks + self.pad_value if self.grayscale: transform = transforms.Grayscale() video = transform(video) video = video.repeat(1, 1, 3, 1, 1) res_video = self.draw_tracks_on_video( video=video, tracks=tracks, visibility=visibility, segm_mask=segm_mask, gt_tracks=gt_tracks, query_frame=query_frame, compensate_for_camera_motion=compensate_for_camera_motion, ) if save_video: self.save_video(res_video, filename=filename, writer=writer, step=step) return res_video def save_video(self, video, filename, writer=None, step=0): if writer is not None: writer.add_video( filename, video.to(torch.uint8), global_step=step, fps=self.fps, ) else: os.makedirs(self.save_dir, exist_ok=True) wide_list = list(video.unbind(1)) wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list] # Prepare the video file path save_path = os.path.join(self.save_dir, f"{filename}.mp4") # Create a writer object video_writer = imageio.get_writer(save_path, fps=self.fps) # Write frames to the video file # for frame in wide_list[2:-1]: print(len(wide_list)) for frame in wide_list: video_writer.append_data(frame) video_writer.close() print(f"Video saved to {save_path}") def draw_tracks_on_video( self, video: torch.Tensor, tracks: torch.Tensor, visibility: torch.Tensor = None, segm_mask: torch.Tensor = None, gt_tracks=None, query_frame: int = 0, compensate_for_camera_motion=False, ): B, T, C, H, W = video.shape _, _, N, D = tracks.shape assert D == 2 assert C == 3 video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2 if gt_tracks is not None: gt_tracks = gt_tracks[0].detach().cpu().numpy() res_video = [] # process input video for rgb in video: res_video.append(rgb.copy()) vector_colors = np.zeros((T, N, 3)) if self.mode == "optical_flow": import flow_vis vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None]) elif segm_mask is None: if self.mode == "rainbow": y_min, y_max = ( tracks[query_frame, :, 1].min(), tracks[query_frame, :, 1].max(), ) norm = plt.Normalize(y_min, y_max) for n in range(N): color = self.color_map(norm(tracks[query_frame, n, 1])) color = np.array(color[:3])[None] * 255 vector_colors[:, n] = np.repeat(color, T, axis=0) else: # color changes with time for t in range(T): color = np.array(self.color_map(t / T)[:3])[None] * 255 vector_colors[t] = np.repeat(color, N, axis=0) else: if self.mode == "rainbow": vector_colors[:, segm_mask <= 0, :] = 255 y_min, y_max = ( tracks[0, segm_mask > 0, 1].min(), tracks[0, segm_mask > 0, 1].max(), ) norm = plt.Normalize(y_min, y_max) for n in range(N): if segm_mask[n] > 0: color = self.color_map(norm(tracks[0, n, 1])) color = np.array(color[:3])[None] * 255 vector_colors[:, n] = np.repeat(color, T, axis=0) else: # color changes with segm class segm_mask = segm_mask.cpu() color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32) color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0 color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0 vector_colors = np.repeat(color[None], T, axis=0) # draw tracks if self.tracks_leave_trace != 0: for t in range(query_frame + 1, T): first_ind = ( max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0 ) curr_tracks = tracks[first_ind : t + 1] curr_colors = vector_colors[first_ind : t + 1] if compensate_for_camera_motion: diff = ( tracks[first_ind : t + 1, segm_mask <= 0] - tracks[t : t + 1, segm_mask <= 0] ).mean(1)[:, None] curr_tracks = curr_tracks - diff curr_tracks = curr_tracks[:, segm_mask > 0] curr_colors = curr_colors[:, segm_mask > 0] res_video[t] = self._draw_pred_tracks( res_video[t], curr_tracks, curr_colors, ) if gt_tracks is not None: res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1]) # draw points for t in range(query_frame, T): img = Image.fromarray(np.uint8(res_video[t])) for i in range(N): coord = (tracks[t, i, 0], tracks[t, i, 1]) visibile = True if visibility is not None: visibile = visibility[0, t, i] if coord[0] != 0 and coord[1] != 0: if not compensate_for_camera_motion or ( compensate_for_camera_motion and segm_mask[i] > 0 ): img = draw_circle( img, coord=coord, radius=int(self.linewidth * 2), color=vector_colors[t, i].astype(int), visible=visibile, ) res_video[t] = np.array(img) # construct the final rgb sequence if self.show_first_frame > 0: res_video = [res_video[0]] * self.show_first_frame + res_video[1:] return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte() def _draw_pred_tracks( self, rgb: np.ndarray, # H x W x 3 tracks: np.ndarray, # T x 2 vector_colors: np.ndarray, alpha: float = 0.5, ): T, N, _ = tracks.shape rgb = Image.fromarray(np.uint8(rgb)) for s in range(T - 1): vector_color = vector_colors[s] original = rgb.copy() alpha = (s / T) ** 2 for i in range(N): coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1])) coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1])) if coord_y[0] != 0 and coord_y[1] != 0: rgb = draw_line( rgb, coord_y, coord_x, vector_color[i].astype(int), self.linewidth, ) if self.tracks_leave_trace > 0: rgb = Image.fromarray( np.uint8(add_weighted(np.array(rgb), alpha, np.array(original), 1 - alpha, 0)) ) rgb = np.array(rgb) return rgb def _draw_gt_tracks( self, rgb: np.ndarray, # H x W x 3, gt_tracks: np.ndarray, # T x 2 ): T, N, _ = gt_tracks.shape color = np.array((211, 0, 0)) rgb = Image.fromarray(np.uint8(rgb)) for t in range(T): for i in range(N): gt_tracks = gt_tracks[t][i] # draw a red cross if gt_tracks[0] > 0 and gt_tracks[1] > 0: length = self.linewidth * 3 coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length) coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length) rgb = draw_line( rgb, coord_y, coord_x, color, self.linewidth, ) coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length) coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length) rgb = draw_line( rgb, coord_y, coord_x, color, self.linewidth, ) rgb = np.array(rgb) return rgb