Full Code of MC-E/ReVideo for AI

main 174645250d9d cached
80 files
445.7 KB
113.2k tokens
668 symbols
1 requests
Download .txt
Showing preview only (472K chars total). Download the full file or copy to clipboard to get everything.
Repository: MC-E/ReVideo
Branch: main
Commit: 174645250d9d
Files: 80
Total size: 445.7 KB

Directory structure:
gitextract_6p1lor9m/

├── LICENSE
├── README.md
├── configs/
│   ├── examples/
│   │   ├── constant_motion/
│   │   │   ├── head6.sh
│   │   │   ├── head7.sh
│   │   │   ├── kong.sh
│   │   │   ├── monkey.sh
│   │   │   ├── woman2.sh
│   │   │   └── woman5.sh
│   │   ├── multi_region/
│   │   │   ├── lawn2.sh
│   │   │   └── woman.sh
│   │   └── single_region/
│   │       ├── desert.sh
│   │       ├── dog.sh
│   │       ├── football.sh
│   │       ├── forest.sh
│   │       ├── head5.sh
│   │       ├── lawn.sh
│   │       ├── lizard.sh
│   │       ├── road.sh
│   │       ├── sea.sh
│   │       ├── sea2.sh
│   │       ├── sky.sh
│   │       └── woman4.sh
│   └── inference/
│       └── config_test.yaml
├── ctrl_model/
│   ├── diffusion_ctrl.py
│   └── svd_ctrl.py
├── main/
│   └── inference/
│       ├── sample_constant_motion.py
│       ├── sample_multi_region.py
│       └── sample_single_region.py
├── requirements.txt
├── sgm/
│   ├── __init__.py
│   ├── inference/
│   │   ├── api.py
│   │   └── helpers.py
│   ├── lr_scheduler.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── autoencoder.py
│   │   └── diffusion.py
│   ├── modules/
│   │   ├── __init__.py
│   │   ├── attention.py
│   │   ├── autoencoding/
│   │   │   ├── __init__.py
│   │   │   ├── losses/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── discriminator_loss.py
│   │   │   │   └── lpips.py
│   │   │   ├── lpips/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── loss/
│   │   │   │   │   ├── .gitignore
│   │   │   │   │   ├── LICENSE
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   └── lpips.py
│   │   │   │   ├── model/
│   │   │   │   │   ├── LICENSE
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   └── model.py
│   │   │   │   ├── util.py
│   │   │   │   └── vqperceptual.py
│   │   │   ├── regularizers/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── base.py
│   │   │   │   └── quantize.py
│   │   │   └── temporal_ae.py
│   │   ├── diffusionmodules/
│   │   │   ├── __init__.py
│   │   │   ├── denoiser.py
│   │   │   ├── denoiser_scaling.py
│   │   │   ├── denoiser_weighting.py
│   │   │   ├── discretizer.py
│   │   │   ├── guiders.py
│   │   │   ├── model.py
│   │   │   ├── openaimodel.py
│   │   │   ├── sampling.py
│   │   │   ├── sampling_utils.py
│   │   │   ├── sigma_sampling.py
│   │   │   ├── util.py
│   │   │   ├── video_model.py
│   │   │   └── wrappers.py
│   │   ├── distributions/
│   │   │   ├── __init__.py
│   │   │   └── distributions.py
│   │   ├── ema.py
│   │   ├── encoders/
│   │   │   ├── __init__.py
│   │   │   └── modules.py
│   │   └── video_attention.py
│   └── util.py
└── utils/
    ├── save_video.py
    ├── tools.py
    └── visualizer.py

================================================
FILE CONTENTS
================================================

================================================
FILE: LICENSE
================================================
STABLE VIDEO DIFFUSION NON-COMMERCIAL COMMUNITY LICENSE AGREEMENT	
Dated: November 21, 2023

“AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.

"Agreement" means the terms and conditions for use, reproduction, distribution and modification of the Software Products set forth herein.
"Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model.
“Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.

"Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.

"Stability AI" or "we" means Stability AI Ltd. 

"Software" means, collectively, Stability AI’s proprietary software, models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement.

“Software Products” means Software and Documentation. 

By using or distributing any portion or element of the Software Products, you agree to be bound by this Agreement.



License Rights and Redistribution. 
Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Software Products to reproduce, distribute, and create Derivative Works of the Software Products for purposes other than commercial or production use.     
b.	If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "Stable Video Diffusion is licensed under the Stable Video Diffusion Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified.
2. 	  Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS  AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS. 
3.   Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. 
3.   Intellectual Property.
a. 	No trademark licenses are granted under this Agreement, and in connection with the Software Products, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products. 
Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works. 
If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products in violation of this Agreement. 
4.   Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of any Software Products or Derivative Works. Sections 2-4 shall survive the termination of this Agreement. 



================================================
FILE: README.md
================================================
# ReVideo: Remake a Video with Motion and Content Control
[Chong Mou](https://scholar.google.com/citations?user=SYQoDk0AAAAJ&hl=zh-CN),
[Mingdeng Cao](https://scholar.google.com/citations?user=EcS0L5sAAAAJ&hl=en),
[Xintao Wang](https://xinntao.github.io/),
[Zhaoyang Zhang](https://zzyfd.github.io/),
[Ying Shan](https://scholar.google.com/citations?user=4oXBp9UAAAAJ),
[Jian Zhang](https://jianzhang.tech/)

[![Project page](https://img.shields.io/badge/Project-Page-brightgreen)](https://mc-e.github.io/project/ReVideo/)
[![arXiv](https://img.shields.io/badge/ArXiv-2405.13865-brightgreen)](https://arxiv.org/abs/2405.13865)

---
## Introduction
ReVideo aims to solve the problem of local video editing. The editing target includes visual content and motion trajectory modifications.
<p align="center">
  <img src="asserts/teaser.jpg">
</p>

## 📰 **New Features/Updates**
- [2024/09/25] ReVideo is accepted by NeurIPS 2024.
- [2024/06/26] We release the code of ReVideo.
- [2024/05/26] **Long video editing plan**: We are collaborating with [Open-Sora Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) team to replace SVD with Sora framework, making ReVideo suitable for long video editing. Here are some preliminary results. This initial combination is still limited in quality for long videos. In the future, we will continue to cooperate and launch high-quality long video editing models.
<table class="center">
<tr>
  <td style="text-align:center;"><b>Generated by Open-Sora</b></td>
  <td style="text-align:center;"><b>Editing Result</b></td>
</tr>
<tr>
  <td><video src="https://github.com/MC-E/ReVideo/assets/54032224/81241556-0f1b-438e-ba90-094d7cc0eded" autoplay></td>
  <td><video src="https://github.com/MC-E/ReVideo/assets/54032224/474b3620-f156-4d30-a473-cbbcc615f56c" autoplay></td>
</tr>
</table>
- [2024/05/23] Paper and project page of **ReVideo** are available.

## ✏️ Todo
- [x] Code will be open-sourced in June

## 🔥🔥🔥 Main Features
### Change content & Customize motion trajectoy
<table class="center">
<tr>
  <td style="text-align:center;"><b>Input Video</b></td>
  <td style="text-align:center;"><b>Editing Result</b></td>
</tr>
<tr>
  <td><video src="https://github.com/MC-E/DragonDiffusion/assets/54032224/222f35da-7396-4989-a3c3-9ab4a2e5fa98" autoplay></td>
  <td><video src="https://github.com/MC-E/DragonDiffusion/assets/54032224/c128f1d7-30e4-49e7-b6b7-9d5f428ff882" autoplay></td>
</tr>
</table>

### Change content & Keep motion trajectoy
<table class="center">
<tr>
  <td style="text-align:center;"><b>Input Video</b></td>
  <td style="text-align:center;"><b>Editing Result</b></td>
</tr>
<tr>
  <td><video src="https://github.com/MC-E/DragonDiffusion/assets/54032224/d25dce6a-88cf-45ad-9177-76df9fffe819" autoplay></td>
  <td><video src="https://github.com/MC-E/DragonDiffusion/assets/54032224/06c8f19d-4569-417f-a4a3-1782a09404db" autoplay></td>
</tr>
</table>

### Keep content & Customize motion trajectoy
<table class="center">
<tr>
  <td style="text-align:center;"><b>Input Video</b></td>
  <td style="text-align:center;"><b>Editing Result</b></td>
</tr>
<tr>
  <td><video src="https://github.com/MC-E/DragonDiffusion/assets/54032224/490b4e9b-c1af-4f87-83de-c6b27f4a925b" autoplay></td>
  <td><video src="https://github.com/MC-E/DragonDiffusion/assets/54032224/93f77c7b-23a8-4b1e-8e6d-1abf57fd1130" autoplay></td>
</tr>
</table>

### Multi-area Editing
<table class="center">
<tr>
  <td style="text-align:center;"><b>Input Video</b></td>
  <td style="text-align:center;"><b>Editing Result</b></td>
</tr>
<tr>
  <td><video src="https://github.com/MC-E/DragonDiffusion/assets/54032224/339263b6-ea97-4c43-8617-b40459b1973c" autoplay></td>
  <td><video src="https://github.com/MC-E/DragonDiffusion/assets/54032224/7a005b3a-ff3e-492c-9643-0fd921b0b53e" autoplay></td>
</tr>
</table>

## 🔧 Dependencies and Installation

- Python >= 3.8 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
- [PyTorch >= 2.0.1](https://pytorch.org/)
```bash
pip install -r requirements.txt
```

## ⏬ Download Models 
All models will be automatically downloaded. You can also choose to download manually from this [url](https://huggingface.co/Adapter/ReVideo).

**Since our ReVideo is trained based on [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid), the usage guidelines for the model should follow the Stable Video Diffusion's [NC-COMMUNITY LICENSE](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid/blob/main/LICENSE)!**

## 💻 How to Test
You can download the testset from  [https://huggingface.co/Adapter/ReVideo](https://huggingface.co/Adapter/ReVideo).
Inference requires at least `20GB` of GPU memory for editing a `768x1344` video.  

```bash
bash configs/examples/constant_motion/head6.sh
```

### Description of input parameters
```bash
--s_h  # The abscissa of the top left corner of the editing region
--e_h # The abscissa of the lower right corner of the editing region
--s_w # The ordinate of the top left corner of the editing region
--e_w # The ordinate of the lower right corner of the editing region
--ps_h # The abscissa of the start point
--pe_h # The abscissa of the end point
--ps_w # The ordinate of the start point
--pe_w # The ordinate of the end point
--x_bias_all # Horizontal offset of reciprocating motion
--y_bias_all # Vertical offset of reciprocating motion
```

## Related Works
<p>
[1] <a href="https://pika.art/">https://pika.art/</a>
</p>
<p>
[2] <a href="https://arxiv.org/abs/2308.08089">DragNUWA: Fine-grained Control in Video Generation by Integrating Text, Image, and Trajectory</a>
</p>
<p>
[3] <a href="https://arxiv.org/abs/2403.07420">
    DragAnything: Motion Control for Anything using Entity Representation</a>
</p>
<p>
[4] <a href="https://arxiv.org/abs/2403.14468/">AnyV2V: A Plug-and-Play Framework For Any Video-to-Video Editing Tasks</a>
</p>

# 🤗 Acknowledgements
We appreciate the releasing code of [Stable Video Diffusion](https://github.com/Stability-AI/generative-models).


================================================
FILE: configs/examples/constant_motion/head6.sh
================================================
name="svd-example2[fps6_mb127-temp]"
config="configs/inference/config_test.yaml"
ckpt="ckpt/model.ckpt"
image_input="testset/head6"
path_ref="testset/reference/head6.png"
res_dir="outputs"

python3 main/inference/sample_constant_motion.py \
--seed 23 \
--ckpt $ckpt \
--config $config \
--savedir $res_dir/$name \
--savefps 10 \
--ddim_steps 25 \
--frames 14 \
--savefps 10 \
--input $image_input \
--path_ref $path_ref \
--fps 10 \
--motion 127 \
--cond_aug 0.02 \
--decoding_t 1 --resize \
--s_h 247 \
--e_h 661 \
--s_w 46 \
--e_w 620 \
--ps_h 389 350 464 385 \
--ps_w 264 363 363 440

================================================
FILE: configs/examples/constant_motion/head7.sh
================================================
name="svd-example2[fps6_mb127-temp]"
config="configs/inference/config_test.yaml"
ckpt="ckpt/model.ckpt"
image_input="testset/head7"
path_ref="testset/reference/head7.png"
res_dir="outputs"

python3 main/inference/sample_constant_motion.py \
--seed 23 \
--ckpt $ckpt \
--config $config \
--savedir $res_dir/$name \
--savefps 10 \
--ddim_steps 25 \
--frames 14 \
--savefps 10 \
--input $image_input \
--path_ref $path_ref \
--fps 10 \
--motion 127 \
--cond_aug 0.02 \
--decoding_t 1 --resize \
--s_h 1672 \
--e_h 2433 \
--s_w 330 \
--e_w 704 \
--ps_h 1854 2082 1932 \
--ps_w 479 501 622

================================================
FILE: configs/examples/constant_motion/kong.sh
================================================
name="svd-example2[fps6_mb127-temp]"
config="configs/inference/config_test.yaml"
ckpt="ckpt/model.ckpt"
image_input="testset/kong"
path_ref="testset/reference/kong.png"
res_dir="outputs"

python3 main/inference/sample_constant_motion.py \
--seed 23 \
--ckpt $ckpt \
--config $config \
--savedir $res_dir/$name \
--savefps 10 \
--ddim_steps 15 \
--frames 14 \
--savefps 10 \
--input $image_input \
--path_ref $path_ref \
--fps 10 \
--motion 127 \
--cond_aug 0.02 \
--decoding_t 1 --resize \
--s_h 290 \
--e_h 1128 \
--s_w 0 \
--e_w 456 \
--ps_h 720 911 835 877 \
--ps_w 223 258 150 123


================================================
FILE: configs/examples/constant_motion/monkey.sh
================================================
name="svd-example2[fps6_mb127-temp]"
config="configs/inference/config_test.yaml"
ckpt="ckpt/model.ckpt"
image_input="testset/monkey"
path_ref="testset/reference/monkey.png"
res_dir="outputs"

python3 main/inference/sample_constant_motion.py \
--seed 23 \
--ckpt $ckpt \
--config $config \
--savedir $res_dir/$name \
--savefps 10 \
--ddim_steps 25 \
--frames 14 \
--savefps 10 \
--input $image_input \
--path_ref $path_ref \
--fps 10 \
--motion 127 \
--cond_aug 0.02 \
--decoding_t 1 --resize \
--s_h 461 \
--e_h 1039 \
--s_w 189 \
--e_w 770 \
--ps_h 615 736 694 \
--ps_w 457 461 586


================================================
FILE: configs/examples/constant_motion/woman2.sh
================================================
name="svd-example2[fps6_mb127-temp]"
config="configs/inference/config_test.yaml"
ckpt="ckpt/model.ckpt"
image_input="testset/woman2"
path_ref="testset/reference/woman2.png"
res_dir="outputs"

python3 main/inference/sample_constant_motion.py \
--seed 23 \
--ckpt $ckpt \
--config $config \
--savedir $res_dir/$name \
--savefps 10 \
--ddim_steps 15 \
--frames 14 \
--savefps 10 \
--input $image_input \
--path_ref $path_ref \
--fps 10 \
--motion 127 \
--cond_aug 0.02 \
--decoding_t 1 --resize \
--s_h 797 \
--e_h 1098 \
--s_w 315 \
--e_w 647 \
--ps_h 888 1009 950 \
--ps_w 391 385 502


================================================
FILE: configs/examples/constant_motion/woman5.sh
================================================
name="svd-example2[fps6_mb127-temp]"
config="configs/inference/config_test.yaml"
ckpt="ckpt/model.ckpt"
image_input="testset/woman5"
path_ref="testset/reference/woman5.png"
res_dir="outputs"

python3 main/inference/sample_constant_motion.py \
--seed 23 \
--ckpt $ckpt \
--config $config \
--savedir $res_dir/$name \
--savefps 10 \
--ddim_steps 15 \
--frames 14 \
--savefps 10 \
--input $image_input \
--path_ref $path_ref \
--fps 10 \
--motion 127 \
--cond_aug 0.02 \
--decoding_t 1 --resize \
--s_h 605 \
--e_h 1449 \
--s_w 599 \
--e_w 1929 \
--ps_h 825 1263 1671 \
--ps_w 1123 1287 1231


================================================
FILE: configs/examples/multi_region/lawn2.sh
================================================
name="svd-example2[fps6_mb127-temp]"
config="configs/inference/config_test.yaml"
ckpt="ckpt/model.ckpt"
image_input="testset/lawn2"
path_ref="testset/reference/lawn2.png"
res_dir="outputs"

python3 main/inference/sample_multi_region.py \
--seed 23 \
--ckpt $ckpt \
--config $config \
--savedir $res_dir/$name \
--savefps 10 \
--ddim_steps 15 \
--frames 14 \
--savefps 10 \
--input $image_input \
--path_ref $path_ref \
--fps 6 \
--motion 127 \
--cond_aug 0.02 \
--decoding_t 1 --resize \
--s_h 1579 2577 509 1405 2369 \
--e_h 2155 2991 821 1723 2725 \
--s_w 311 119 1207 1257 1189 \
--e_w 673 393 1653 1663 1641 \
--ps_h 2887 2027 675 1577 2561 \
--pe_h 2707 1777 675 1577 2561 \
--ps_w 221 449 1347 1395 1367 \
--pe_w 221 449 1347 1395 1367 \
--x_bias_all 0 \
--x_bias_all 0 \
--x_bias_all 10 20 30 20 10 0 \
--x_bias_all 10 20 30 20 10 0 \
--x_bias_all 10 20 30 20 10 0 \
--y_bias_all 0 \
--y_bias_all 0 \
--y_bias_all 0 \
--y_bias_all 0 \
--y_bias_all 0 

================================================
FILE: configs/examples/multi_region/woman.sh
================================================
name="svd-example2[fps6_mb127-temp]"
config="configs/inference/config_test.yaml"
ckpt="ckpt/model.ckpt"
image_input="testset/woman"
path_ref="testset/reference/woman.png"
res_dir="outputs"

python3 main/inference/sample_multi_region.py \
--seed 23 \
--ckpt $ckpt \
--config $config \
--savedir $res_dir/$name \
--savefps 10 \
--ddim_steps 15 \
--frames 14 \
--savefps 10 \
--input $image_input \
--path_ref $path_ref \
--fps 6 \
--motion 127 \
--cond_aug 0.02 \
--decoding_t 1 --resize \
--s_h 437 1624 2140 85 328 1921 \
--e_h 752 1968 2399 402 728 2336 \
--s_w 324 339 434 243 565 223 \
--e_w 597 515 648 442 771 412 \
--ps_h 531 1718 2209 169 409 2007 \
--pe_h 681 1900 2331 318 552 2261 \
--ps_w 397 396 589 284 646 292 \
--pe_w 523 433 495 381 697 353 \
--x_bias_all 0 \
--x_bias_all 0 \
--x_bias_all 0 \
--x_bias_all 0 \
--x_bias_all 0 \
--x_bias_all 0 \
--y_bias_all 0 \
--y_bias_all 0 \
--y_bias_all 0 \
--y_bias_all 0 \
--y_bias_all 0 \
--y_bias_all 0

================================================
FILE: configs/examples/single_region/desert.sh
================================================
name="svd-example2[fps6_mb127-temp]"
config="configs/inference/config_test.yaml"
ckpt="ckpt/model.ckpt"
image_input="testset/desert"
path_ref="testset/reference/desert.png"
res_dir="outputs"

python3 main/inference/sample_single_region.py \
--seed 23 \
--ckpt $ckpt \
--config $config \
--savedir $res_dir/$name \
--savefps 10 \
--ddim_steps 15 \
--frames 14 \
--savefps 10 \
--input $image_input \
--path_ref $path_ref \
--fps 10 \
--motion 127 \
--cond_aug 0.02 \
--decoding_t 1 --resize \
--s_h 1021 \
--e_h 2428 \
--s_w 1290 \
--e_w 1896 \
--ps_h 1805 \
--pe_h 1935 \
--ps_w 1585 \
--pe_w 1554 \
--x_bias 0 \
--y_bias 0


================================================
FILE: configs/examples/single_region/dog.sh
================================================
name="svd-example2[fps6_mb127-temp]"
config="configs/inference/config_test.yaml"
ckpt="ckpt/model.ckpt"
image_input="testset/dog"
path_ref="testset/reference/dog.png"
res_dir="outputs"

python3 main/inference/sample_single_region.py \
--seed 23 \
--ckpt $ckpt \
--config $config \
--savedir $res_dir/$name \
--savefps 10 \
--ddim_steps 15 \
--frames 14 \
--savefps 10 \
--input $image_input \
--path_ref $path_ref \
--fps 10 \
--motion 127 \
--cond_aug 0.02 \
--decoding_t 1 --resize \
--s_h 83 \
--e_h 1477 \
--s_w 97 \
--e_w 982 \
--ps_h 803 608 1022 \
--pe_h 574 529 811 \
--ps_w 429 261 340 \
--pe_w 429 261 340 \
--x_bias 0 \
--y_bias 0


================================================
FILE: configs/examples/single_region/football.sh
================================================
name="svd-example2[fps6_mb127-temp]"
config="configs/inference/config_test.yaml"
ckpt="ckpt/model.ckpt"
image_input="testset/football"
path_ref="testset/reference/football.png"
res_dir="outputs"

python3 main/inference/sample_single_region.py \
--seed 23 \
--ckpt $ckpt \
--config $config \
--savedir $res_dir/$name \
--savefps 10 \
--ddim_steps 15 \
--frames 14 \
--savefps 10 \
--input $image_input \
--path_ref $path_ref \
--fps 10 \
--motion 127 \
--cond_aug 0.02 \
--decoding_t 1 --resize \
--s_h 105 \
--e_h 1119 \
--s_w 910 \
--e_w 1301 \
--ps_h 629 \
--pe_h 789 \
--ps_w 1075 \
--pe_w 1075 \
--x_bias 0 \
--y_bias 0


================================================
FILE: configs/examples/single_region/forest.sh
================================================
name="svd-example2[fps6_mb127-temp]"
config="configs/inference/config_test.yaml"
ckpt="ckpt/model.ckpt"
image_input="testset/forest"
path_ref="testset/reference/forest.png"
res_dir="outputs"

python3 main/inference/sample_single_region.py \
--seed 23 \
--ckpt $ckpt \
--config $config \
--savedir $res_dir/$name \
--savefps 10 \
--ddim_steps 15 \
--frames 14 \
--savefps 10 \
--input $image_input \
--path_ref $path_ref \
--fps 10 \
--motion 127 \
--cond_aug 0.02 \
--decoding_t 1 --resize \
--s_h 1045 \
--e_h 1497 \
--s_w 1511 \
--e_w 2073 \
--ps_h 1275 \
--pe_h 1275 \
--ps_w 1695 \
--pe_w 1695 \
--x_bias 10 20 10 0 \
--y_bias 0 0 0 0


================================================
FILE: configs/examples/single_region/head5.sh
================================================
name="svd-example2[fps6_mb127-temp]"
config="configs/inference/config_test.yaml"
ckpt="ckpt/model.ckpt"
image_input="testset/head5"
path_ref="testset/reference/head5.png"
res_dir="outputs"

python3 main/inference/sample_single_region.py \
--seed 23 \
--ckpt $ckpt \
--config $config \
--savedir $res_dir/$name \
--savefps 10 \
--ddim_steps 15 \
--frames 14 \
--savefps 10 \
--input $image_input \
--path_ref $path_ref \
--fps 10 \
--motion 127 \
--cond_aug 0.02 \
--decoding_t 1 --resize \
--s_h 828 \
--e_h 1494 \
--s_w 0 \
--e_w 766 \
--ps_h 1157 \
--pe_h 1356 \
--ps_w 529 \
--pe_w 529 \
--x_bias 0 0 0 0 \
--y_bias 10 20 10 0


================================================
FILE: configs/examples/single_region/lawn.sh
================================================
name="svd-example2[fps6_mb127-temp]"
config="configs/inference/config_test.yaml"
ckpt="ckpt/model.ckpt"
image_input="testset/lawn"
path_ref="testset/reference/lawn.png"
res_dir="outputs"

python3 main/inference/sample_single_region.py \
--seed 23 \
--ckpt $ckpt \
--config $config \
--savedir $res_dir/$name \
--savefps 10 \
--ddim_steps 15 \
--frames 14 \
--savefps 10 \
--input $image_input \
--path_ref $path_ref \
--fps 10 \
--motion 127 \
--cond_aug 0.02 \
--decoding_t 1 --resize \
--s_h 383 \
--e_h 1109 \
--s_w 485 \
--e_w 1627 \
--ps_h 717 741 \
--pe_h 717 741 \
--ps_w 731 945 \
--pe_w 781 995 \
--x_bias 15 30 45 30 15 0 \
--y_bias 0


================================================
FILE: configs/examples/single_region/lizard.sh
================================================
name="svd-example2[fps6_mb127-temp]"
config="configs/inference/config_test.yaml"
ckpt="ckpt/model.ckpt"
image_input="testset/lizard"
path_ref="testset/reference/lizard.png"
res_dir="outputs"

python3 main/inference/sample_single_region.py \
--seed 23 \
--ckpt $ckpt \
--config $config \
--savedir $res_dir/$name \
--savefps 10 \
--ddim_steps 15 \
--frames 14 \
--savefps 10 \
--input $image_input \
--path_ref $path_ref \
--fps 10 \
--motion 127 \
--cond_aug 0.02 \
--decoding_t 1 --resize \
--s_h 152 \
--e_h 1062 \
--s_w 248 \
--e_w 1021 \
--ps_h 520 \
--pe_h 700 \
--ps_w 542 \
--pe_w 542 \
--x_bias 0 \
--y_bias 0


================================================
FILE: configs/examples/single_region/road.sh
================================================
name="svd-example2[fps6_mb127-temp]"
config="configs/inference/config_test.yaml"
ckpt="ckpt/model.ckpt"
image_input="testset/road"
path_ref="testset/reference/road.png"
res_dir="outputs"

python3 main/inference/sample_single_region.py \
--seed 23 \
--ckpt $ckpt \
--config $config \
--savedir $res_dir/$name \
--savefps 10 \
--ddim_steps 15 \
--frames 14 \
--savefps 10 \
--input $image_input \
--path_ref $path_ref \
--fps 10 \
--motion 127 \
--cond_aug 0.02 \
--decoding_t 1 --resize \
--s_h 886 \
--e_h 1235 \
--s_w 704 \
--e_w 944 \
--ps_h 1115 \
--pe_h 1046 \
--ps_w 790 \
--pe_w 881 \
--x_bias 0 \
--y_bias 0


================================================
FILE: configs/examples/single_region/sea.sh
================================================
name="svd-example2[fps6_mb127-temp]"
config="configs/inference/config_test.yaml"
ckpt="ckpt/model.ckpt"
image_input="testset/sea2"
path_ref="testset/reference/sea2.png"
res_dir="outputs"

python3 main/inference/sample_single_region.py \
--seed 23 \
--ckpt $ckpt \
--config $config \
--savedir $res_dir/$name \
--savefps 10 \
--ddim_steps 15 \
--frames 14 \
--savefps 10 \
--input $image_input \
--path_ref $path_ref \
--fps 10 \
--motion 127 \
--cond_aug 0.02 \
--decoding_t 1 --resize \
--s_h 547 \
--e_h 1400 \
--s_w 488 \
--e_w 882 \
--ps_h 939 \
--pe_h 1026 \
--ps_w 649 \
--pe_w 687 \
--x_bias 0 0 0 0 \
--y_bias 10 20 10 0


================================================
FILE: configs/examples/single_region/sea2.sh
================================================
name="svd-example2[fps6_mb127-temp]"
config="configs/inference/config_test.yaml"
ckpt="ckpt/model.ckpt"
image_input="testset/sea2"
path_ref="testset/reference/sea2_2.png"
res_dir="outputs"

python3 main/inference/sample_single_region.py \
--seed 23 \
--ckpt $ckpt \
--config $config \
--savedir $res_dir/$name \
--savefps 10 \
--ddim_steps 15 \
--frames 14 \
--savefps 10 \
--input $image_input \
--path_ref $path_ref \
--fps 10 \
--motion 127 \
--cond_aug 0.02 \
--decoding_t 1 --resize \
--s_h 306 \
--e_h 1077 \
--s_w 617 \
--e_w 962 \
--ps_h 737 \
--pe_h 485 \
--ps_w 761 \
--pe_w 764 \
--x_bias 0 \
--y_bias 0


================================================
FILE: configs/examples/single_region/sky.sh
================================================
name="svd-example2[fps6_mb127-temp]"
config="configs/inference/config_test.yaml"
ckpt="ckpt/model.ckpt"
image_input="testset/sky"
path_ref="testset/reference/sky.png"
res_dir="outputs"

python3 main/inference/sample_single_region.py \
--seed 23 \
--ckpt $ckpt \
--config $config \
--savedir $res_dir/$name \
--savefps 10 \
--ddim_steps 15 \
--frames 14 \
--savefps 10 \
--input $image_input \
--path_ref $path_ref \
--fps 10 \
--motion 127 \
--cond_aug 0.02 \
--decoding_t 1 --resize \
--s_h 580 \
--e_h 1037 \
--s_w 397 \
--e_w 607 \
--ps_h 780 \
--pe_h 709 \
--ps_w 516 \
--pe_w 532 \
--x_bias 0 \
--y_bias 0


================================================
FILE: configs/examples/single_region/woman4.sh
================================================
name="svd-example2[fps6_mb127-temp]"
config="configs/inference/config_test.yaml"
ckpt="ckpt/model.ckpt"
image_input="testset/woman4"
path_ref="testset/reference/woman4.png"
res_dir="outputs"

python3 main/inference/sample_single_region.py \
--seed 23 \
--ckpt $ckpt \
--config $config \
--savedir $res_dir/$name \
--savefps 10 \
--ddim_steps 15 \
--frames 14 \
--savefps 10 \
--input $image_input \
--path_ref $path_ref \
--fps 10 \
--motion 127 \
--cond_aug 0.02 \
--decoding_t 1 --resize \
--s_h 94 \
--e_h 613 \
--s_w 236 \
--e_w 952 \
--ps_h 278 \
--pe_h 420 \
--ps_w 640 \
--pe_w 640 \
--x_bias 0 \
--y_bias 0


================================================
FILE: configs/inference/config_test.yaml
================================================
num_frames: &num_frames 14
model:
  base_learning_rate: 3.0e-5
  target: ctrl_model.diffusion_ctrl.DiffusionEngineCtrl
  params:
    scale_factor: 0.18215
    input_key: video
    items: ['canny', 'depth', 'sketch', 'org']
    probabilities: [0.3, 0.3, 0.3, 0.1]
    no_cond_log: true
    en_and_decode_n_samples_a_time: 1
    use_ema: false
    disable_first_stage_autocast: true

    denoiser_config:
      target: sgm.modules.diffusionmodules.denoiser.Denoiser
      params:
        scaling_config:
          target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise

    network_config:
      target: ctrl_model.svd_ctrl.ControledVideoUnet
      params:
        num_frames: *num_frames
        adm_in_channels: 768
        num_classes: sequential
        use_checkpoint: true
        in_channels: 8
        out_channels: 4
        model_channels: 320
        attention_resolutions: [4, 2, 1]
        num_res_blocks: 2
        channel_mult: [1, 2, 4, 4]
        num_head_channels: 64
        use_linear_in_transformer: true
        transformer_depth: 1
        context_dim: 1024
        spatial_transformer_attn_type: softmax-xformers
        extra_ff_mix_layer: true
        use_spatial_context: true
        merge_strategy: learned_with_images
        video_kernel_size: [3, 1, 1]

    controlnet_config:
      target: ctrl_model.svd_ctrl.VideoCtrlNet
      params:
        num_frames: *num_frames
        adm_in_channels: 768
        num_classes: sequential
        use_checkpoint: true
        in_channels: 8
        model_channels: 320
        attention_resolutions: [4, 2, 1]
        num_res_blocks: 2
        channel_mult: [1, 2, 4, 4]
        num_head_channels: 64
        use_linear_in_transformer: true
        transformer_depth: 1
        context_dim: 1024
        spatial_transformer_attn_type: softmax-xformers
        extra_ff_mix_layer: true
        use_spatial_context: true
        merge_strategy: learned_with_images
        video_kernel_size: [3, 1, 1]
        conditioning_channels: 2
        conditioning_channels_mask: 3
        conditioning_channels_region: 1
        conditioning_embedding_out_channels: [16, 32, 96, 256]
        ctrlnet_scale: 1
        if_init_from_unet: false
        ctrlnet_ckpt_path: ckpt/model.ckpt

    conditioner_config:
      target: sgm.modules.GeneralConditioner
      params:
        emb_models:
        - is_trainable: false
          input_key: cond_frames_without_noise
          ucg_rate: 0.1
          target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
          params:
            n_cond_frames: 1
            n_copies: 1
            open_clip_embedding_config:
              target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
              params:
                freeze: true
                # version: "ckpt/clip-h-14/models--laion--CLIP-ViT-H-14-laion2B-s32B-b79K/blobs/9a78ef8e8c73fd0df621682e7a8e8eb36c6916cb3c16b291a082ecd52ab79cc4"

        - input_key: fps_id
          is_trainable: false
          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
          params:
            outdim: 256

        - input_key: motion_bucket_id
          is_trainable: false
          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
          params:
            outdim: 256

        - input_key: cond_frames
          is_trainable: false
          ucg_rate: 0.1
          target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
          params:
            disable_encoder_autocast: true
            n_cond_frames: 1
            n_copies: 1
            is_ae: true
            encoder_config:
              target: sgm.models.autoencoder.AutoencoderKLModeOnly
              params:
                embed_dim: 4
                monitor: val/rec_loss
                ddconfig:
                  attn_type: vanilla-xformers
                  double_z: true
                  z_channels: 4
                  resolution: 256
                  in_channels: 3
                  out_ch: 3
                  ch: 128
                  ch_mult: [1, 2, 4, 4]
                  num_res_blocks: 2
                  attn_resolutions: []
                  dropout: 0.0
                lossconfig:
                  target: torch.nn.Identity

        - input_key: cond_aug
          is_trainable: false
          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
          params:
            outdim: 256

    first_stage_config:
      target: sgm.models.autoencoder.AutoencodingEngine
      params:
        loss_config:
          target: torch.nn.Identity
        regularizer_config:
          target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
        encoder_config: 
          target: sgm.modules.diffusionmodules.model.Encoder
          params:
            attn_type: vanilla
            double_z: true
            z_channels: 4
            resolution: 256
            in_channels: 3
            out_ch: 3
            ch: 128
            ch_mult: [1, 2, 4, 4]
            num_res_blocks: 2
            attn_resolutions: []
            dropout: 0.0
        decoder_config:
          target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
          params:
            attn_type: vanilla
            double_z: true
            z_channels: 4
            resolution: 256
            in_channels: 3
            out_ch: 3
            ch: 128
            ch_mult: [1, 2, 4, 4]
            num_res_blocks: 2
            attn_resolutions: []
            dropout: 0.0
            video_kernel_size: [3, 1, 1]

    sampler_config:
      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
      params:
        num_steps: 25
        discretization_config:
          target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
          params:
            sigma_max: 700.0

        guider_config:
          target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
          params:
            num_frames: *num_frames
            max_scale: 4 #2.5
            min_scale: 1 #1.0
            additional_cond_keys: ["ctrl_input", "mask", "region"]

================================================
FILE: ctrl_model/diffusion_ctrl.py
================================================
from functools import partial

from typing import Any, Dict, List, Optional, Tuple, Union

from einops import rearrange, repeat

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch as th

from omegaconf import ListConfig, OmegaConf
from safetensors.torch import load_file as load_safetensors
from torch.optim.lr_scheduler import LambdaLR
from sgm.models.diffusion import DiffusionEngine
from sgm.util import (default, disabled_train, get_obj_from_str,
                    instantiate_from_config, log_txt_as_img)
from sgm.modules.diffusionmodules.wrappers import OpenAIWrapper
import cv2
import numpy as np
from utils.save_video import save_rgb_video, save_flow_video

class DiffusionEngineCtrl(DiffusionEngine):
    def __init__(
        self,
        network_config,
        denoiser_config,
        first_stage_config,
        controlnet_config: Optional[Dict] = None,
        ctrlnet_key: Optional[str] = None,
        items = None,
        probabilities = None,
        conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
        sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
        optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
        scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
        loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
        network_wrapper: Union[None, str] = None,
        ckpt_path: Union[None, str] = None,
        use_ema: bool = False,
        ema_decay_rate: float = 0.9999,
        scale_factor: float = 1.0,
        disable_first_stage_autocast=False,
        input_key: str = "jpg",
        log_keys: Union[List, None] = None,
        no_cond_log: bool = False,
        compile_model: bool = False,
        en_and_decode_n_samples_a_time: Optional[int] = None,
        kernel_size = 199,
        sigma = 20,
    ):
        super().__init__(
            network_config,
            denoiser_config,
            first_stage_config,
            conditioner_config,
            sampler_config,
            optimizer_config,
            scheduler_config,
            loss_fn_config,
            network_wrapper,
            ckpt_path,
            use_ema,
            ema_decay_rate,
            scale_factor,
            disable_first_stage_autocast,
            input_key,
            log_keys,
            no_cond_log,
            compile_model,
            en_and_decode_n_samples_a_time,
        )
        self.items = items
        self.probabilities = probabilities
        ctrlnet_model = instantiate_from_config(controlnet_config)
        if ctrlnet_model.ctrlnet_ckpt_path is not None:
            ctrlnet_model.init_from_ckpt()
        elif ctrlnet_model.if_init_from_unet == True:
            ctrlnet_model.init_from_unet(self.model.diffusion_model)
        else:
            print('random initial UNet')
        self.ctrlnet_key = ctrlnet_key
        self.model = CtrlNetWrapper(
            self.model.diffusion_model,
            compile_model=False,  # the UNet may be compiled in the super().__init__()
            ctrlnet_model=ctrlnet_model,
        )
        

class CtrlNetWrapper(OpenAIWrapper):
    def __init__(self, diffusion_model, compile_model: bool = False, ctrlnet_model: nn.Module = None):
        super().__init__(diffusion_model, compile_model)
        self.ctrlnet_model = ctrlnet_model

    def forward(
        self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
    ) -> torch.Tensor:
        x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
        ctrlnet_cond = c.get("ctrl_input")
        mask = c.get("mask")
        region = c.get("region")
        assert ctrlnet_cond is not None, "Input SVD CtrlNet conditon is None!!!"
        down_block_res_samples, mid_block_res_sample = self.ctrlnet_model(
            x,
            mask=mask,
            region=region,
            conds=ctrlnet_cond,
            timesteps=t,
            context=c.get("crossattn", None),
            y=c.get("vector", None),
            **kwargs,
        )
        return self.diffusion_model(
            x,
            timesteps=t,
            context=c.get("crossattn", None),
            y=c.get("vector", None),
            mid_block_additional_residual=mid_block_res_sample,
            down_block_additional_residuals=down_block_res_samples,
            **kwargs,
        )

================================================
FILE: ctrl_model/svd_ctrl.py
================================================
from functools import partial

from typing import List, Optional, Union, Tuple

from einops import rearrange

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch as th

from sgm.modules.diffusionmodules.openaimodel import *
from sgm.modules.video_attention import SpatialVideoTransformer
from sgm.util import default
from sgm.modules.diffusionmodules.util import AlphaBlender
from sgm.modules.diffusionmodules.video_model import VideoResBlock, VideoUNet
import numpy as np
import cv2

class ControledVideoUnet(VideoUNet):
    """
    Only modify the forward function by adding additional controls
    """
    def forward(
        self,
        x: th.Tensor,
        timesteps: th.Tensor,
        context: Optional[th.Tensor] = None,
        y: Optional[th.Tensor] = None,
        time_context: Optional[th.Tensor] = None,
        num_video_frames: Optional[int] = None,
        image_only_indicator: Optional[th.Tensor] = None,
        mid_block_additional_residual: Optional[torch.Tensor] = None,
        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None
    ):
        assert (y is not None) == (
            self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
        emb = self.time_embed(t_emb)
        ## tbd: check the role of "image_only_indicator"
        num_video_frames = self.num_frames
        image_only_indicator = torch.zeros(
                    x.shape[0]//num_video_frames, num_video_frames
                ).to(x.device) if image_only_indicator is None else image_only_indicator

        if self.num_classes is not None:
            assert y.shape[0] == x.shape[0]
            emb = emb + self.label_emb(y)

        ## x shape: [bt,c,h,w]
        h = x
        hs = []
        for module in self.input_blocks:
            h = module(
                h,
                emb,
                context=context,
                image_only_indicator=image_only_indicator,
                time_context=time_context,
                num_video_frames=num_video_frames,
            )
            hs.append(h)
        h = self.middle_block(
            h,
            emb,
            context=context,
            image_only_indicator=image_only_indicator,
            time_context=time_context,
            num_video_frames=num_video_frames,
        )
        # svd ctrl
        if mid_block_additional_residual is not None:
            h = h + mid_block_additional_residual
        for module in self.output_blocks:
            if down_block_additional_residuals is not None:
                h = th.cat([h, hs.pop() + down_block_additional_residuals.pop()], dim=1)
            else:
                h = th.cat([h, hs.pop()], dim=1)
            h = module(
                h,
                emb,
                context=context,
                image_only_indicator=image_only_indicator,
                time_context=time_context,
                num_video_frames=num_video_frames,
            )

        return self.out(h)


class ControlNetConditioningEmbedding(nn.Module):
    def __init__(
        self,
        conditioning_embedding_channels: int,
        conditioning_channels: int = 3,
        block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
    ):
        super().__init__()

        self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)

        self.blocks = nn.ModuleList([])

        for i in range(len(block_out_channels) - 1):
            channel_in = block_out_channels[i]
            channel_out = block_out_channels[i + 1]
            self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
            self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))

        self.conv_out = zero_module(
            nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
        )

    def forward(self, conditioning):
        embedding = self.conv_in(conditioning)
        embedding = F.silu(embedding)

        for block in self.blocks:
            embedding = block(embedding)
            embedding = F.silu(embedding)

        embedding = self.conv_out(embedding)

        return embedding


class MaskEmbedding(nn.Module):
    def __init__(
        self,
        conditioning_embedding_channels: int,
        conditioning_channels: int = 3,
        block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
    ):
        super().__init__()

        self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)

        self.blocks = nn.ModuleList([])

        for i in range(len(block_out_channels) - 1):
            channel_in = block_out_channels[i]
            channel_out = block_out_channels[i + 1]
            self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
            self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))

        self.conv_out = zero_module(
            nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
        )

    def forward(self, conditioning):
        embedding = self.conv_in(conditioning)
        embedding = F.silu(embedding)

        for block in self.blocks:
            embedding = block(embedding)
            embedding = F.silu(embedding)

        embedding = self.conv_out(embedding)

        return embedding


class WeightEmbedding(nn.Module):
    def __init__(
        self,
        conditioning_embedding_channels: int,
        conditioning_channels: int = 3,
        block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
    ):
        super().__init__()

        self.conv_in = nn.Conv2d(conditioning_channels*2, block_out_channels[0], kernel_size=3, padding=1)

        self.blocks = nn.ModuleList([])

        for i in range(len(block_out_channels) - 1):
            channel_in = block_out_channels[i]
            channel_out = block_out_channels[i + 1]
            self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
            self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))

        self.conv_out = zero_module(
            nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
        )
        self.sig = nn.Sigmoid()
        self.conv_final = nn.Conv2d(conditioning_embedding_channels, conditioning_embedding_channels, kernel_size=3, padding=1)

    def forward(self, conditioning, t, cond_embeddings, mask_embeddings):
        t_cond = t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(conditioning.shape)
        conditioning = torch.cat([conditioning, t_cond], dim=1)
        w = self.conv_in(conditioning)
        w = F.silu(w)

        for block in self.blocks:
            w = block(w)
            w = F.silu(w)

        w = self.conv_out(w)
        w = self.sig(w)

        embedding = cond_embeddings*w+mask_embeddings*(1-w)
        embedding = self.conv_final(embedding)

        return embedding


class VideoCtrlNet(nn.Module):
    def __init__(
        self,
        in_channels: int,
        model_channels: int,
        num_frames: int,
        num_res_blocks: int,
        attention_resolutions: int,
        dropout: float = 0.0,
        channel_mult: List[int] = (1, 2, 4, 8),
        conv_resample: bool = True,
        dims: int = 2,
        num_classes: Optional[int] = None,
        use_checkpoint: bool = False,
        num_heads: int = -1,
        num_head_channels: int = -1,
        num_heads_upsample: int = -1,
        use_scale_shift_norm: bool = False,
        resblock_updown: bool = False,
        transformer_depth: Union[List[int], int] = 1,
        transformer_depth_middle: Optional[int] = None,
        context_dim: Optional[int] = None,
        time_downup: bool = False,
        time_context_dim: Optional[int] = None,
        extra_ff_mix_layer: bool = False,
        use_spatial_context: bool = False,
        merge_strategy: str = "fixed",
        merge_factor: float = 0.5,
        spatial_transformer_attn_type: str = "softmax",
        video_kernel_size: Union[int, List[int]] = 3,
        use_linear_in_transformer: bool = False,
        adm_in_channels: Optional[int] = None,
        disable_temporal_crossattention: bool = False,
        max_ddpm_temb_period: int = 10000,
        # ctrlnet
        conditioning_channels: int = 3,
        conditioning_channels_mask: int = 1,
        conditioning_channels_region: int = 1,
        conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
        ctrlnet_scale: float = 1.0,
        ctrlnet_ckpt_path: Optional[str] = None,
        if_init_from_unet = True,
    ):
        super(VideoCtrlNet, self).__init__()
        assert context_dim is not None

        if num_heads_upsample == -1:
            num_heads_upsample = num_heads

        if num_heads == -1:
            assert num_head_channels != -1

        if num_head_channels == -1:
            assert num_heads != -1

        self.in_channels = in_channels
        self.model_channels = model_channels
        self.num_frames = num_frames
        if isinstance(transformer_depth, int):
            transformer_depth = len(channel_mult) * [transformer_depth]
        transformer_depth_middle = default(
            transformer_depth_middle, transformer_depth[-1]
        )
        self.dims = dims
        self.num_res_blocks = num_res_blocks
        self.attention_resolutions = attention_resolutions
        self.dropout = dropout
        self.channel_mult = channel_mult
        self.conv_resample = conv_resample
        self.num_classes = num_classes
        self.use_checkpoint = use_checkpoint
        self.if_init_from_unet = if_init_from_unet
        self.num_heads = num_heads
        self.num_head_channels = num_head_channels
        self.num_heads_upsample = num_heads_upsample

        self.ctrlnet_scale = ctrlnet_scale
        self.ctrlnet_ckpt_path = ctrlnet_ckpt_path

        time_embed_dim = model_channels * 4
        self.time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            nn.SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )

        if self.num_classes is not None:
            if isinstance(self.num_classes, int):
                self.label_emb = nn.Embedding(num_classes, time_embed_dim)
            elif self.num_classes == "continuous":
                print("setting up linear c_adm embedding layer")
                self.label_emb = nn.Linear(1, time_embed_dim)
            elif self.num_classes == "timestep":
                self.label_emb = nn.Sequential(
                    Timestep(model_channels),
                    nn.Sequential(
                        linear(model_channels, time_embed_dim),
                        nn.SiLU(),
                        linear(time_embed_dim, time_embed_dim),
                    ),
                )
            elif self.num_classes == "sequential":
                assert adm_in_channels is not None
                self.label_emb = nn.Sequential(
                    nn.Sequential(
                        linear(adm_in_channels, time_embed_dim),
                        nn.SiLU(),
                        linear(time_embed_dim, time_embed_dim),
                    )
                )
            else:
                raise ValueError()

        # control net conditioning embedding
        self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
            conditioning_embedding_channels=model_channels,
            block_out_channels=conditioning_embedding_out_channels,
            conditioning_channels=conditioning_channels,
        )
        self.mask_embedding = MaskEmbedding(
            conditioning_embedding_channels=model_channels,
            block_out_channels=conditioning_embedding_out_channels,
            conditioning_channels=conditioning_channels_mask,
        )
        self.weight_embedding = WeightEmbedding(
            conditioning_embedding_channels=model_channels,
            block_out_channels=conditioning_embedding_out_channels,
            conditioning_channels=conditioning_channels_region,
        )

        self.controlnet_down_blocks = nn.ModuleList([
            self.make_zero_conv(model_channels),
        ])

        self.input_blocks = nn.ModuleList(
            [
                TimestepEmbedSequential(
                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
                )
            ]
        )
        self._feature_size = model_channels
        input_block_chans = [model_channels]
        ch = model_channels
        ds = 1

        def get_attention_layer(
            ch,
            num_heads,
            dim_head,
            depth=1,
            context_dim=None,
            use_checkpoint=False,
            disabled_sa=False,
        ):
            return SpatialVideoTransformer(
                ch,
                num_heads,
                dim_head,
                depth=depth,
                context_dim=context_dim,
                time_context_dim=time_context_dim,
                dropout=dropout,
                ff_in=extra_ff_mix_layer,
                use_spatial_context=use_spatial_context,
                merge_strategy=merge_strategy,
                merge_factor=merge_factor,
                checkpoint=use_checkpoint,
                use_linear=use_linear_in_transformer,
                attn_mode=spatial_transformer_attn_type,
                disable_self_attn=disabled_sa,
                disable_temporal_crossattention=disable_temporal_crossattention,
                max_time_embed_period=max_ddpm_temb_period,
            )

        def get_resblock(
            merge_factor,
            merge_strategy,
            video_kernel_size,
            ch,
            time_embed_dim,
            dropout,
            out_ch,
            dims,
            use_checkpoint,
            use_scale_shift_norm,
            down=False,
            up=False,
        ):
            return VideoResBlock(
                merge_factor=merge_factor,
                merge_strategy=merge_strategy,
                video_kernel_size=video_kernel_size,
                channels=ch,
                emb_channels=time_embed_dim,
                dropout=dropout,
                out_channels=out_ch,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
                down=down,
                up=up,
            )

        for level, mult in enumerate(channel_mult):
            for _ in range(num_res_blocks):
                layers = [
                    get_resblock(
                        merge_factor=merge_factor,
                        merge_strategy=merge_strategy,
                        video_kernel_size=video_kernel_size,
                        ch=ch,
                        time_embed_dim=time_embed_dim,
                        dropout=dropout,
                        out_ch=mult * model_channels,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = mult * model_channels
                if ds in attention_resolutions:
                    if num_head_channels == -1:
                        dim_head = ch // num_heads
                    else:
                        num_heads = ch // num_head_channels
                        dim_head = num_head_channels

                    layers.append(
                        get_attention_layer(
                            ch,
                            num_heads,
                            dim_head,
                            depth=transformer_depth[level],
                            context_dim=context_dim,
                            use_checkpoint=use_checkpoint,
                            disabled_sa=False,
                        )
                    )
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                # ctrlnet blocks
                self.controlnet_down_blocks.append(self.make_zero_conv(ch))
                self._feature_size += ch
                input_block_chans.append(ch)
            if level != len(channel_mult) - 1:
                ds *= 2
                out_ch = ch
                self.input_blocks.append(
                    TimestepEmbedSequential(
                        get_resblock(
                            merge_factor=merge_factor,
                            merge_strategy=merge_strategy,
                            video_kernel_size=video_kernel_size,
                            ch=ch,
                            time_embed_dim=time_embed_dim,
                            dropout=dropout,
                            out_ch=out_ch,
                            dims=dims,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            down=True,
                        )
                        if resblock_updown
                        else Downsample(
                            ch,
                            conv_resample,
                            dims=dims,
                            out_channels=out_ch,
                            third_down=time_downup,
                        )
                    )
                )
                ch = out_ch
                input_block_chans.append(ch)
                self.controlnet_down_blocks.append(self.make_zero_conv(ch))
                self._feature_size += ch
                # self.step_cur = 0

        if num_head_channels == -1:
            dim_head = ch // num_heads
        else:
            num_heads = ch // num_head_channels
            dim_head = num_head_channels

        # ctrlnet mid block
        self.controlnet_mid_block = self.make_zero_conv(ch)

        self.middle_block = TimestepEmbedSequential(
            get_resblock(
                merge_factor=merge_factor,
                merge_strategy=merge_strategy,
                video_kernel_size=video_kernel_size,
                ch=ch,
                time_embed_dim=time_embed_dim,
                out_ch=None,
                dropout=dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
            get_attention_layer(
                ch,
                num_heads,
                dim_head,
                depth=transformer_depth_middle,
                context_dim=context_dim,
                use_checkpoint=use_checkpoint,
            ),
            get_resblock(
                merge_factor=merge_factor,
                merge_strategy=merge_strategy,
                video_kernel_size=video_kernel_size,
                ch=ch,
                out_ch=None,
                time_embed_dim=time_embed_dim,
                dropout=dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
        )
        self._feature_size += ch

    def make_zero_conv(self, channels):
        return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))

    def init_from_ckpt(self, ckpt_path=None):
        path = self.ctrlnet_ckpt_path
        if ckpt_path is not None:
            path = ckpt_path

        if path.endswith("ckpt"):
            sd = torch.load(path, map_location="cpu")["state_dict"]
            sd = filter(lambda x: x[0].startswith("model.ctrlnet_model"), sd.items())
            sd = {k.replace("model.ctrlnet_model.", ""): v for k, v in sd}
        elif path.endswith("bin"):
            sd = torch.load(path, map_location="cpu")
        elif path.endswith("safetensors"):
            sd = load_safetensors(path)
        else:
            raise NotImplementedError

        missing, unexpected = self.load_state_dict(sd, strict=False)
        print(
            f"@CtrlNet: Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
        )
        if len(missing) > 0:
            print(f"Missing Keys: {missing}")
        if len(unexpected) > 0:
            print(f"Unexpected Keys: {unexpected}")

    def init_from_unet(self, unet):
        self.input_blocks.load_state_dict(unet.input_blocks.state_dict())
        self.middle_block.load_state_dict(unet.middle_block.state_dict())
        print("init from unet successfully!")

    def forward(
        self,
        x: th.Tensor,
        conds: th.Tensor,
        mask: th.Tensor,
        region: th.Tensor,
        timesteps: th.Tensor,
        context: Optional[th.Tensor] = None,
        y: Optional[th.Tensor] = None,
        time_context: Optional[th.Tensor] = None,
        num_video_frames: Optional[int] = None,
        image_only_indicator: Optional[th.Tensor] = None,
    ):
        assert (y is not None) == (
            self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
        emb = self.time_embed(t_emb)

        ## tbd: check the role of "image_only_indicator"
        num_video_frames = self.num_frames
        image_only_indicator = torch.zeros(
                    x.shape[0]//num_video_frames, num_video_frames
                ).to(x.device) if image_only_indicator is None else image_only_indicator

        if self.num_classes is not None:
            assert y.shape[0] == x.shape[0]
            emb = emb + self.label_emb(y)

        # ctrlnet
        cond_embeddings = self.controlnet_cond_embedding(conds)
        mask = mask.permute(0,2,1,3,4)
        mask = mask.reshape(mask.shape[0]*mask.shape[1], mask.shape[2], mask.shape[3], mask.shape[4])
        mask_embeddings = self.mask_embedding(mask)
        cond_embeddings = self.weight_embedding(region, timesteps, cond_embeddings, mask_embeddings) ##*0.1

        ## x shape: [bt,c,h,w]
        h = x
        down_block_res_samples = []
        for module, zero_module in zip(self.input_blocks, self.controlnet_down_blocks):
            if cond_embeddings is not None:
                # for the conv_in block
                h = module(
                    h,
                    emb,
                    context=context,
                    image_only_indicator=image_only_indicator,
                    time_context=time_context,
                    num_video_frames=num_video_frames,
                )
                h += cond_embeddings
                cond_embeddings = None
            else:
                h = module(
                    h,
                    emb,
                    context=context,
                    image_only_indicator=image_only_indicator,
                    time_context=time_context,
                    num_video_frames=num_video_frames,
                )
            down_block_res_samples.append(
                zero_module(
                    h,
                    emb,
                    context=context,
                    image_only_indicator=image_only_indicator,
                    time_context=time_context,
                    num_video_frames=num_video_frames,
                ) * self.ctrlnet_scale
            )

        h = self.middle_block(
            h,
            emb,
            context=context,
            image_only_indicator=image_only_indicator,
            time_context=time_context,
            num_video_frames=num_video_frames,
        )
        mid_block_res_sample = self.controlnet_mid_block(
            h,
            emb,
            context=context,
            image_only_indicator=image_only_indicator,
            time_context=time_context,
            num_video_frames=num_video_frames,
        ) * self.ctrlnet_scale
        return down_block_res_samples, mid_block_res_sample




================================================
FILE: main/inference/sample_constant_motion.py
================================================
import datetime, time
import os, sys, argparse
import math
from glob import glob
from pathlib import Path
from typing import Optional

import cv2
import numpy as np
import torch
from einops import rearrange, repeat
from fire import Fire
from omegaconf import OmegaConf
from PIL import Image
from torchvision.transforms import ToTensor

sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
from sgm.util import default, instantiate_from_config
from utils.save_video import save_flow_video, save_rgb_video
import torch.nn as nn
from utils.visualizer import Visualizer
import time
from utils.tools import resize_pil_image, quick_freeze, get_gaussian_kernel, get_batch, get_unique_embedder_keys_from_conditioner, load_model

if not os.path.exists('ckpt'):
    os.makedirs('ckpt')
if not os.path.exists('ckpt/model.ckpt'):
    torch.hub.download_url_to_file(
        'https://huggingface.co/Adapter/ReVideo/resolve/main/model.ckpt',
        'ckpt/model.ckpt')

def sample(
    input_path: str = "outputs/inputs/test_image.png",  # Can either be image file or folder with image files
    path_ref: str = None,
    ckpt: str = "checkpoints/svd.safetensors",
    config: str = None,
    num_frames: Optional[int] = None,
    num_steps: Optional[int] = None,
    version: str = "svd",
    fps_id: int = 6,
    motion_bucket_id: int = 127,
    cond_aug: float = 0.02,
    seed: int = 23,
    decoding_t: int = 1,  # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
    device: str = "cuda",
    output_folder: Optional[str] = None,
    save_fps: int = 10,
    resize: Optional[bool] = False,
    # points = None
    s_w = None, 
    e_w = None, 
    s_h = None, 
    e_h = None,
    ps_h = None,
    ps_w = None,
):
    """
    Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
    image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
    """

    # flow
    cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker2").to(device)
    guassian_filter = quick_freeze(get_gaussian_kernel(kernel_size=51, sigma=10, channels=2)).to(device)
    visualizer = Visualizer(tracks_leave_trace=-1, show_first_frame=0, mode='cool', linewidth=2)
    visualizer_layer = Visualizer(tracks_leave_trace=-1, show_first_frame=0, mode='cool', linewidth=4)

    torch.manual_seed(seed)

    all_img_paths = os.listdir(input_path)
    all_img_paths.sort()
    for i in range(len(all_img_paths)):
        all_img_paths[i] = os.path.join(input_path, all_img_paths[i])

    print(f'loaded {len(all_img_paths)} images.')
    os.makedirs(output_folder, exist_ok=True)
    images = []
    for no, input_img_path in enumerate(all_img_paths):
        filepath, fullflname = os.path.split(input_img_path)
        filename, ext = os.path.splitext(fullflname)
        print(f'-sample {no+1}: {filename} ...')
        with Image.open(input_img_path) as image:
            if image.mode == "RGBA":
                image = image.convert("RGB")
            w_org, h_org = image.size
            image = resize_pil_image(image, max_resolution=1024*1024)
            # image = resize_pil_image(image, max_resolution=896*896)
            scale_h = image.size[0]/w_org
            scale_w = image.size[1]/h_org
            w, h = image.size

            if h % 64 != 0 or w % 64 != 0:
                width, height = map(lambda x: x - x % 64, (w, h))
                image = image.resize((width, height))
                print(
                    f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
                )

            image = ToTensor()(image)
            image = image * 2.0 - 1.0

        image = image.unsqueeze(0).to(device)
        H, W = image.shape[2:]
        assert image.shape[1] == 3
        F = 8
        C = 4
        shape = (num_frames, C, H // F, W // F)
        images.append(image)
    images = images[:num_frames]#[::-1]#.flip(dims=(1,))
    images = torch.stack(images, dim=2)
    save_rgb_video((images.flip(1)+1)/2.,'outputs/org.mp4')

    img_ref = Image.open(path_ref)
    if img_ref.mode == "RGBA":
        img_ref = img_ref.convert("RGB")
    img_ref = img_ref.resize((images.shape[-1], images.shape[-2]))
    img_ref = ToTensor()(img_ref)
    img_ref = img_ref * 2.0 - 1.0
    img_ref = img_ref.unsqueeze(0).to(device)

    with torch.no_grad():
        vid_cotracker = ((images+1)/2.).permute(0,2,1,3,4)*255.
        grid = torch.zeros(1,len(ps_h),3)
        for i in range(len(ps_h)):
            grid[0,i,1] = ps_h[i]*scale_h
            grid[0,i,2] = ps_w[i]*scale_w
        grid = grid.to(device)
        tracks, _ = cotracker(vid_cotracker, queries=grid) # B T N 2,  B T N 1
    
    layer = torch.ones_like(images.permute(0,2,1,3,4))
    res_video = visualizer_layer.visualize(layer.cpu()*255., tracks=tracks, save_video=False)
    res_video = (
        (rearrange(res_video[0], "t c h w -> t h w c"))
        .numpy()
        .astype(np.uint8)
    )
    frame = cv2.cvtColor(res_video[-1], cv2.COLOR_RGB2BGR)
    filename = input_path.split('/')[-1]
    cv2.imwrite(os.path.join('vis_im', filename, 'layer.png'), frame)

    b,c,n,w,h = images.shape
    s_w = int(s_w*scale_w)
    e_w = int(e_w*scale_w)
    s_h = int(s_h*scale_h)
    e_h = int(e_h*scale_h)
    model_config = config

    model, filter = load_model(
        model_config,
        ckpt,
        device,
        num_frames,
        num_steps,
    )


    select_point = tracks.flip(-1)
    maps = []
    for i in range(num_frames-1):
        map = torch.zeros((1,2,img_ref.shape[-2],img_ref.shape[-1])).to(device)
        rows = select_point[:, i+1, :, 0].to(device).int()#, dtype=torch.int64)
        cols = select_point[:, i+1, :, 1].to(device).int()#, dtype=torch.int64)
        rows = torch.clip(rows, 0, w-1)
        cols = torch.clip(cols, 0, h-1)
        map[0,:,rows[0], cols[0]] = (select_point[0, i+1] - select_point[0, i]).permute(1,0)#flow[kk, jj+1, :, rows[kk], cols[kk]].clone()#.cpu()
        maps.append(map)

    maps = [torch.zeros_like(maps[0])]+maps
    maps = torch.stack(maps, dim=2)#.reshape(b*n,2,w,h)
    guassian_filter = quick_freeze(get_gaussian_kernel(kernel_size=51, sigma=10, channels=2)).to(device)
    with torch.no_grad():
        maps = maps.permute(0,2,1,3,4).reshape(b*n,2,w,h)
        maps = guassian_filter(maps).reshape(b,n,2,w,h)
    images[:,:,:,s_w:e_w,s_h:e_h] = -1

    save_flow_video(maps.permute(0,2,1,3,4), 'outputs/flow.mp4')
    save_rgb_video((images.flip(1)+1)/2.,'outputs/content.mp4')

    flow = maps.reshape(b*n,2,w,h)

    region = torch.zeros_like(images)[:,:1]
    region[:,:,:,s_w:e_w,s_h:e_h] = 1
    region = region.permute(0,2,1,3,4).reshape(b*n,1,w,h)

    value_dict = {}
    print(images.shape)
    value_dict["video"] = images.to(dtype=torch.float16)
    value_dict["region"] = region.to(dtype=torch.float16)
    value_dict["motion_bucket_id"] = motion_bucket_id
    value_dict["fps_id"] = fps_id
    value_dict["cond_aug"] = cond_aug
    value_dict["cond_frames_without_noise"] = img_ref.to(dtype=torch.float16) #images[:,:,0]
    value_dict["cond_frames"] = (img_ref + cond_aug * torch.randn_like(images[:,:,0])).to(dtype=torch.float16)
    print(cond_aug)

    with torch.no_grad():
        with torch.autocast(device):
            batch, batch_uc = get_batch(
                get_unique_embedder_keys_from_conditioner(model.conditioner),
                value_dict,
                [1, num_frames],
                T=num_frames,
                device=device,
            )
            c, uc = model.conditioner.get_unconditional_conditioning(
                batch,
                batch_uc=batch_uc,
                force_uc_zero_embeddings=[
                    "cond_frames",
                    "cond_frames_without_noise",
                ],
            )

            for k in ["crossattn", "concat"]:
                uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
                uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
                c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
                c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)

            randn = torch.randn(shape, device=device, dtype = torch.float16)

            additional_model_inputs = {}
            additional_model_inputs["image_only_indicator"] = torch.zeros(
                2, num_frames
            ).to(device, dtype=torch.float16)
            #additional_model_inputs["image_only_indicator"][:,0] = 1
            additional_model_inputs["num_video_frames"] = batch["num_video_frames"]

            def denoiser(input, sigma, c):
                return model.denoiser(
                    model.model, input, sigma, c, **additional_model_inputs
                )

            c['mask'] = images.clone()
            uc['mask'] = images.clone()
            c['region'] = region.clone()
            uc['region'] = region.clone()
            c['ctrl_input'] = flow.clone()
            uc['ctrl_input'] = flow.clone()

            samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
            model.en_and_decode_n_samples_a_time = decoding_t
            samples_x = model.decode_first_stage(samples_z)
            samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)

            filename = input_path.split('/')[-1]
            res_video = visualizer.visualize(samples.unsqueeze(0).cpu()*255., tracks=tracks, save_video=False)
            visualizer.save_video(res_video, filename='result_cotracker_%s'%filename, writer=None, step=0)
            os.makedirs(os.path.join('vis_im', filename, 'im_w_track'), exist_ok=True)
            os.makedirs(os.path.join('vis_im', filename, 'im_wo_track'), exist_ok=True)
            vid = (
                (rearrange(samples, "t c h w -> t h w c") * 255)
                .cpu()
                .numpy()
                .astype(np.uint8)
            )
            vid_wo_track = (
                (rearrange(res_video[0], "t c h w -> t h w c"))
                .cpu()
                .numpy()
                .astype(np.uint8)
            )
            for idx, frame in enumerate(vid):
                frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                name = '%04d.png'%idx
                cv2.imwrite(os.path.join('vis_im', filename, 'im_wo_track', name), frame)
            for idx, frame in enumerate(vid_wo_track):
                frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                name = '%04d.png'%idx
                cv2.imwrite(os.path.join('vis_im', filename, 'im_w_track', name), frame)
    
    print(f'Done! results saved in {output_folder}.')


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=23, help="seed for seed_everything")
    parser.add_argument("--ckpt", type=str, default=None, help="checkpoint path")
    parser.add_argument("--config", type=str, help="config (yaml) path")
    parser.add_argument("--input", type=str, default=None, help="image path or folder")
    parser.add_argument("--path_ref", type=str, default=None, help="reference image path")
    parser.add_argument("--savedir", type=str, default=None, help="results saving path")
    parser.add_argument("--savefps", type=int, default=10, help="video fps to generate")
    parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt",)
    parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM",)
    parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)",)
    parser.add_argument("--frames", type=int, default=-1, help="frames num to inference")
    parser.add_argument("--fps", type=int, default=6, help="control the fps")
    parser.add_argument("--motion", type=int, default=127, help="control the motion magnitude")
    parser.add_argument("--cond_aug", type=float, default=0.02, help="adding noise to input image")
    parser.add_argument("--decoding_t", type=int, default=1, help="frames num to decoding per time")
    parser.add_argument("--resize", action='store_true', default=False, help="resize all input to default resolution")
    parser.add_argument("--s_w", type=int, default=None)
    parser.add_argument("--e_w", type=int, default=None)
    parser.add_argument("--s_h", type=int, default=None)
    parser.add_argument("--e_h", type=int, default=None)
    parser.add_argument("--ps_h", metavar="N", type=int, nargs="+", default=None)
    parser.add_argument("--ps_w", metavar="N", type=int, nargs="+", default=None)
    return parser


if __name__ == "__main__":
    now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    print("@SVD Inference: %s"%now)
    #Fire(sample)
    parser = get_parser()
    args = parser.parse_args()
    sample(input_path=args.input, path_ref=args.path_ref, ckpt=args.ckpt, config=args.config, num_frames=args.frames, num_steps=args.ddim_steps, \
        fps_id=args.fps, motion_bucket_id=args.motion, cond_aug=args.cond_aug, seed=args.seed, \
        decoding_t=args.decoding_t, output_folder=args.savedir, save_fps=args.savefps, resize=args.resize, s_w=args.s_w, e_w=args.e_w, s_h=args.s_h, e_h=args.e_h, ps_w=args.ps_w, ps_h=args.ps_h)


================================================
FILE: main/inference/sample_multi_region.py
================================================
import datetime, time
import os, sys, argparse
import math
from glob import glob
from pathlib import Path
from typing import Optional

import cv2
import numpy as np
import torch
from einops import rearrange, repeat
from fire import Fire
from PIL import Image
from torchvision.transforms import ToTensor

sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
from utils.save_video import save_flow_video, save_rgb_video
import torch.nn as nn
from utils.visualizer import Visualizer
from utils.tools import resize_pil_image, quick_freeze, get_gaussian_kernel, get_batch, get_unique_embedder_keys_from_conditioner, load_model

if not os.path.exists('ckpt'):
    os.makedirs('ckpt')
if not os.path.exists('ckpt/model.ckpt'):
    torch.hub.download_url_to_file(
        'https://huggingface.co/Adapter/ReVideo/resolve/main/model.ckpt',
        'ckpt/model.ckpt')
        
def sample(
    input_path: str = "outputs/inputs/test_image.png",  # Can either be image file or folder with image files
    path_ref: str = None,
    ckpt: str = "checkpoints/svd.safetensors",
    config: str = None,
    num_frames: Optional[int] = None,
    num_steps: Optional[int] = None,
    version: str = "svd",
    fps_id: int = 6,
    motion_bucket_id: int = 127,
    cond_aug: float = 0.02,
    seed: int = 23,
    decoding_t: int = 1,  # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
    device: str = "cuda",
    output_folder: Optional[str] = None,
    save_fps: int = 10,
    resize: Optional[bool] = False,
    # points = None
    s_w = None, 
    e_w = None, 
    s_h = None, 
    e_h = None,
    ps_w = None, 
    pe_w = None, 
    ps_h = None, 
    pe_h = None,
    x_bias_all = None,
    y_bias_all = None,
):
    """
    Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
    image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
    """

    # flow
    guassian_filter = quick_freeze(get_gaussian_kernel(kernel_size=51, sigma=10, channels=2)).to(device)
    visualizer = Visualizer(tracks_leave_trace=-1, show_first_frame=0, mode='cool', linewidth=2)
    visualizer_layer = Visualizer(tracks_leave_trace=-1, show_first_frame=0, mode='cool', linewidth=4)

    torch.manual_seed(seed)

    all_img_paths = os.listdir(input_path)
    all_img_paths.sort()
    for i in range(len(all_img_paths)):
        all_img_paths[i] = os.path.join(input_path, all_img_paths[i])

    print(f'loaded {len(all_img_paths)} images.')
    os.makedirs(output_folder, exist_ok=True)
    images = []
    for no, input_img_path in enumerate(all_img_paths):
        filepath, fullflname = os.path.split(input_img_path)
        filename, ext = os.path.splitext(fullflname)
        print(f'-sample {no+1}: {filename} ...')
        with Image.open(input_img_path) as image:
            if image.mode == "RGBA":
                image = image.convert("RGB")
            w_org, h_org = image.size
            image = resize_pil_image(image, max_resolution=1024*1024)
            scale_h = image.size[0]/w_org
            scale_w = image.size[1]/h_org
            w, h = image.size

            if h % 64 != 0 or w % 64 != 0:
                width, height = map(lambda x: x - x % 64, (w, h))
                image = image.resize((width, height))
                print(
                    f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
                )

            image = ToTensor()(image)
            image = image * 2.0 - 1.0

        image = image.unsqueeze(0).to(device)
        H, W = image.shape[2:]
        assert image.shape[1] == 3
        F = 8
        C = 4
        shape = (num_frames, C, H // F, W // F)
        images.append(image)
    images = images[:num_frames]
    images = torch.stack(images, dim=2)
    save_rgb_video((images.flip(1)+1)/2.,'outputs/org.mp4')

    img_ref = Image.open(path_ref)
    if img_ref.mode == "RGBA":
        img_ref = img_ref.convert("RGB")
    img_ref = img_ref.resize((images.shape[-1], images.shape[-2]))
    img_ref = ToTensor()(img_ref)
    img_ref = img_ref * 2.0 - 1.0
    img_ref = img_ref.unsqueeze(0).to(device)

    b,c,n,w,h = images.shape
    map = torch.zeros((1,2,num_frames-1,img_ref.shape[-2],img_ref.shape[-1])).to(device)
    region = torch.zeros_like(images)[:,:1]
    tracks = []
    assert len(x_bias_all)==len(y_bias_all) and len(x_bias_all)==len(s_w), 'Wrong bias!'
    for k in range(len(s_w)):
        s_w[k] = int(s_w[k]*scale_w)
        e_w[k] = int(e_w[k]*scale_w)
        s_h[k] = int(s_h[k]*scale_h)
        e_h[k] = int(e_h[k]*scale_h)
        images[:,:,:,s_w[k]:e_w[k],s_h[k]:e_h[k]] = -1
        region[:,:,:,s_w[k]:e_w[k],s_h[k]:e_h[k]] = 1

        p_start = [ps_h[k]*scale_h, ps_w[k]*scale_w]
        p_end = [pe_h[k]*scale_h, pe_w[k]*scale_w]

        x_dist = p_end[0]-p_start[0]
        y_dist = p_end[1]-p_start[1]
        inter_x = x_dist/(num_frames-2)
        inter_y = y_dist/(num_frames-2)
        x_bias = x_bias_all[k]
        y_bias = y_bias_all[k]
        for i in range(num_frames-1):
            x_cur = int(p_start[0]+i*inter_x + x_bias[i%len(x_bias)])
            y_cur = int(p_start[1]+i*inter_y + y_bias[i%len(y_bias)])
            if i == 0:
                x_per = p_start[0]
                y_per = p_start[1]
            else:
                x_per = int(p_start[0]+(i-1)*inter_x + x_bias[(i-1)%len(x_bias)])
                y_per = int(p_start[1]+(i-1)*inter_y+ y_bias[(i-1)%len(y_bias)])
            map[:,1,i,y_cur,x_cur] = x_cur - x_per
            map[:,0,i,y_cur,x_cur] = y_cur - y_per

        track = torch.zeros(b,num_frames,1,2)
        for i in range(num_frames):
            if i == 0:
                x_cur = p_start[0]
                y_cur = p_start[1]
            else:
                x_cur = int(p_start[0]+i*inter_x + x_bias[i%len(x_bias)])
                y_cur = int(p_start[1]+i*inter_y + y_bias[i%len(y_bias)])
            track[0,i,0,0]=x_cur
            track[0,i,0,1]=y_cur
        tracks.append(track)
    tracks = torch.cat(tracks, dim=2)

    layer = torch.ones_like(images.permute(0,2,1,3,4))
    res_video = visualizer_layer.visualize(layer.cpu()*255., tracks=tracks, save_video=False)
    res_video = (
        (rearrange(res_video[0], "t c h w -> t h w c"))
        .numpy()
        .astype(np.uint8)
    )
    frame = cv2.cvtColor(res_video[-1], cv2.COLOR_RGB2BGR)
    filename = input_path.split('/')[-1]
    cv2.imwrite(os.path.join('vis_im', filename, 'layer.png'), frame)

    pad = torch.zeros_like(map[:,:,:1,:,:])
    map = torch.cat([pad, map], dim=2)
    with torch.no_grad():
        map = map.permute(0,2,1,3,4).reshape(b*n,2,w,h)
        map = guassian_filter(map).reshape(b,n,2,w,h)

    save_flow_video(map.permute(0,2,1,3,4), 'outputs/flow.mp4')
    save_rgb_video((images.flip(1)+1)/2.,'outputs/content.mp4')

    flow = map.reshape(b*n,2,w,h)

    model_config = config

    model, filter = load_model(
        model_config,
        ckpt,
        device,
        num_frames,
        num_steps,
    )
    
    region = region.permute(0,2,1,3,4).reshape(b*n,1,w,h)

    value_dict = {}
    value_dict["video"] = images.to(dtype=torch.float16)
    value_dict["region"] = region.to(dtype=torch.float16)
    value_dict["motion_bucket_id"] = motion_bucket_id
    value_dict["fps_id"] = fps_id
    value_dict["cond_aug"] = cond_aug
    value_dict["cond_frames_without_noise"] = img_ref.to(dtype=torch.float16) #images[:,:,0]
    value_dict["cond_frames"] = img_ref.to(dtype=torch.float16) #images[:,:,0] # + cond_aug * torch.randn_like(images[:,:,0])

    with torch.no_grad():
        with torch.autocast(device):
            batch, batch_uc = get_batch(
                get_unique_embedder_keys_from_conditioner(model.conditioner),
                value_dict,
                [1, num_frames],
                T=num_frames,
                device=device,
            )
            c, uc = model.conditioner.get_unconditional_conditioning(
                batch,
                batch_uc=batch_uc,
                force_uc_zero_embeddings=[
                    "cond_frames",
                    "cond_frames_without_noise",
                ],
            )

            for k in ["crossattn", "concat"]:
                uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
                uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
                c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
                c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)

            randn = torch.randn(shape, device=device, dtype=torch.float16)

            additional_model_inputs = {}
            additional_model_inputs["image_only_indicator"] = torch.zeros(
                2, num_frames
            ).to(device, dtype=torch.float16)
            additional_model_inputs["num_video_frames"] = batch["num_video_frames"]

            def denoiser(input, sigma, c):
                return model.denoiser(
                    model.model, input, sigma, c, **additional_model_inputs
                )

            c['mask'] = images.clone()
            uc['mask'] = images.clone()
            c['region'] = region.clone()
            uc['region'] = region.clone()
            c['ctrl_input'] = flow.clone()
            uc['ctrl_input'] = flow.clone()

            samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
            model.en_and_decode_n_samples_a_time = decoding_t
            samples_x = model.decode_first_stage(samples_z)
            samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)

            filename = input_path.split('/')[-1]
            res_video = visualizer.visualize(samples.unsqueeze(0).cpu()*255., tracks=tracks, save_video=False)
            visualizer.save_video(res_video, filename='result_cotracker_%s'%filename, writer=None, step=0)

            os.makedirs(os.path.join('vis_im', filename, 'im_w_track'), exist_ok=True)
            os.makedirs(os.path.join('vis_im', filename, 'im_wo_track'), exist_ok=True)
            vid = (
                (rearrange(samples, "t c h w -> t h w c") * 255)
                .cpu()
                .numpy()
                .astype(np.uint8)
            )
            vid_wo_track = (
                (rearrange(res_video[0], "t c h w -> t h w c"))
                .cpu()
                .numpy()
                .astype(np.uint8)
            )
            for idx, frame in enumerate(vid):
                frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                name = '%04d.png'%idx
                cv2.imwrite(os.path.join('vis_im', filename, 'im_wo_track', name), frame)
            for idx, frame in enumerate(vid_wo_track):
                frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                name = '%04d.png'%idx
                cv2.imwrite(os.path.join('vis_im', filename, 'im_w_track', name), frame)
    
    print(f'Done! results saved in {output_folder}.')


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=23, help="seed for seed_everything")
    parser.add_argument("--ckpt", type=str, default=None, help="checkpoint path")
    parser.add_argument("--config", type=str, help="config (yaml) path")
    parser.add_argument("--input", type=str, default=None, help="image path or folder")
    parser.add_argument("--path_ref", type=str, default=None, help="reference image path")
    parser.add_argument("--savedir", type=str, default=None, help="results saving path")
    parser.add_argument("--savefps", type=int, default=10, help="video fps to generate")
    parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt",)
    parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM",)
    parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)",)
    parser.add_argument("--frames", type=int, default=-1, help="frames num to inference")
    parser.add_argument("--fps", type=int, default=6, help="control the fps")
    parser.add_argument("--motion", type=int, default=127, help="control the motion magnitude")
    parser.add_argument("--cond_aug", type=float, default=0.02, help="adding noise to input image")
    parser.add_argument("--decoding_t", type=int, default=1, help="frames num to decoding per time")
    parser.add_argument("--resize", action='store_true', default=False, help="resize all input to default resolution")
    parser.add_argument("--s_w", metavar="N", type=int, nargs="+", default=None)
    parser.add_argument("--e_w", metavar="N", type=int, nargs="+", default=None)
    parser.add_argument("--s_h", metavar="N", type=int, nargs="+", default=None)
    parser.add_argument("--e_h", metavar="N", type=int, nargs="+", default=None)
    parser.add_argument("--ps_w", metavar="N", type=int, nargs="+", default=None)
    parser.add_argument("--pe_w", metavar="N", type=int, nargs="+", default=None)
    parser.add_argument("--ps_h", metavar="N", type=int, nargs="+", default=None)
    parser.add_argument("--pe_h", metavar="N", type=int, nargs="+", default=None)
    parser.add_argument("--x_bias_all", nargs="+", action="append", type=int, help="Horizontal swing")
    parser.add_argument("--y_bias_all", nargs="+", action="append", type=int, help="Vertical swing")
    return parser


if __name__ == "__main__":
    now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    print("@SVD Inference: %s"%now)
    #Fire(sample)
    parser = get_parser()
    args = parser.parse_args()
    sample(input_path=args.input, path_ref=args.path_ref, ckpt=args.ckpt, config=args.config, num_frames=args.frames, num_steps=args.ddim_steps, \
        fps_id=args.fps, motion_bucket_id=args.motion, cond_aug=args.cond_aug, seed=args.seed, \
        decoding_t=args.decoding_t, output_folder=args.savedir, save_fps=args.savefps, resize=args.resize, \
        s_w=args.s_w, e_w=args.e_w, s_h=args.s_h, e_h=args.e_h, ps_w=args.ps_w, pe_w=args.pe_w, ps_h=args.ps_h, \
        pe_h=args.pe_h, x_bias_all=args.x_bias_all, y_bias_all=args.y_bias_all)


================================================
FILE: main/inference/sample_single_region.py
================================================
import datetime, time
import os, sys, argparse
import math
from glob import glob
from pathlib import Path
from typing import Optional

import cv2
import numpy as np
import torch
from einops import rearrange, repeat
from fire import Fire
from omegaconf import OmegaConf
from PIL import Image
from torchvision.transforms import ToTensor

sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
from sgm.util import default, instantiate_from_config
from utils.save_video import save_flow_video, save_rgb_video
import torch.nn as nn
from utils.visualizer import Visualizer
from utils.tools import resize_pil_image, quick_freeze, get_gaussian_kernel, get_batch, get_unique_embedder_keys_from_conditioner, load_model

if not os.path.exists('ckpt'):
    os.makedirs('ckpt')
if not os.path.exists('ckpt/model.ckpt'):
    torch.hub.download_url_to_file(
        'https://huggingface.co/Adapter/ReVideo/resolve/main/model.ckpt',
        'ckpt/model.ckpt')

def sample(
    input_path: str = "outputs/inputs/test_image.png",  # Can either be image file or folder with image files
    path_ref: str = None,
    ckpt: str = "checkpoints/svd.safetensors",
    config: str = None,
    num_frames: Optional[int] = None,
    num_steps: Optional[int] = None,
    version: str = "svd",
    fps_id: int = 6,
    motion_bucket_id: int = 127,
    cond_aug: float = 0.02,
    seed: int = 23,
    decoding_t: int = 1,  # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
    device: str = "cuda",
    output_folder: Optional[str] = None,
    save_fps: int = 10,
    resize: Optional[bool] = False,
    # points = None
    s_w = None, 
    e_w = None, 
    s_h = None, 
    e_h = None,
    ps_w = None, 
    pe_w = None, 
    ps_h = None, 
    pe_h = None,
    x_bias = None,
    y_bias = None,
):
    """
    Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
    image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
    """
    guassian_filter = quick_freeze(get_gaussian_kernel(kernel_size=51, sigma=10, channels=2)).to(device)
    visualizer = Visualizer(tracks_leave_trace=-1, show_first_frame=0, mode='cool', linewidth=4)
    visualizer_layer = Visualizer(tracks_leave_trace=-1, show_first_frame=0, mode='cool', linewidth=4)

    torch.manual_seed(seed)

    all_img_paths = os.listdir(input_path)
    all_img_paths.sort()
    for i in range(len(all_img_paths)):
        all_img_paths[i] = os.path.join(input_path, all_img_paths[i])

    print(f'loaded {len(all_img_paths)} images.')
    os.makedirs(output_folder, exist_ok=True)
    images = []
    for no, input_img_path in enumerate(all_img_paths):
        filepath, fullflname = os.path.split(input_img_path)
        filename, ext = os.path.splitext(fullflname)
        print(f'-sample {no+1}: {filename} ...')
        with Image.open(input_img_path) as image:
            if image.mode == "RGBA":
                image = image.convert("RGB")
            w_org, h_org = image.size
            image = resize_pil_image(image, max_resolution=1024*1024)
            scale_h = image.size[0]/w_org
            scale_w = image.size[1]/h_org
            w, h = image.size

            if h % 64 != 0 or w % 64 != 0:
                width, height = map(lambda x: x - x % 64, (w, h))
                image = image.resize((width, height))
                print(
                    f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
                )

            image = ToTensor()(image)
            image = image * 2.0 - 1.0

        image = image.unsqueeze(0).to(device)
        H, W = image.shape[2:]
        assert image.shape[1] == 3
        F = 8
        C = 4
        shape = (num_frames, C, H // F, W // F)
        images.append(image)
    images = images[:num_frames]#[::-1]#.flip(dims=(1,))
    images = torch.stack(images, dim=2)
    save_rgb_video((images.flip(1)+1)/2.,'outputs/org.mp4')

    img_ref = Image.open(path_ref)
    if img_ref.mode == "RGBA":
        img_ref = img_ref.convert("RGB")
    img_ref = img_ref.resize((images.shape[-1], images.shape[-2]))
    img_ref = ToTensor()(img_ref)
    img_ref = img_ref * 2.0 - 1.0
    img_ref = img_ref.unsqueeze(0).to(device)

    b,c,n,w,h = images.shape
    s_w = int(s_w*scale_w)
    e_w = int(e_w*scale_w)
    s_h = int(s_h*scale_h)
    e_h = int(e_h*scale_h)
    p_start = [[ps_h_i*scale_h, ps_w_i*scale_w] for ps_h_i, ps_w_i in zip(ps_h, ps_w)]
    p_end = [[pe_h_i*scale_h, pe_w_i*scale_w] for pe_h_i, pe_w_i in zip(pe_h, pe_w)]
    tracks = torch.zeros(b,num_frames,len(p_start),2)
    for k in range(len(p_start)):
        x_dist = p_end[k][0]-p_start[k][0]
        y_dist = p_end[k][1]-p_start[k][1]
        inter_x = x_dist/(num_frames-2)
        inter_y = y_dist/(num_frames-2)
        for i in range(num_frames):
            if i == 0:
                x_cur = p_start[k][0]
                y_cur = p_start[k][1]
            else:
                x_cur = int(p_start[k][0]+i*inter_x + x_bias[i%len(x_bias)])
                y_cur = int(p_start[k][1]+i*inter_y + y_bias[i%len(y_bias)])
            tracks[0,i,k,0]=x_cur
            tracks[0,i,k,1]=y_cur
    layer = torch.ones_like(images.permute(0,2,1,3,4))
    res_video = visualizer_layer.visualize(layer.cpu()*255., tracks=tracks, save_video=False)
    res_video = (
        (rearrange(res_video[0], "t c h w -> t h w c"))
        .numpy()
        .astype(np.uint8)
    )
    frame = cv2.cvtColor(res_video[-1], cv2.COLOR_RGB2BGR)
    filename = input_path.split('/')[-1]
    cv2.imwrite(os.path.join('vis_im', filename, 'layer.png'), frame)

    model_config = config
    model, filter = load_model(
        model_config,
        ckpt,
        device,
        num_frames,
        num_steps,
    )

    map = torch.zeros((1,2,num_frames-1,img_ref.shape[-2],img_ref.shape[-1])).to(device)
    for k in range(len(p_start)):
        x_dist = p_end[k][0]-p_start[k][0]
        y_dist = p_end[k][1]-p_start[k][1]
        inter_x = x_dist/(num_frames-2)
        inter_y = y_dist/(num_frames-2)
        for i in range(num_frames-1):
            x_cur = int(p_start[k][0]+i*inter_x + x_bias[i%len(x_bias)])
            y_cur = int(p_start[k][1]+i*inter_y + y_bias[i%len(y_bias)])
            if i == 0:
                x_per = p_start[k][0]
                y_per = p_start[k][1]
            else:
                x_per = int(p_start[k][0]+(i-1)*inter_x + x_bias[(i-1)%len(x_bias)])
                y_per = int(p_start[k][1]+(i-1)*inter_y+ y_bias[(i-1)%len(y_bias)])
            map[:,1,i,y_cur,x_cur] = x_cur - x_per
            map[:,0,i,y_cur,x_cur] = y_cur - y_per

    pad = torch.zeros_like(map[:,:,:1,:,:])
    map = torch.cat([pad, map], dim=2)
    guassian_filter = quick_freeze(get_gaussian_kernel(kernel_size=51, sigma=10, channels=2)).to(device)
    with torch.no_grad():
        map = map.permute(0,2,1,3,4).reshape(b*n,2,w,h)
        map = guassian_filter(map).reshape(b,n,2,w,h)
    images[:,:,:,s_w:e_w,s_h:e_h] = -1

    save_flow_video(map.permute(0,2,1,3,4), 'outputs/flow.mp4')
    save_rgb_video((images.flip(1)+1)/2.,'outputs/content.mp4')

    flow = map.reshape(b*n,2,w,h)
    region = torch.zeros_like(images)[:,:1]
    region[:,:,:,s_w:e_w,s_h:e_h] = 1
    region = region.permute(0,2,1,3,4).reshape(b*n,1,w,h)

    value_dict = {}
    value_dict["video"] = images.to(dtype=torch.float16)
    value_dict["region"] = region.to(dtype=torch.float16)
    value_dict["motion_bucket_id"] = motion_bucket_id
    value_dict["fps_id"] = fps_id
    value_dict["cond_aug"] = cond_aug
    value_dict["cond_frames_without_noise"] = img_ref.to(dtype=torch.float16) #images[:,:,0]
    value_dict["cond_frames"] = (img_ref + cond_aug * torch.randn_like(images[:,:,0])).to(dtype=torch.float16) #images[:,:,0] # + cond_aug * torch.randn_like(images[:,:,0])

    with torch.no_grad():
        with torch.autocast(device):
            batch, batch_uc = get_batch(
                get_unique_embedder_keys_from_conditioner(model.conditioner),
                value_dict,
                [1, num_frames],
                T=num_frames,
                device=device,
            )
            c, uc = model.conditioner.get_unconditional_conditioning(
                batch,
                batch_uc=batch_uc,
                force_uc_zero_embeddings=[
                    "cond_frames",
                    "cond_frames_without_noise",
                ],
            )

            for k in ["crossattn", "concat"]:
                uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
                uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
                c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
                c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)

            randn = torch.randn(shape, device=device, dtype=torch.float16)

            additional_model_inputs = {}
            additional_model_inputs["image_only_indicator"] = torch.zeros(
                2, num_frames
            ).to(device, dtype=torch.float16)
            #additional_model_inputs["image_only_indicator"][:,0] = 1
            additional_model_inputs["num_video_frames"] = batch["num_video_frames"]

            def denoiser(input, sigma, c):
                return model.denoiser(
                    model.model, input, sigma, c, **additional_model_inputs
                )
            c['mask'] = images.clone()
            uc['mask'] = images.clone()
            c['region'] = region.clone()
            uc['region'] = region.clone()
            c['ctrl_input'] = flow.clone()
            uc['ctrl_input'] = flow.clone()

            samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
            # model.first_stage_model.to(dtype=torch.float32)
            model.en_and_decode_n_samples_a_time = decoding_t
            samples_x = model.decode_first_stage(samples_z)
            samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)

            filename = input_path.split('/')[-1]
            res_video = visualizer.visualize(samples.unsqueeze(0).cpu()*255., tracks=tracks, save_video=False)
            visualizer.save_video(res_video, filename='result_cotracker_%s'%filename, writer=None, step=0)

            os.makedirs(os.path.join('vis_im', filename, 'im_w_track'), exist_ok=True)
            os.makedirs(os.path.join('vis_im', filename, 'im_wo_track'), exist_ok=True)
            vid = (
                (rearrange(samples, "t c h w -> t h w c") * 255)
                .cpu()
                .numpy()
                .astype(np.uint8)
            )
            vid_wo_track = (
                (rearrange(res_video[0], "t c h w -> t h w c"))
                .cpu()
                .numpy()
                .astype(np.uint8)
            )
            for idx, frame in enumerate(vid):
                frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                name = '%04d.png'%idx
                cv2.imwrite(os.path.join('vis_im', filename, 'im_wo_track', name), frame)
            for idx, frame in enumerate(vid_wo_track):
                frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                name = '%04d.png'%idx
                cv2.imwrite(os.path.join('vis_im', filename, 'im_w_track', name), frame)
    
    print(f'Done! results saved in {output_folder}.')


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=23, help="seed for seed_everything")
    parser.add_argument("--ckpt", type=str, default=None, help="checkpoint path")
    parser.add_argument("--config", type=str, help="config (yaml) path")
    parser.add_argument("--input", type=str, default=None, help="image path or folder")
    parser.add_argument("--path_ref", type=str, default=None, help="reference image path")
    parser.add_argument("--savedir", type=str, default=None, help="results saving path")
    parser.add_argument("--savefps", type=int, default=10, help="video fps to generate")
    parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt",)
    parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM",)
    parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)",)
    parser.add_argument("--frames", type=int, default=-1, help="frames num to inference")
    parser.add_argument("--fps", type=int, default=6, help="control the fps")
    parser.add_argument("--motion", type=int, default=127, help="control the motion magnitude")
    parser.add_argument("--cond_aug", type=float, default=0.02, help="adding noise to input image")
    parser.add_argument("--decoding_t", type=int, default=1, help="frames num to decoding per time")
    parser.add_argument("--resize", action='store_true', default=False, help="resize all input to default resolution")
    parser.add_argument("--s_w", type=int, default=None)
    parser.add_argument("--e_w", type=int, default=None)
    parser.add_argument("--s_h", type=int, default=None)
    parser.add_argument("--e_h", type=int, default=None)
    parser.add_argument("--ps_w", metavar="N", type=int, nargs="+", default=None)
    parser.add_argument("--pe_w", metavar="N", type=int, nargs="+", default=None)
    parser.add_argument("--ps_h", metavar="N", type=int, nargs="+", default=None)
    parser.add_argument("--pe_h", metavar="N", type=int, nargs="+", default=None)
    parser.add_argument("--x_bias", metavar="N", type=int, nargs="+", default=None)
    parser.add_argument("--y_bias", metavar="N", type=int, nargs="+", default=None)
    return parser


if __name__ == "__main__":
    now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    print("@SVD Inference: %s"%now)
    #Fire(sample)
    parser = get_parser()
    args = parser.parse_args()
    sample(input_path=args.input, path_ref=args.path_ref, ckpt=args.ckpt, config=args.config, num_frames=args.frames, num_steps=args.ddim_steps, \
        fps_id=args.fps, motion_bucket_id=args.motion, cond_aug=args.cond_aug, seed=args.seed, \
        decoding_t=args.decoding_t, output_folder=args.savedir, save_fps=args.savefps, resize=args.resize, \
        s_w=args.s_w, e_w=args.e_w, s_h=args.s_h, e_h=args.e_h, ps_w=args.ps_w, pe_w=args.pe_w, ps_h=args.ps_h, pe_h=args.pe_h, x_bias=args.x_bias, y_bias=args.y_bias)


================================================
FILE: requirements.txt
================================================
black==23.7.0
chardet==5.1.0
clip
einops>=0.6.1
fairscale==0.4.13
fire==0.5.0
fsspec>=2023.6.0
invisible-watermark>=0.2.0
kornia==0.6.9
matplotlib==3.7.2
natsort>=8.4.0
ninja>=1.11.1
numpy>=1.24.4
omegaconf>=2.3.0
open-clip-torch>=2.20.0
opencv-python==4.6.0.66
pandas>=2.0.3
pillow>=9.5.0
pudb>=2022.1.3
pytorch-lightning==1.9.0
pyyaml>=6.0.1
scipy>=1.10.1
streamlit>=0.73.1
tensorboardx==2.6
timm==0.4.12
tokenizers==0.12.1
torch==2.0.1
torchaudio>=2.0.2
torchdata==0.6.1
torchmetrics>=1.0.1
torchvision>=0.15.2
tqdm>=4.65.0
transformers==4.19.1
triton==2.0.0
urllib3<1.27,>=1.25.4
wandb>=0.15.6
webdataset>=0.2.33
wheel>=0.41.0
xformers==0.0.21
decord
deepspeed
yacs
controlnet_aux
PyAV
safetensors
imageio[ffmpeg]
gradio==3.50.2
safetensors

================================================
FILE: sgm/__init__.py
================================================
from .models import AutoencodingEngine, DiffusionEngine
from .util import get_configs_path, instantiate_from_config

__version__ = "0.1.0"


================================================
FILE: sgm/inference/api.py
================================================
import pathlib
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Optional

from omegaconf import OmegaConf

from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img,
                                   do_sample)
from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
                                                   DPMPP2SAncestralSampler,
                                                   EulerAncestralSampler,
                                                   EulerEDMSampler,
                                                   HeunEDMSampler,
                                                   LinearMultistepSampler)
from sgm.util import load_model_from_config


class ModelArchitecture(str, Enum):
    SD_2_1 = "stable-diffusion-v2-1"
    SD_2_1_768 = "stable-diffusion-v2-1-768"
    SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
    SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
    SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
    SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"


class Sampler(str, Enum):
    EULER_EDM = "EulerEDMSampler"
    HEUN_EDM = "HeunEDMSampler"
    EULER_ANCESTRAL = "EulerAncestralSampler"
    DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler"
    DPMPP2M = "DPMPP2MSampler"
    LINEAR_MULTISTEP = "LinearMultistepSampler"


class Discretization(str, Enum):
    LEGACY_DDPM = "LegacyDDPMDiscretization"
    EDM = "EDMDiscretization"


class Guider(str, Enum):
    VANILLA = "VanillaCFG"
    IDENTITY = "IdentityGuider"


class Thresholder(str, Enum):
    NONE = "None"


@dataclass
class SamplingParams:
    width: int = 1024
    height: int = 1024
    steps: int = 50
    sampler: Sampler = Sampler.DPMPP2M
    discretization: Discretization = Discretization.LEGACY_DDPM
    guider: Guider = Guider.VANILLA
    thresholder: Thresholder = Thresholder.NONE
    scale: float = 6.0
    aesthetic_score: float = 5.0
    negative_aesthetic_score: float = 5.0
    img2img_strength: float = 1.0
    orig_width: int = 1024
    orig_height: int = 1024
    crop_coords_top: int = 0
    crop_coords_left: int = 0
    sigma_min: float = 0.0292
    sigma_max: float = 14.6146
    rho: float = 3.0
    s_churn: float = 0.0
    s_tmin: float = 0.0
    s_tmax: float = 999.0
    s_noise: float = 1.0
    eta: float = 1.0
    order: int = 4


@dataclass
class SamplingSpec:
    width: int
    height: int
    channels: int
    factor: int
    is_legacy: bool
    config: str
    ckpt: str
    is_guided: bool


model_specs = {
    ModelArchitecture.SD_2_1: SamplingSpec(
        height=512,
        width=512,
        channels=4,
        factor=8,
        is_legacy=True,
        config="sd_2_1.yaml",
        ckpt="v2-1_512-ema-pruned.safetensors",
        is_guided=True,
    ),
    ModelArchitecture.SD_2_1_768: SamplingSpec(
        height=768,
        width=768,
        channels=4,
        factor=8,
        is_legacy=True,
        config="sd_2_1_768.yaml",
        ckpt="v2-1_768-ema-pruned.safetensors",
        is_guided=True,
    ),
    ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
        height=1024,
        width=1024,
        channels=4,
        factor=8,
        is_legacy=False,
        config="sd_xl_base.yaml",
        ckpt="sd_xl_base_0.9.safetensors",
        is_guided=True,
    ),
    ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
        height=1024,
        width=1024,
        channels=4,
        factor=8,
        is_legacy=True,
        config="sd_xl_refiner.yaml",
        ckpt="sd_xl_refiner_0.9.safetensors",
        is_guided=True,
    ),
    ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
        height=1024,
        width=1024,
        channels=4,
        factor=8,
        is_legacy=False,
        config="sd_xl_base.yaml",
        ckpt="sd_xl_base_1.0.safetensors",
        is_guided=True,
    ),
    ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
        height=1024,
        width=1024,
        channels=4,
        factor=8,
        is_legacy=True,
        config="sd_xl_refiner.yaml",
        ckpt="sd_xl_refiner_1.0.safetensors",
        is_guided=True,
    ),
}


class SamplingPipeline:
    def __init__(
        self,
        model_id: ModelArchitecture,
        model_path="checkpoints",
        config_path="configs/inference",
        device="cuda",
        use_fp16=True,
    ) -> None:
        if model_id not in model_specs:
            raise ValueError(f"Model {model_id} not supported")
        self.model_id = model_id
        self.specs = model_specs[self.model_id]
        self.config = str(pathlib.Path(config_path, self.specs.config))
        self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
        self.device = device
        self.model = self._load_model(device=device, use_fp16=use_fp16)

    def _load_model(self, device="cuda", use_fp16=True):
        config = OmegaConf.load(self.config)
        model = load_model_from_config(config, self.ckpt)
        if model is None:
            raise ValueError(f"Model {self.model_id} could not be loaded")
        model.to(device)
        if use_fp16:
            model.conditioner.half()
            model.model.half()
        return model

    def text_to_image(
        self,
        params: SamplingParams,
        prompt: str,
        negative_prompt: str = "",
        samples: int = 1,
        return_latents: bool = False,
    ):
        sampler = get_sampler_config(params)
        value_dict = asdict(params)
        value_dict["prompt"] = prompt
        value_dict["negative_prompt"] = negative_prompt
        value_dict["target_width"] = params.width
        value_dict["target_height"] = params.height
        return do_sample(
            self.model,
            sampler,
            value_dict,
            samples,
            params.height,
            params.width,
            self.specs.channels,
            self.specs.factor,
            force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
            return_latents=return_latents,
            filter=None,
        )

    def image_to_image(
        self,
        params: SamplingParams,
        image,
        prompt: str,
        negative_prompt: str = "",
        samples: int = 1,
        return_latents: bool = False,
    ):
        sampler = get_sampler_config(params)

        if params.img2img_strength < 1.0:
            sampler.discretization = Img2ImgDiscretizationWrapper(
                sampler.discretization,
                strength=params.img2img_strength,
            )
        height, width = image.shape[2], image.shape[3]
        value_dict = asdict(params)
        value_dict["prompt"] = prompt
        value_dict["negative_prompt"] = negative_prompt
        value_dict["target_width"] = width
        value_dict["target_height"] = height
        return do_img2img(
            image,
            self.model,
            sampler,
            value_dict,
            samples,
            force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
            return_latents=return_latents,
            filter=None,
        )

    def refiner(
        self,
        params: SamplingParams,
        image,
        prompt: str,
        negative_prompt: Optional[str] = None,
        samples: int = 1,
        return_latents: bool = False,
    ):
        sampler = get_sampler_config(params)
        value_dict = {
            "orig_width": image.shape[3] * 8,
            "orig_height": image.shape[2] * 8,
            "target_width": image.shape[3] * 8,
            "target_height": image.shape[2] * 8,
            "prompt": prompt,
            "negative_prompt": negative_prompt,
            "crop_coords_top": 0,
            "crop_coords_left": 0,
            "aesthetic_score": 6.0,
            "negative_aesthetic_score": 2.5,
        }

        return do_img2img(
            image,
            self.model,
            sampler,
            value_dict,
            samples,
            skip_encode=True,
            return_latents=return_latents,
            filter=None,
        )


def get_guider_config(params: SamplingParams):
    if params.guider == Guider.IDENTITY:
        guider_config = {
            "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
        }
    elif params.guider == Guider.VANILLA:
        scale = params.scale

        thresholder = params.thresholder

        if thresholder == Thresholder.NONE:
            dyn_thresh_config = {
                "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
            }
        else:
            raise NotImplementedError

        guider_config = {
            "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
            "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
        }
    else:
        raise NotImplementedError
    return guider_config


def get_discretization_config(params: SamplingParams):
    if params.discretization == Discretization.LEGACY_DDPM:
        discretization_config = {
            "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
        }
    elif params.discretization == Discretization.EDM:
        discretization_config = {
            "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
            "params": {
                "sigma_min": params.sigma_min,
                "sigma_max": params.sigma_max,
                "rho": params.rho,
            },
        }
    else:
        raise ValueError(f"unknown discretization {params.discretization}")
    return discretization_config


def get_sampler_config(params: SamplingParams):
    discretization_config = get_discretization_config(params)
    guider_config = get_guider_config(params)
    sampler = None
    if params.sampler == Sampler.EULER_EDM:
        return EulerEDMSampler(
            num_steps=params.steps,
            discretization_config=discretization_config,
            guider_config=guider_config,
            s_churn=params.s_churn,
            s_tmin=params.s_tmin,
            s_tmax=params.s_tmax,
            s_noise=params.s_noise,
            verbose=True,
        )
    if params.sampler == Sampler.HEUN_EDM:
        return HeunEDMSampler(
            num_steps=params.steps,
            discretization_config=discretization_config,
            guider_config=guider_config,
            s_churn=params.s_churn,
            s_tmin=params.s_tmin,
            s_tmax=params.s_tmax,
            s_noise=params.s_noise,
            verbose=True,
        )
    if params.sampler == Sampler.EULER_ANCESTRAL:
        return EulerAncestralSampler(
            num_steps=params.steps,
            discretization_config=discretization_config,
            guider_config=guider_config,
            eta=params.eta,
            s_noise=params.s_noise,
            verbose=True,
        )
    if params.sampler == Sampler.DPMPP2S_ANCESTRAL:
        return DPMPP2SAncestralSampler(
            num_steps=params.steps,
            discretization_config=discretization_config,
            guider_config=guider_config,
            eta=params.eta,
            s_noise=params.s_noise,
            verbose=True,
        )
    if params.sampler == Sampler.DPMPP2M:
        return DPMPP2MSampler(
            num_steps=params.steps,
            discretization_config=discretization_config,
            guider_config=guider_config,
            verbose=True,
        )
    if params.sampler == Sampler.LINEAR_MULTISTEP:
        return LinearMultistepSampler(
            num_steps=params.steps,
            discretization_config=discretization_config,
            guider_config=guider_config,
            order=params.order,
            verbose=True,
        )

    raise ValueError(f"unknown sampler {params.sampler}!")


================================================
FILE: sgm/inference/helpers.py
================================================
import math
import os
from typing import List, Optional, Union

import numpy as np
import torch
from einops import rearrange
from imwatermark import WatermarkEncoder
from omegaconf import ListConfig
from PIL import Image
from torch import autocast

from sgm.util import append_dims


class WatermarkEmbedder:
    def __init__(self, watermark):
        self.watermark = watermark
        self.num_bits = len(WATERMARK_BITS)
        self.encoder = WatermarkEncoder()
        self.encoder.set_watermark("bits", self.watermark)

    def __call__(self, image: torch.Tensor) -> torch.Tensor:
        """
        Adds a predefined watermark to the input image

        Args:
            image: ([N,] B, RGB, H, W) in range [0, 1]

        Returns:
            same as input but watermarked
        """
        squeeze = len(image.shape) == 4
        if squeeze:
            image = image[None, ...]
        n = image.shape[0]
        image_np = rearrange(
            (255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
        ).numpy()[:, :, :, ::-1]
        # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
        # watermarking libary expects input as cv2 BGR format
        for k in range(image_np.shape[0]):
            image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
        image = torch.from_numpy(
            rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
        ).to(image.device)
        image = torch.clamp(image / 255, min=0.0, max=1.0)
        if squeeze:
            image = image[0]
        return image


# A fixed 48-bit message that was choosen at random
# WATERMARK_MESSAGE = 0xB3EC907BB19E
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)


def get_unique_embedder_keys_from_conditioner(conditioner):
    return list({x.input_key for x in conditioner.embedders})


def perform_save_locally(save_path, samples):
    os.makedirs(os.path.join(save_path), exist_ok=True)
    base_count = len(os.listdir(os.path.join(save_path)))
    samples = embed_watermark(samples)
    for sample in samples:
        sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
        Image.fromarray(sample.astype(np.uint8)).save(
            os.path.join(save_path, f"{base_count:09}.png")
        )
        base_count += 1


class Img2ImgDiscretizationWrapper:
    """
    wraps a discretizer, and prunes the sigmas
    params:
        strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
    """

    def __init__(self, discretization, strength: float = 1.0):
        self.discretization = discretization
        self.strength = strength
        assert 0.0 <= self.strength <= 1.0

    def __call__(self, *args, **kwargs):
        # sigmas start large first, and decrease then
        sigmas = self.discretization(*args, **kwargs)
        print(f"sigmas after discretization, before pruning img2img: ", sigmas)
        sigmas = torch.flip(sigmas, (0,))
        sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
        print("prune index:", max(int(self.strength * len(sigmas)), 1))
        sigmas = torch.flip(sigmas, (0,))
        print(f"sigmas after pruning: ", sigmas)
        return sigmas


def do_sample(
    model,
    sampler,
    value_dict,
    num_samples,
    H,
    W,
    C,
    F,
    force_uc_zero_embeddings: Optional[List] = None,
    batch2model_input: Optional[List] = None,
    return_latents=False,
    filter=None,
    device="cuda",
):
    if force_uc_zero_embeddings is None:
        force_uc_zero_embeddings = []
    if batch2model_input is None:
        batch2model_input = []

    with torch.no_grad():
        with autocast(device) as precision_scope:
            with model.ema_scope():
                num_samples = [num_samples]
                batch, batch_uc = get_batch(
                    get_unique_embedder_keys_from_conditioner(model.conditioner),
                    value_dict,
                    num_samples,
                )
                for key in batch:
                    if isinstance(batch[key], torch.Tensor):
                        print(key, batch[key].shape)
                    elif isinstance(batch[key], list):
                        print(key, [len(l) for l in batch[key]])
                    else:
                        print(key, batch[key])
                c, uc = model.conditioner.get_unconditional_conditioning(
                    batch,
                    batch_uc=batch_uc,
                    force_uc_zero_embeddings=force_uc_zero_embeddings,
                )

                for k in c:
                    if not k == "crossattn":
                        c[k], uc[k] = map(
                            lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
                        )

                additional_model_inputs = {}
                for k in batch2model_input:
                    additional_model_inputs[k] = batch[k]

                shape = (math.prod(num_samples), C, H // F, W // F)
                randn = torch.randn(shape).to(device)

                def denoiser(input, sigma, c):
                    return model.denoiser(
                        model.model, input, sigma, c, **additional_model_inputs
                    )

                samples_z = sampler(denoiser, randn, cond=c, uc=uc)
                samples_x = model.decode_first_stage(samples_z)
                samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)

                if filter is not None:
                    samples = filter(samples)

                if return_latents:
                    return samples, samples_z
                return samples


def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
    # Hardcoded demo setups; might undergo some changes in the future

    batch = {}
    batch_uc = {}

    for key in keys:
        if key == "txt":
            batch["txt"] = (
                np.repeat([value_dict["prompt"]], repeats=math.prod(N))
                .reshape(N)
                .tolist()
            )
            batch_uc["txt"] = (
                np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
                .reshape(N)
                .tolist()
            )
        elif key == "original_size_as_tuple":
            batch["original_size_as_tuple"] = (
                torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
                .to(device)
                .repeat(*N, 1)
            )
        elif key == "crop_coords_top_left":
            batch["crop_coords_top_left"] = (
                torch.tensor(
                    [value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
                )
                .to(device)
                .repeat(*N, 1)
            )
        elif key == "aesthetic_score":
            batch["aesthetic_score"] = (
                torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
            )
            batch_uc["aesthetic_score"] = (
                torch.tensor([value_dict["negative_aesthetic_score"]])
                .to(device)
                .repeat(*N, 1)
            )

        elif key == "target_size_as_tuple":
            batch["target_size_as_tuple"] = (
                torch.tensor([value_dict["target_height"], value_dict["target_width"]])
                .to(device)
                .repeat(*N, 1)
            )
        else:
            batch[key] = value_dict[key]

    for key in batch.keys():
        if key not in batch_uc and isinstance(batch[key], torch.Tensor):
            batch_uc[key] = torch.clone(batch[key])
    return batch, batch_uc


def get_input_image_tensor(image: Image.Image, device="cuda"):
    w, h = image.size
    print(f"loaded input image of size ({w}, {h})")
    width, height = map(
        lambda x: x - x % 64, (w, h)
    )  # resize to integer multiple of 64
    image = image.resize((width, height))
    image_array = np.array(image.convert("RGB"))
    image_array = image_array[None].transpose(0, 3, 1, 2)
    image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0
    return image_tensor.to(device)


def do_img2img(
    img,
    model,
    sampler,
    value_dict,
    num_samples,
    force_uc_zero_embeddings=[],
    additional_kwargs={},
    offset_noise_level: float = 0.0,
    return_latents=False,
    skip_encode=False,
    filter=None,
    device="cuda",
):
    with torch.no_grad():
        with autocast(device) as precision_scope:
            with model.ema_scope():
                batch, batch_uc = get_batch(
                    get_unique_embedder_keys_from_conditioner(model.conditioner),
                    value_dict,
                    [num_samples],
                )
                c, uc = model.conditioner.get_unconditional_conditioning(
                    batch,
                    batch_uc=batch_uc,
                    force_uc_zero_embeddings=force_uc_zero_embeddings,
                )

                for k in c:
                    c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))

                for k in additional_kwargs:
                    c[k] = uc[k] = additional_kwargs[k]
                if skip_encode:
                    z = img
                else:
                    z = model.encode_first_stage(img)
                noise = torch.randn_like(z)
                sigmas = sampler.discretization(sampler.num_steps)
                sigma = sigmas[0].to(z.device)

                if offset_noise_level > 0.0:
                    noise = noise + offset_noise_level * append_dims(
                        torch.randn(z.shape[0], device=z.device), z.ndim
                    )
                noised_z = z + noise * append_dims(sigma, z.ndim)
                noised_z = noised_z / torch.sqrt(
                    1.0 + sigmas[0] ** 2.0
                )  # Note: hardcoded to DDPM-like scaling. need to generalize later.

                def denoiser(x, sigma, c):
                    return model.denoiser(model.model, x, sigma, c)

                samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
                samples_x = model.decode_first_stage(samples_z)
                samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)

                if filter is not None:
                    samples = filter(samples)

                if return_latents:
                    return samples, samples_z
                return samples


================================================
FILE: sgm/lr_scheduler.py
================================================
import numpy as np


class LambdaWarmUpCosineScheduler:
    """
    note: use with a base_lr of 1.0
    """

    def __init__(
        self,
        warm_up_steps,
        lr_min,
        lr_max,
        lr_start,
        max_decay_steps,
        verbosity_interval=0,
    ):
        self.lr_warm_up_steps = warm_up_steps
        self.lr_start = lr_start
        self.lr_min = lr_min
        self.lr_max = lr_max
        self.lr_max_decay_steps = max_decay_steps
        self.last_lr = 0.0
        self.verbosity_interval = verbosity_interval

    def schedule(self, n, **kwargs):
        if self.verbosity_interval > 0:
            if n % self.verbosity_interval == 0:
                print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
        if n < self.lr_warm_up_steps:
            lr = (
                self.lr_max - self.lr_start
            ) / self.lr_warm_up_steps * n + self.lr_start
            self.last_lr = lr
            return lr
        else:
            t = (n - self.lr_warm_up_steps) / (
                self.lr_max_decay_steps - self.lr_warm_up_steps
            )
            t = min(t, 1.0)
            lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
                1 + np.cos(t * np.pi)
            )
            self.last_lr = lr
            return lr

    def __call__(self, n, **kwargs):
        return self.schedule(n, **kwargs)


class LambdaWarmUpCosineScheduler2:
    """
    supports repeated iterations, configurable via lists
    note: use with a base_lr of 1.0.
    """

    def __init__(
        self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
    ):
        assert (
            len(warm_up_steps)
            == len(f_min)
            == len(f_max)
            == len(f_start)
            == len(cycle_lengths)
        )
        self.lr_warm_up_steps = warm_up_steps
        self.f_start = f_start
        self.f_min = f_min
        self.f_max = f_max
        self.cycle_lengths = cycle_lengths
        self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
        self.last_f = 0.0
        self.verbosity_interval = verbosity_interval

    def find_in_interval(self, n):
        interval = 0
        for cl in self.cum_cycles[1:]:
            if n <= cl:
                return interval
            interval += 1

    def schedule(self, n, **kwargs):
        cycle = self.find_in_interval(n)
        n = n - self.cum_cycles[cycle]
        if self.verbosity_interval > 0:
            if n % self.verbosity_interval == 0:
                print(
                    f"current step: {n}, recent lr-multiplier: {self.last_f}, "
                    f"current cycle {cycle}"
                )
        if n < self.lr_warm_up_steps[cycle]:
            f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
                cycle
            ] * n + self.f_start[cycle]
            self.last_f = f
            return f
        else:
            t = (n - self.lr_warm_up_steps[cycle]) / (
                self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
            )
            t = min(t, 1.0)
            f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
                1 + np.cos(t * np.pi)
            )
            self.last_f = f
            return f

    def __call__(self, n, **kwargs):
        return self.schedule(n, **kwargs)


class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
    def schedule(self, n, **kwargs):
        cycle = self.find_in_interval(n)
        n = n - self.cum_cycles[cycle]
        if self.verbosity_interval > 0:
            if n % self.verbosity_interval == 0:
                print(
                    f"current step: {n}, recent lr-multiplier: {self.last_f}, "
                    f"current cycle {cycle}"
                )

        if n < self.lr_warm_up_steps[cycle]:
            f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
                cycle
            ] * n + self.f_start[cycle]
            self.last_f = f
            return f
        else:
            f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
                self.cycle_lengths[cycle] - n
            ) / (self.cycle_lengths[cycle])
            self.last_f = f
            return f


================================================
FILE: sgm/models/__init__.py
================================================
from .autoencoder import AutoencodingEngine
from .diffusion import DiffusionEngine


================================================
FILE: sgm/models/autoencoder.py
================================================
import logging
import math
import re
from abc import abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union

import pytorch_lightning as pl
import torch
import torch.nn as nn
from einops import rearrange
from packaging import version

from ..modules.autoencoding.regularizers import AbstractRegularizer
from ..modules.ema import LitEma
from ..util import (default, get_nested_attribute, get_obj_from_str,
                    instantiate_from_config)

logpy = logging.getLogger(__name__)


class AbstractAutoencoder(pl.LightningModule):
    """
    This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
    unCLIP models, etc. Hence, it is fairly general, and specific features
    (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
    """

    def __init__(
        self,
        ema_decay: Union[None, float] = None,
        monitor: Union[None, str] = None,
        input_key: str = "jpg",
    ):
        super().__init__()

        self.input_key = input_key
        self.use_ema = ema_decay is not None
        if monitor is not None:
            self.monitor = monitor

        if self.use_ema:
            self.model_ema = LitEma(self, decay=ema_decay)
            logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

        if version.parse(torch.__version__) >= version.parse("2.0.0"):
            self.automatic_optimization = False

    def apply_ckpt(self, ckpt: Union[None, str, dict]):
        if ckpt is None:
            return
        if isinstance(ckpt, str):
            ckpt = {
                "target": "sgm.modules.checkpoint.CheckpointEngine",
                "params": {"ckpt_path": ckpt},
            }
        engine = instantiate_from_config(ckpt)
        engine(self)

    @abstractmethod
    def get_input(self, batch) -> Any:
        raise NotImplementedError()

    def on_train_batch_end(self, *args, **kwargs):
        # for EMA computation
        if self.use_ema:
            self.model_ema(self)

    @contextmanager
    def ema_scope(self, context=None):
        if self.use_ema:
            self.model_ema.store(self.parameters())
            self.model_ema.copy_to(self)
            if context is not None:
                logpy.info(f"{context}: Switched to EMA weights")
        try:
            yield None
        finally:
            if self.use_ema:
                self.model_ema.restore(self.parameters())
                if context is not None:
                    logpy.info(f"{context}: Restored training weights")

    @abstractmethod
    def encode(self, *args, **kwargs) -> torch.Tensor:
        raise NotImplementedError("encode()-method of abstract base class called")

    @abstractmethod
    def decode(self, *args, **kwargs) -> torch.Tensor:
        raise NotImplementedError("decode()-method of abstract base class called")

    def instantiate_optimizer_from_config(self, params, lr, cfg):
        logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
        return get_obj_from_str(cfg["target"])(
            params, lr=lr, **cfg.get("params", dict())
        )

    def configure_optimizers(self) -> Any:
        raise NotImplementedError()


class AutoencodingEngine(AbstractAutoencoder):
    """
    Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
    (we also restore them explicitly as special cases for legacy reasons).
    Regularizations such as KL or VQ are moved to the regularizer class.
    """

    def __init__(
        self,
        *args,
        encoder_config: Dict,
        decoder_config: Dict,
        loss_config: Dict,
        regularizer_config: Dict,
        optimizer_config: Union[Dict, None] = None,
        lr_g_factor: float = 1.0,
        trainable_ae_params: Optional[List[List[str]]] = None,
        ae_optimizer_args: Optional[List[dict]] = None,
        trainable_disc_params: Optional[List[List[str]]] = None,
        disc_optimizer_args: Optional[List[dict]] = None,
        disc_start_iter: int = 0,
        diff_boost_factor: float = 3.0,
        ckpt_engine: Union[None, str, dict] = None,
        ckpt_path: Optional[str] = None,
        additional_decode_keys: Optional[List[str]] = None,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.automatic_optimization = False  # pytorch lightning

        self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
        self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
        self.loss: torch.nn.Module = instantiate_from_config(loss_config)
        self.regularization: AbstractRegularizer = instantiate_from_config(
            regularizer_config
        )
        self.optimizer_config = default(
            optimizer_config, {"target": "torch.optim.Adam"}
        )
        self.diff_boost_factor = diff_boost_factor
        self.disc_start_iter = disc_start_iter
        self.lr_g_factor = lr_g_factor
        self.trainable_ae_params = trainable_ae_params
        if self.trainable_ae_params is not None:
            self.ae_optimizer_args = default(
                ae_optimizer_args,
                [{} for _ in range(len(self.trainable_ae_params))],
            )
            assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
        else:
            self.ae_optimizer_args = [{}]  # makes type consitent

        self.trainable_disc_params = trainable_disc_params
        if self.trainable_disc_params is not None:
            self.disc_optimizer_args = default(
                disc_optimizer_args,
                [{} for _ in range(len(self.trainable_disc_params))],
            )
            assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
        else:
            self.disc_optimizer_args = [{}]  # makes type consitent

        if ckpt_path is not None:
            assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
            logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
        self.apply_ckpt(default(ckpt_path, ckpt_engine))
        self.additional_decode_keys = set(default(additional_decode_keys, []))

    def get_input(self, batch: Dict) -> torch.Tensor:
        # assuming unified data format, dataloader returns a dict.
        # image tensors should be scaled to -1 ... 1 and in channels-first
        # format (e.g., bchw instead if bhwc)
        return batch[self.input_key]

    def get_autoencoder_params(self) -> list:
        params = []
        if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
            params += list(self.loss.get_trainable_autoencoder_parameters())
        if hasattr(self.regularization, "get_trainable_parameters"):
            params += list(self.regularization.get_trainable_parameters())
        params = params + list(self.encoder.parameters())
        params = params + list(self.decoder.parameters())
        return params

    def get_discriminator_params(self) -> list:
        if hasattr(self.loss, "get_trainable_parameters"):
            params = list(self.loss.get_trainable_parameters())  # e.g., discriminator
        else:
            params = []
        return params

    def get_last_layer(self):
        return self.decoder.get_last_layer()

    def encode(
        self,
        x: torch.Tensor,
        return_reg_log: bool = False,
        unregularized: bool = False,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
        z = self.encoder(x)
        if unregularized:
            return z, dict()
        z, reg_log = self.regularization(z)
        if return_reg_log:
            return z, reg_log
        return z

    def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
        x = self.decoder(z, **kwargs)
        return x

    def forward(
        self, x: torch.Tensor, **additional_decode_kwargs
    ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
        z, reg_log = self.encode(x, return_reg_log=True)
        dec = self.decode(z, **additional_decode_kwargs)
        return z, dec, reg_log

    def inner_training_step(
        self, batch: dict, batch_idx: int, optimizer_idx: int = 0
    ) -> torch.Tensor:
        x = self.get_input(batch)
        additional_decode_kwargs = {
            key: batch[key] for key in self.additional_decode_keys.intersection(batch)
        }
        z, xrec, regularization_log = self(x, **additional_decode_kwargs)
        if hasattr(self.loss, "forward_keys"):
            extra_info = {
                "z": z,
                "optimizer_idx": optimizer_idx,
                "global_step": self.global_step,
                "last_layer": self.get_last_layer(),
                "split": "train",
                "regularization_log": regularization_log,
                "autoencoder": self,
            }
            extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
        else:
            extra_info = dict()

        if optimizer_idx == 0:
            # autoencode
            out_loss = self.loss(x, xrec, **extra_info)
            if isinstance(out_loss, tuple):
                aeloss, log_dict_ae = out_loss
            else:
                # simple loss function
                aeloss = out_loss
                log_dict_ae = {"train/loss/rec": aeloss.detach()}

            self.log_dict(
                log_dict_ae,
                prog_bar=False,
                logger=True,
                on_step=True,
                on_epoch=True,
                sync_dist=False,
            )
            self.log(
                "loss",
                aeloss.mean().detach(),
                prog_bar=True,
                logger=False,
                on_epoch=False,
                on_step=True,
            )
            return aeloss
        elif optimizer_idx == 1:
            # discriminator
            discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
            # -> discriminator always needs to return a tuple
            self.log_dict(
                log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
            )
            return discloss
        else:
            raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")

    def training_step(self, batch: dict, batch_idx: int):
        opts = self.optimizers()
        if not isinstance(opts, list):
            # Non-adversarial case
            opts = [opts]
        optimizer_idx = batch_idx % len(opts)
        if self.global_step < self.disc_start_iter:
            optimizer_idx = 0
        opt = opts[optimizer_idx]
        opt.zero_grad()
        with opt.toggle_model():
            loss = self.inner_training_step(
                batch, batch_idx, optimizer_idx=optimizer_idx
            )
            self.manual_backward(loss)
        opt.step()

    def validation_step(self, batch: dict, batch_idx: int) -> Dict:
        log_dict = self._validation_step(batch, batch_idx)
        with self.ema_scope():
            log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
            log_dict.update(log_dict_ema)
        return log_dict

    def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
        x = self.get_input(batch)

        z, xrec, regularization_log = self(x)
        if hasattr(self.loss, "forward_keys"):
            extra_info = {
                "z": z,
                "optimizer_idx": 0,
                "global_step": self.global_step,
                "last_layer": self.get_last_layer(),
                "split": "val" + postfix,
                "regularization_log": regularization_log,
                "autoencoder": self,
            }
            extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
        else:
            extra_info = dict()
        out_loss = self.loss(x, xrec, **extra_info)
        if isinstance(out_loss, tuple):
            aeloss, log_dict_ae = out_loss
        else:
            # simple loss function
            aeloss = out_loss
            log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
        full_log_dict = log_dict_ae

        if "optimizer_idx" in extra_info:
            extra_info["optimizer_idx"] = 1
            discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
            full_log_dict.update(log_dict_disc)
        self.log(
            f"val{postfix}/loss/rec",
            log_dict_ae[f"val{postfix}/loss/rec"],
            sync_dist=True,
        )
        self.log_dict(full_log_dict, sync_dist=True)
        return full_log_dict

    def get_param_groups(
        self, parameter_names: List[List[str]], optimizer_args: List[dict]
    ) -> Tuple[List[Dict[str, Any]], int]:
        groups = []
        num_params = 0
        for names, args in zip(parameter_names, optimizer_args):
            params = []
            for pattern_ in names:
                pattern_params = []
                pattern = re.compile(pattern_)
                for p_name, param in self.named_parameters():
                    if re.match(pattern, p_name):
                        pattern_params.append(param)
                        num_params += param.numel()
                if len(pattern_params) == 0:
                    logpy.warn(f"Did not find parameters for pattern {pattern_}")
                params.extend(pattern_params)
            groups.append({"params": params, **args})
        return groups, num_params

    def configure_optimizers(self) -> List[torch.optim.Optimizer]:
        if self.trainable_ae_params is None:
            ae_params = self.get_autoencoder_params()
        else:
            ae_params, num_ae_params = self.get_param_groups(
                self.trainable_ae_params, self.ae_optimizer_args
            )
            logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
        if self.trainable_disc_params is None:
            disc_params = self.get_discriminator_params()
        else:
            disc_params, num_disc_params = self.get_param_groups(
                self.trainable_disc_params, self.disc_optimizer_args
            )
            logpy.info(
                f"Number of trainable discriminator parameters: {num_disc_params:,}"
            )
        opt_ae = self.instantiate_optimizer_from_config(
            ae_params,
            default(self.lr_g_factor, 1.0) * self.learning_rate,
            self.optimizer_config,
        )
        opts = [opt_ae]
        if len(disc_params) > 0:
            opt_disc = self.instantiate_optimizer_from_config(
                disc_params, self.learning_rate, self.optimizer_config
            )
            opts.append(opt_disc)

        return opts

    @torch.no_grad()
    def log_images(
        self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
    ) -> dict:
        log = dict()
        additional_decode_kwargs = {}
        x = self.get_input(batch)
        additional_decode_kwargs.update(
            {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
        )

        _, xrec, _ = self(x, **additional_decode_kwargs)
        log["inputs"] = x
        log["reconstructions"] = xrec
        diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
        diff.clamp_(0, 1.0)
        log["diff"] = 2.0 * diff - 1.0
        # diff_boost shows location of small errors, by boosting their
        # brightness.
        log["diff_boost"] = (
            2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
        )
        if hasattr(self.loss, "log_images"):
            log.update(self.loss.log_images(x, xrec))
        with self.ema_scope():
            _, xrec_ema, _ = self(x, **additional_decode_kwargs)
            log["reconstructions_ema"] = xrec_ema
            diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
            diff_ema.clamp_(0, 1.0)
            log["diff_ema"] = 2.0 * diff_ema - 1.0
            log["diff_boost_ema"] = (
                2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
            )
        if additional_log_kwargs:
            additional_decode_kwargs.update(additional_log_kwargs)
            _, xrec_add, _ = self(x, **additional_decode_kwargs)
            log_str = "reconstructions-" + "-".join(
                [f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
            )
            log[log_str] = xrec_add
        return log


class AutoencodingEngineLegacy(AutoencodingEngine):
    def __init__(self, embed_dim: int, **kwargs):
        self.max_batch_size = kwargs.pop("max_batch_size", None)
        ddconfig = kwargs.pop("ddconfig")
        ckpt_path = kwargs.pop("ckpt_path", None)
        ckpt_engine = kwargs.pop("ckpt_engine", None)
        super().__init__(
            encoder_config={
                "target": "sgm.modules.diffusionmodules.model.Encoder",
                "params": ddconfig,
            },
            decoder_config={
                "target": "sgm.modules.diffusionmodules.model.Decoder",
                "params": ddconfig,
            },
            **kwargs,
        )
        self.quant_conv = torch.nn.Conv2d(
            (1 + ddconfig["double_z"]) * ddconfig["z_channels"],
            (1 + ddconfig["double_z"]) * embed_dim,
            1,
        )
        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
        self.embed_dim = embed_dim

        self.apply_ckpt(default(ckpt_path, ckpt_engine))

    def get_autoencoder_params(self) -> list:
        params = super().get_autoencoder_params()
        return params

    def encode(
        self, x: torch.Tensor, return_reg_log: bool = False
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
        if self.max_batch_size is None:
            z = self.encoder(x)
            z = self.quant_conv(z)
        else:
            N = x.shape[0]
            bs = self.max_batch_size
            n_batches = int(math.ceil(N / bs))
            z = list()
            for i_batch in range(n_batches):
                z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
                z_batch = self.quant_conv(z_batch)
                z.append(z_batch)
            z = torch.cat(z, 0)

        z, reg_log = self.regularization(z)
        if return_reg_log:
            return z, reg_log
        return z

    def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
        if self.max_batch_size is None:
            dec = self.post_quant_conv(z)
            dec = self.decoder(dec, **decoder_kwargs)
        else:
            N = z.shape[0]
            bs = self.max_batch_size
            n_batches = int(math.ceil(N / bs))
            dec = list()
            for i_batch in range(n_batches):
                dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
                dec_batch = self.decoder(dec_batch, **decoder_kwargs)
                dec.append(dec_batch)
            dec = torch.cat(dec, 0)

        return dec


class AutoencoderKL(AutoencodingEngineLegacy):
    def __init__(self, **kwargs):
        if "lossconfig" in kwargs:
            kwargs["loss_config"] = kwargs.pop("lossconfig")
        super().__init__(
            regularizer_config={
                "target": (
                    "sgm.modules.autoencoding.regularizers"
                    ".DiagonalGaussianRegularizer"
                )
            },
            **kwargs,
        )


class AutoencoderLegacyVQ(AutoencodingEngineLegacy):
    def __init__(
        self,
        embed_dim: int,
        n_embed: int,
        sane_index_shape: bool = False,
        **kwargs,
    ):
        if "lossconfig" in kwargs:
            logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.")
            kwargs["loss_config"] = kwargs.pop("lossconfig")
        super().__init__(
            regularizer_config={
                "target": (
                    "sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer"
                ),
                "params": {
                    "n_e": n_embed,
                    "e_dim": embed_dim,
                    "sane_index_shape": sane_index_shape,
                },
            },
            **kwargs,
        )


class IdentityFirstStage(AbstractAutoencoder):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def get_input(self, x: Any) -> Any:
        return x

    def encode(self, x: Any, *args, **kwargs) -> Any:
        return x

    def decode(self, x: Any, *args, **kwargs) -> Any:
        return x


class AEIntegerWrapper(nn.Module):
    def __init__(
        self,
        model: nn.Module,
        shape: Union[None, Tuple[int, int], List[int]] = (16, 16),
        regularization_key: str = "regularization",
        encoder_kwargs: Optional[Dict[str, Any]] = None,
    ):
        super().__init__()
        self.model = model
        assert hasattr(model, "encode") and hasattr(
            model, "decode"
        ), "Need AE interface"
        self.regularization = get_nested_attribute(model, regularization_key)
        self.shape = shape
        self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True})

    def encode(self, x) -> torch.Tensor:
        assert (
            not self.training
        ), f"{self.__class__.__name__} only supports inference currently"
        _, log = self.model.encode(x, **self.encoder_kwargs)
        assert isinstance(log, dict)
        inds = log["min_encoding_indices"]
        return rearrange(inds, "b ... -> b (...)")

    def decode(
        self, inds: torch.Tensor, shape: Union[None, tuple, list] = None
    ) -> torch.Tensor:
        # expect inds shape (b, s) with s = h*w
        shape = default(shape, self.shape)  # Optional[(h, w)]
        if shape is not None:
            assert len(shape) == 2, f"Unhandeled shape {shape}"
            inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1])
        h = self.regularization.get_codebook_entry(inds)  # (b, h, w, c)
        h = rearrange(h, "b h w c -> b c h w")
        return self.model.decode(h)


class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
    def __init__(self, **kwargs):
        if "lossconfig" in kwargs:
            kwargs["loss_config"] = kwargs.pop("lossconfig")
        super().__init__(
            regularizer_config={
                "target": (
                    "sgm.modules.autoencoding.regularizers"
                    ".DiagonalGaussianRegularizer"
                ),
                "params": {"sample": False},
            },
            **kwargs,
        )


================================================
FILE: sgm/models/diffusion.py
================================================
import math
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
from einops import rearrange, repeat

import pytorch_lightning as pl
import torch
from omegaconf import ListConfig, OmegaConf
from safetensors.torch import load_file as load_safetensors
from torch.optim.lr_scheduler import LambdaLR

from ..modules import UNCONDITIONAL_CONFIG
from ..modules.autoencoding.temporal_ae import VideoDecoder
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from ..modules.ema import LitEma
from ..util import (default, disabled_train, get_obj_from_str,
                    instantiate_from_config, log_txt_as_img)

class DiffusionEngine(pl.LightningModule):
    def __init__(
        self,
        network_config,
        denoiser_config,
        first_stage_config,
        conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
        sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
        optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
        scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
        loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
        network_wrapper: Union[None, str] = None,
        ckpt_path: Union[None, str] = None,
        use_ema: bool = False,
        ema_decay_rate: float = 0.9999,
        scale_factor: float = 1.0,
        disable_first_stage_autocast=False,
        input_key: str = "jpg",
        log_keys: Union[List, None] = None,
        no_cond_log: bool = False,
        compile_model: bool = False,
        en_and_decode_n_samples_a_time: Optional[int] = None,
    ):
        super().__init__()
        self.log_keys = log_keys
        self.input_key = input_key
        self.optimizer_config = default(
            optimizer_config, {"target": "torch.optim.AdamW"}
        )
        model = instantiate_from_config(network_config)
        self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
            model, compile_model=compile_model
        )

        self.denoiser = instantiate_from_config(denoiser_config)
        self.sampler = (
            instantiate_from_config(sampler_config)
            if sampler_config is not None
            else None
        )
        self.conditioner = instantiate_from_config(
            default(conditioner_config, UNCONDITIONAL_CONFIG)
        )
        self.scheduler_config = scheduler_config
        self._init_first_stage(first_stage_config)

        ## update with num_frames
        self.num_frames = network_config.params.num_frames

        self.use_ema = use_ema
        if self.use_ema:
            self.model_ema = LitEma(self.model, decay=ema_decay_rate)
            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

        self.scale_factor = scale_factor
        self.disable_first_stage_autocast = disable_first_stage_autocast
        self.no_cond_log = no_cond_log

        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path)

        self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time

    def init_from_ckpt(
        self,
        path: str,
    ) -> None:
        if path.endswith("ckpt"):
            sd = torch.load(path, map_location="cpu")["state_dict"]
        elif path.endswith("safetensors"):
            sd = load_safetensors(path)
        else:
            raise NotImplementedError

        missing, unexpected = self.load_state_dict(sd, strict=False)
        print(
            f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
        )
        if len(missing) > 0:
            print(f"Missing Keys: {missing}")
        if len(unexpected) > 0:
            print(f"Unexpected Keys: {unexpected}")

        # exit(0)

    def _init_first_stage(self, config):
        model = instantiate_from_config(config).eval()
        model.train = disabled_train
        for param in model.parameters():
            param.requires_grad = False
        self.first_stage_model = model

    @torch.no_grad()
    def decode_first_stage(self, z):
        z = 1.0 / self.scale_factor * z
        n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])

        n_rounds = math.ceil(z.shape[0] / n_samples)
        all_out = []
        with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
            for n in range(n_rounds):
                if isinstance(self.first_stage_model.decoder, VideoDecoder):
                    kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
                else:
                    kwargs = {}
                out = self.first_stage_model.decode(
                    z[n * n_samples : (n + 1) * n_samples], **kwargs
                )
                all_out.append(out)
        out = torch.cat(all_out, dim=0)
        return out
    
    # @torch.no_grad()
    def decode_first_stage_train(self, z):
        z = 1.0 / self.scale_factor * z
        n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])

        n_rounds = math.ceil(z.shape[0] / n_samples)
        all_out = []
        with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
            for n in range(n_rounds):
                if isinstance(self.first_stage_model.decoder, VideoDecoder):
                    kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
                else:
                    kwargs = {}
                out = self.first_stage_model.decode(
                    z[n * n_samples : (n + 1) * n_samples], **kwargs
                )
                all_out.append(out)
        out = torch.cat(all_out, dim=0)
        return out

    @torch.no_grad()
    def encode_first_stage(self, x):
        n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
        n_rounds = math.ceil(x.shape[0] / n_samples)
        all_out = []
        with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
            for n in range(n_rounds):
                out = self.first_stage_model.encode(
                    x[n * n_samples : (n + 1) * n_samples]
                )
                all_out.append(out)
        z = torch.cat(all_out, dim=0)
        z = self.scale_factor * z

        return z

    def forward(self, x, batch):
        loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
        loss_mean = loss.mean()
        loss_dict = {"loss": loss_mean}
        return loss_mean, loss_dict

    def get_input(self, batch):
        # assuming unified data format, dataloader returns a dict.
        # image tensors should be scaled to -1 ... 1 and in bchw format
        x = batch[self.input_key]
        x = rearrange(x, "b c t h w -> (b t) c h w")
        return x

    def shared_step(self, batch: Dict) -> Any:
        x = self.get_input(batch)
        x = self.encode_first_stage(x)
        batch["global_step"] = self.global_step
        loss, loss_dict = self(x, batch)
        return loss, loss_dict

    def training_step(self, batch, batch_idx):
        loss, loss_dict = self.shared_step(batch)
        self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False)
        self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False)

        if self.scheduler_config is not None:
            lr = self.optimizers().param_groups[0]["lr"]
            self.log("lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)

        return loss

    def on_train_start(self, *args, **kwargs):
        if self.sampler is None or self.loss_fn is None:
            raise ValueError("Sampler and loss function need to be set for training.")

    def on_train_batch_end(self, *args, **kwargs):
        if self.use_ema:
            self.model_ema(self.model)

    @contextmanager
    def ema_scope(self, context=None):
        if self.use_ema:
            self.model_ema.store(self.model.parameters())
            self.model_ema.copy_to(self.model)
            if context is not None:
                print(f"{context}: Switched to EMA weights")
        try:
            yield None
        finally:
            if self.use_ema:
                self.model_ema.restore(self.model.parameters())
                if context is not None:
                    print(f"{context}: Restored training weights")

    def instantiate_optimizer_from_config(self, params, lr, cfg):
        return get_obj_from_str(cfg["target"])(
            params, lr=lr, **cfg.get("params", dict())
        )

    def configure_optimizers(self):
        lr = self.learning_rate
        params = list(self.model.parameters())
        for embedder in self.conditioner.embedders:
            if embedder.is_trainable:
                params = params + list(embedder.parameters())

        print(f"@Training [{len(params)}] paramters.")
        opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
        if self.scheduler_config is not None:
            scheduler = instantiate_from_config(self.scheduler_config)
            print("Setting up LambdaLR scheduler...")
            scheduler = [
                {
                    "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
                    "interval": "step",
                    "frequency": 1,
                }
            ]
            return [opt], scheduler
        return opt

    @torch.no_grad()
    def sample(
        self,
        cond: Dict,
        uc: Union[Dict, None] = None,
        batch_size: int = 16,
        shape: Union[None, Tuple, List] = None,
        **kwargs,
    ):
        randn = torch.randn(batch_size, *shape).to(self.device)

        denoiser = lambda input, sigma, c: self.denoiser(
            self.model, input, sigma, c, **kwargs
        )
        samples = self.sampler(denoiser, randn, cond, uc=uc)
        return samples

    @torch.no_grad()
    def log_conditionings(self, batch: Dict, n: int) -> Dict:
        """
        Defines heuristics to log different conditionings.
        These can be lists of strings (text-to-image), tensors, ints, ...
        """
        image_h, image_w = batch[self.input_key].shape[2:]
        log = dict()

        for embedder in self.conditioner.embedders:
            if (
                (self.log_keys is None) or (embedder.input_key in self.log_keys)
            ) and not self.no_cond_log:
                x = batch[embedder.input_key][:n]
                if isinstance(x, torch.Tensor):
                    if x.dim() == 1:
                        # class-conditional, convert integer to string
                        x = [str(x[i].item()) for i in range(x.shape[0])]
                        xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
                    elif x.dim() == 2:
                        # size and crop cond and the like
                        x = [
                            "x".join([str(xx) for xx in x[i].tolist()])
                            for i in range(x.shape[0])
                        ]
                        xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
                    else:
                        raise NotImplementedError()
                elif isinstance(x, (List, ListConfig)):
                    if isinstance(x[0], str):
                        # strings
                        xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
                    else:
                        raise NotImplementedError()
                else:
                    raise NotImplementedError()
                log[embedder.input_key] = xc
        return log

    @torch.no_grad()
    def log_images(
        self,
        batch: Dict,
        N: int = 999,
        sample: bool = True,
        ucg_keys: List[str] = None,
        **kwargs,
    ) -> Dict:
        conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
        if ucg_keys:
            assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
                "Each defined ucg key for sampling must be in the provided conditioner input keys,"
                f"but we have {ucg_keys} vs. {conditioner_input_keys}"
            )
        else:
            ucg_keys = conditioner_input_keys
        log = dict()

        x = self.get_input(batch)
        c, uc = self.conditioner.get_unconditional_conditioning(
            batch,
            force_uc_zero_embeddings=ucg_keys
            if len(self.conditioner.embedders) > 0
            else [],
        )
        ## repeat conditions for each frames
        for k in ["crossattn", "concat", "vector"]:
            uc[k] = repeat(uc[k], "b ... -> b t ...", t=self.num_frames)
            uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=self.num_frames)
            c[k] = repeat(c[k], "b ... -> b t ...", t=self.num_frames)
            c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=self.num_frames)

        sampling_kwargs = {}

        N = min(x.shape[0], N)
        x = x.to(self.device)[:N]
        b = x.shape[0] // self.num_frames
        log["inputs"] = rearrange(x, '(b t) c h w -> b c t h w', b=b, t=self.num_frames)
        z = self.encode_first_stage(x)
        if self.disable_first_stage_autocast:
            self.disable_first_stage_autocast = False
            x_recon = self.decode_first_stage(z)
            self.disable_first_stage_autocast = True
        else:
            x_recon = self.decode_first_stage(z)
        log["reconstructions"] = rearrange(x_recon, '(b t) c h w -> b c t h w', b=b, t=self.num_frames)

        for k in c:
            if isinstance(c[k], torch.Tensor):
                c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))

        if sample:
            with self.ema_scope("Plotting"):
                samples = self.sample(
                    c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
                )
            if self.disable_first_stage_autocast:
                self.disable_first_stage_autocast = False
                samples = self.decode_first_stage(samples)
                self.disable_first_stage_autocast = True
            else:
                samples = self.decode_first_stage(samples)
            log["samples"] = rearrange(samples, '(b t) c h w -> b c t h w', b=b, t=self.num_frames)
        return log


================================================
FILE: sgm/modules/__init__.py
================================================
from .encoders.modules import GeneralConditioner

UNCONDITIONAL_CONFIG = {
    "target": "sgm.modules.GeneralConditioner",
    "params": {"emb_models": []},
}


================================================
FILE: sgm/modules/attention.py
================================================
import logging
import math
from inspect import isfunction
from typing import Any, Optional

import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from packaging import version
from torch import nn
from torch.utils.checkpoint import checkpoint

logpy = logging.getLogger(__name__)

if version.parse(torch.__version__) >= version.parse("2.0.0"):
    SDP_IS_AVAILABLE = True
    from torch.backends.cuda import SDPBackend, sdp_kernel

    BACKEND_MAP = {
        SDPBackend.MATH: {
            "enable_math": True,
            "enable_flash": False,
            "enable_mem_efficient": False,
        },
        SDPBackend.FLASH_ATTENTION: {
            "enable_math": False,
            "enable_flash": True,
            "enable_mem_efficient": False,
        },
        SDPBackend.EFFICIENT_ATTENTION: {
            "enable_math": False,
            "enable_flash": False,
            "enable_mem_efficient": True,
        },
        None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
    }
else:
    from contextlib import nullcontext

    SDP_IS_AVAILABLE = False
    sdp_kernel = nullcontext
    BACKEND_MAP = {}
    logpy.warn(
        f"No SDP backend available, likely because you are running in pytorch "
        f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
        f"You might want to consider upgrading."
    )

try:
    import xformers
    import xformers.ops

    XFORMERS_IS_AVAILABLE = True
except:
    XFORMERS_IS_AVAILABLE = False
    logpy.warn("no module 'xformers'. Processing without...")

# from .diffusionmodules.util import mixed_checkpoint as checkpoint


def exists(val):
    return val is not None


def uniq(arr):
    return {el: True for el in arr}.keys()


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def max_neg_value(t):
    return -torch.finfo(t.dtype).max


def init_(tensor):
    dim = tensor.shape[-1]
    std = 1 / math.sqrt(dim)
    tensor.uniform_(-std, std)
    return tensor


# feedforward
class GEGLU(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out * 2)

    def forward(self, x):
        x, gate = self.proj(x).chunk(2, dim=-1)
        return x * F.gelu(gate)


class FeedForward(nn.Module):
    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = default(dim_out, dim)
        project_in = (
            nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
            if not glu
            else GEGLU(dim, inner_dim)
        )

        self.net = nn.Sequential(
            project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
        )

    def forward(self, x):
        return self.net(x)


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


def Normalize(in_channels):
    return torch.nn.GroupNorm(
        num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
    )


class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x)
        q, k, v = rearrange(
            qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
        )
        k = k.softmax(dim=-1)
        context = torch.einsum("bhdn,bhen->bhde", k, v)
        out = torch.einsum("bhde,bhdn->bhen", context, q)
        out = rearrange(
            out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
        )
        return self.to_out(out)


class SelfAttention(nn.Module):
    ATTENTION_MODES = ("xformers", "torch", "math")

    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        qkv_bias: bool = False,
        qk_scale: Optional[float] = None,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        attn_mode: str = "xformers",
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        assert attn_mode in self.ATTENTION_MODES
        self.attn_mode = attn_mode

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, L, C = x.shape

        qkv = self.qkv(x)
        if self.attn_mode == "torch":
            qkv = rearrange(
                qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
            ).float()
            q, k, v = qkv[0], qkv[1], qkv[2]  # B H L D
            x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
            x = rearrange(x, "B H L D -> B L (H D)")
        elif self.attn_mode == "xformers":
            qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
            q, k, v = qkv[0], qkv[1], qkv[2]  # B L H D
            x = xformers.ops.memory_efficient_attention(q, k, v)
            x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads)
        elif self.attn_mode == "math":
            qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
            q, k, v = qkv[0], qkv[1], qkv[2]  # B H L D
            attn = (q @ k.transpose(-2, -1)) * self.scale
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = (attn @ v).transpose(1, 2).reshape(B, L, C)
        else:
            raise NotImplemented

        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class SpatialSelfAttention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = Normalize(in_channels)
        self.q = torch.nn.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.k = torch.nn.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.v = torch.nn.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.proj_out = torch.nn.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )

    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention
        b, c, h, w = q.shape
        q = rearrange(q, "b c h w -> b (h w) c")
        k = rearrange(k, "b c h w -> b c (h w)")
        w_ = torch.einsum("bij,bjk->bik", q, k)

        w_ = w_ * (int(c) ** (-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # attend to values
        v = rearrange(v, "b c h w -> b c (h w)")
        w_ = rearrange(w_, "b i j -> b j i")
        h_ = torch.einsum("bij,bjk->bik", v, w_)
        h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
        h_ = self.proj_out(h_)

        return x + h_


class CrossAttention(nn.Module):
    def __init__(
        self,
        query_dim,
        context_dim=None,
        heads=8,
        dim_head=64,
        dropout=0.0,
        backend=None,
    ):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)

        self.scale = dim_head**-0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
        )
        self.backend = backend

    def forward(
        self,
        x,
        context=None,
        mask=None,
        additional_tokens=None,
        n_times_crossframe_attn_in_self=0,
    ):
        h = self.heads

        if additional_tokens is not None:
            # get the number of masked tokens at the beginning of the output sequence
            n_tokens_to_mask = additional_tokens.shape[1]
            # add additional token
            x = torch.cat([additional_tokens, x], dim=1)

        q = self.to_q(x)
        context = default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)

        if n_times_crossframe_attn_in_self:
            # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
            assert x.shape[0] % n_times_crossframe_attn_in_self == 0
            n_cp = x.shape[0] // n_times_crossframe_attn_in_self
            k = repeat(
                k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
            )
            v = repeat(
                v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
            )

        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))

        ## old
        """
        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
        del q, k

        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h=h)
            sim.masked_fill_(~mask, max_neg_value)

        # attention, what we cannot get enough of
        sim = sim.softmax(dim=-1)

        out = einsum('b i j, b j d -> b i d', sim, v)
        """
        ## new
        with sdp_kernel(**BACKEND_MAP[self.backend]):
            # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
            out = F.scaled_dot_product_attention(
                q, k, v, attn_mask=mask
            )  # scale is dim_head ** -0.5 per default

        del q, k, v
        out = rearrange(out, "b h n d -> b n (h d)", h=h)

        if additional_tokens is not None:
            # remove additional token
            out = out[:, n_tokens_to_mask:]
        return self.to_out(out)


class MemoryEfficientCrossAttention(nn.Module):
    # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
    def __init__(
        self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
    ):
        super().__init__()
        logpy.debug(
            f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
            f"context_dim is {context_dim} and using {heads} heads with a "
            f"dimension of {dim_head}."
        )
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)

        self.heads = heads
        self.dim_head = dim_head

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
        )
        self.attention_op: Optional[Any] = None

    def forward(
        self,
        x,
        context=None,
        mask=None,
        additional_tokens=None,
        n_times_crossframe_attn_in_self=0,
        align_w_first_frame = False,
    ):
        # print(x.shape, context is None)
        if additional_tokens is not None:
            # get the number of masked tokens at the beginning of the output sequence
            n_tokens_to_mask = additional_tokens.shape[1]
            # add additional token
            x = torch.cat([additional_tokens, x], dim=1)
        q = self.to_q(x)
        # print(context is None)
        context = default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)
        # print(n_times_crossframe_attn_in_self)
        # exit(0)
        # n_times_crossframe_attn_in_self = 1
        # print(x.shape, q.shape, k.shape, v.shape, n_times_crossframe_attn_in_self)

        if n_times_crossframe_attn_in_self:
            # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
            assert x.shape[0] % n_times_crossframe_attn_in_self == 0
            # n_cp = x.shape[0]//n_times_crossframe_attn_in_self
            k = repeat(
                k[::n_times_crossframe_attn_in_self],
                "b ... -> (b n) ...",
                n=n_times_crossframe_attn_in_self,
            )
            v = repeat(
                v[::n_times_crossframe_attn_in_self],
                "b ... -> (b n) ...",
                n=n_times_crossframe_attn_in_self,
            )

        b, _, _ = q.shape
        q, k, v = map(
            lambda t: t.unsqueeze(3)
            .reshape(b, t.shape[1], self.heads, self.dim_head)
            .permute(0, 2, 1, 3)
            .reshape(b * self.heads, t.shape[1], self.dim_head)
            .contiguous(),
            (q, k, v),
        )

        # actually compute the attention, what we cannot get enough of
        if version.parse(xformers.__version__) >= version.parse("0.0.21"):
            # NOTE: workaround for
            # https://github.com/facebookresearch/xformers/issues/845
            max_bs = 32768
            N = q.shape[0]
            n_batches = math.ceil(N / max_bs)
            out = list()
            for i_batch in range(n_batches):
                batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
                out.append(
                    xformers.ops.memory_efficient_attention(
                        q[batch],
                        k[batch],
                        v[batch],
                        attn_bias=None,
                        op=self.attention_op,
                    )
                )
            out = torch.cat(out, 0)
        else:
            out = xformers.ops.memory_efficient_attention(
                q, k, v, attn_bias=None, op=self.attention_op
            )

        # TODO: Use this directly in the attention operation, as a bias
        if exists(mask):
            raise NotImplementedError
        out = (
            out.unsqueeze(0)
            .reshape(b, self.heads, out.shape[1], self.dim_head)
            .permute(0, 2, 1, 3)
            .reshape(b, out.shape[1], self.heads * self.dim_head)
        )
        if additional_tokens is not None:
            # remove additional token
            out = out[:, n_tokens_to_mask:]
        return self.to_out(out)


class BasicTransformerBlock(nn.Module):
    ATTENTION_MODES = {
        "softmax": CrossAttention,  # vanilla attention
        "softmax-xformers": MemoryEfficientCrossAttention,  # ampere
    }

    def __init__(
        self,
        dim,
        n_heads,
        d_head,
        dropout=0.0,
        context_dim=None,
        gated_ff=True,
        checkpoint=True,
        disable_self_attn=False,
        attn_mode="softmax",
        sdp_backend=None,
    ):
        super().__init__()
        assert attn_mode in self.ATTENTION_MODES
        if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
            logpy.warn(
                f"Attention mode '{attn_mode}' is not available. Falling "
                f"back to native attention. This is not a problem in "
                f"Pytorch >= 2.0. FYI, you are running with PyTorch "
                f"version {torch.__version__}."
            )
            attn_mode = "softmax"
        elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
            logpy.warn(
                "We do not support vanilla attention anymore, as it is too "
                "expensive. Sorry."
            )
            if not XFORMERS_IS_AVAILABLE:
                assert (
                    False
                ), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
            else:
                logpy.info("Falling back to xformers efficient attention.")
                attn_mode = "softmax-xformers"
        attn_cls = self.ATTENTION_MODES[attn_mode]
        if version.parse(torch.__version__) >= version.parse("2.0.0"):
            assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
        else:
            assert sdp_backend is None
        self.disable_self_attn = disable_self_attn
        self.attn1 = attn_cls(
            query_dim=dim,
            heads=n_heads,
            dim_head=d_head,
            dropout=dropout,
            context_dim=context_dim if self.disable_self_attn else None,
            backend=sdp_backend,
        )  # is a self-attention if not self.disable_self_attn
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
        self.attn2 = attn_cls(
            query_dim=dim,
            context_dim=context_dim,
            heads=n_heads,
            dim_head=d_head,
            dropout=dropout,
            backend=sdp_backend,
        )  # is self-attn if context is none
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        self.checkpoint = checkpoint
        if self.checkpoint:
            logpy.debug(f"{self.__class__.__name__} is using checkpointing")

    def forward(
        self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
    ):
        kwargs = {"x": x}

        if context is not None:
            kwargs.update({"context": context})

        if additional_tokens is not None:
            kwargs.update({"additional_tokens": additional_tokens})

        if n_times_crossframe_attn_in_self:
            kwargs.update(
                {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
            )

        # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
        if self.checkpoint:
            # inputs = {"x": x, "context": context}
            return checkpoint(self._forward, x, context)
            # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
        else:
            return self._forward(**kwargs)

    def _forward(
        self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
    ):
        x = (
            self.attn1(
                self.norm1(x),
                context=context if self.disable_self_attn else None,
                additional_tokens=additional_tokens,
                n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
                if not self.disable_self_attn
                else 0,
            )
            + x
        )
        x = (
            self.attn2(
                self.norm2(x), context=context, additional_tokens=additional_tokens
            )
            + x
        )
        x = self.ff(self.norm3(x)) + x
        return x


class BasicTransformerSingleLayerBlock(nn.Module):
    ATTENTION_MODES = {
        "softmax": CrossAttention,  # vanilla attention
        "softmax-xformers": MemoryEfficientCrossAttention  # on the A100s not quite as fast as the above version
        # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
    }

    def __init__(
        self,
        dim,
        n_heads,
        d_head,
        dropout=0.0,
        context_dim=None,
        gated_ff=True,
        checkpoint=True,
        attn_mode="softmax",
    ):
        super().__init__()
        assert attn_mode in self.ATTENTION_MODES
        attn_cls = self.ATTENTION_MODES[attn_mode]
        self.attn1 = attn_cls(
            query_dim=dim,
            heads=n_heads,
            dim_head=d_head,
            dropout=dropout,
            context_dim=context_dim,
        )
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.checkpoint = checkpoint

    def forward(self, x, context=None):
        # inputs = {"x": x, "context": context}
        # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
        return checkpoint(self._forward, x, context)

    def _forward(self, x, context=None):
        x = self.attn1(self.norm1(x), context=context) + x
        x = self.ff(self.norm2(x)) + x
        return x


class SpatialTransformer(nn.Module):
    """
    Transformer block for image-like data.
    First, project the input (aka embedding)
    and reshape to b, t, d.
    Then apply standard transformer action.
    Finally, reshape to image
    NEW: use_linear for more efficiency instead of the 1x1 convs
    """

    def __init__(
        self,
        in_channels,
        n_heads,
        d_head,
        depth=1,
        dropout=0.0,
        context_dim=None,
        disable_self_attn=False,
        use_linear=False,
        attn_type="softmax",
        use_checkpoint=True,
        # sdp_backend=SDPBackend.FLASH_ATTENTION
        sdp_backend=None,
    ):
        super().__init__()
        logpy.debug(
            f"constructing {self.__class__.__name__} of depth {depth} w/ "
            f"{in_channels} channels and {n_heads} heads."
        )

        if exists(context_dim) and not isinstance(context_dim, list):
            context_dim = [context_dim]
        if exists(context_dim) and isinstance(context_dim, list):
            if depth != len(context_dim):
                logpy.warn(
                    f"{self.__class__.__name__}: Found context dims "
                    f"{context_dim} of depth {len(context_dim)}, which does not "
                    f"match the specified 'depth' of {depth}. Setting context_dim "
                    f"to {depth * [context_dim[0]]} now."
                )
                # depth does not match context dims.
                assert all(
                    map(lambda x: x == context_dim[0], context_dim)
                ), "need homogenous context_dim to match depth automatically"
                context_dim = depth * [context_dim[0]]
        elif context_dim is None:
            context_dim = [None] * depth
        self.in_channels = in_channels
        inner_dim = n_heads * d_head
        self.norm = Normalize(in_channels)
        if not use_linear:
            self.proj_in = nn.Conv2d(
                in_channels, inner_dim, kernel_size=1, stride=1, padding=0
            )
        else:
            self.proj_in = nn.Linear(in_channels, inner_dim)

        self.transformer_blocks = nn.ModuleList(
            [
                BasicTransformerBlock(
                    inner_dim,
                    n_heads,
                    d_head,
                    dropout=dropout,
                    context_dim=context_dim[d],
                    disable_self_attn=disable_self_attn,
                    attn_mode=attn_type,
                    checkpoint=use_checkpoint,
                    sdp_backend=sdp_backend,
                )
                for d in range(depth)
            ]
        )
        if not use_linear:
            self.proj_out = zero_module(
                nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
            )
        else:
            # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
            self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
        self.use_linear = use_linear

    def forward(self, x, context=None):
        # note: if no context is given, cross-attention defaults to self-attention
        if not isinstance(context, list):
            context = [context]
        b, c, h, w = x.shape
        x_in = x
        x = self.norm(x)
        if not self.use_linear:
            x = self.proj_in(x)
        x = rearrange(x, "b c h w -> b (h w) c").contiguous()
        if self.use_linear:
            x = self.proj_in(x)
        for i, block in enumerate(self.transformer_blocks):
            if i > 0 and len(context) == 1:
                i = 0  # use same context for each block
            x = block(x, context=context[i])
        if self.use_linear:
            x = self.proj_out(x)
        x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
        if not self.use_linear:
            x = self.proj_out(x)
        return x + x_in


class SimpleTransformer(nn.Module):
    def __init__(
        self,
        dim: int,
        depth: int,
        heads: int,
        dim_head: int,
        context_dim: Optional[int] = None,
        dropout: float = 0.0,
        checkpoint: bool = True,
    ):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                BasicTransformerBlock(
                    dim,
                    heads,
                    dim_head,
                    dropout=dropout,
                    context_dim=context_dim,
                    attn_mode="softmax-xformers",
                    checkpoint=checkpoint,
                )
            )

    def forward(
        self,
        x: torch.Tensor,
        context: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x, context)
        return x


================================================
FILE: sgm/modules/autoencoding/__init__.py
================================================


================================================
FILE: sgm/modules/autoencoding/losses/__init__.py
================================================
__all__ = [
    "GeneralLPIPSWithDiscriminator",
    "LatentLPIPS",
]

from .discriminator_loss import GeneralLPIPSWithDiscriminator
from .lpips import LatentLPIPS


================================================
FILE: sgm/modules/autoencoding/losses/discriminator_loss.py
================================================
from typing import Dict, Iterator, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torchvision
from einops import rearrange
from matplotlib import colormaps
from matplotlib import pyplot as plt

from ....util import default, instantiate_from_config
from ..lpips.loss.lpips import LPIPS
from ..lpips.model.model import weights_init
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss


class GeneralLPIPSWithDiscriminator(nn.Module):
    def __init__(
        self,
        disc_start: int,
        logvar_init: float = 0.0,
        disc_num_layers: int = 3,
        disc_in_channels: int = 3,
        disc_factor: float = 1.0,
        disc_weight: float = 1.0,
        perceptual_weight: float = 1.0,
        disc_loss: str = "hinge",
        scale_input_to_tgt_size: bool = False,
        dims: int = 2,
        learn_logvar: bool = False,
        regularization_weights: Union[None, Dict[str, float]] = None,
        additional_log_keys: Optional[List[str]] = None,
        discriminator_config: Optional[Dict] = None,
    ):
        super().__init__()
        self.dims = dims
        if self.dims > 2:
            print(
                f"running with dims={dims}. This means that for perceptual loss "
                f"calculation, the LPIPS loss will be applied to each frame "
                f"independently."
            )
        self.scale_input_to_tgt_size = scale_input_to_tgt_size
        assert disc_loss in ["hinge", "vanilla"]
        self.perceptual_loss = LPIPS().eval()
        self.perceptual_weight = perceptual_weight
        # output log variance
        self.logvar = nn.Parameter(
            torch.full((), logvar_init), requires_grad=learn_logvar
        )
        self.learn_logvar = learn_logvar

        discriminator_config = default(
            discriminator_config,
            {
                "target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator",
                "params": {
                    "input_nc": disc_in_channels,
                    "n_layers": disc_num_layers,
                    "use_actnorm": F
Download .txt
gitextract_6p1lor9m/

├── LICENSE
├── README.md
├── configs/
│   ├── examples/
│   │   ├── constant_motion/
│   │   │   ├── head6.sh
│   │   │   ├── head7.sh
│   │   │   ├── kong.sh
│   │   │   ├── monkey.sh
│   │   │   ├── woman2.sh
│   │   │   └── woman5.sh
│   │   ├── multi_region/
│   │   │   ├── lawn2.sh
│   │   │   └── woman.sh
│   │   └── single_region/
│   │       ├── desert.sh
│   │       ├── dog.sh
│   │       ├── football.sh
│   │       ├── forest.sh
│   │       ├── head5.sh
│   │       ├── lawn.sh
│   │       ├── lizard.sh
│   │       ├── road.sh
│   │       ├── sea.sh
│   │       ├── sea2.sh
│   │       ├── sky.sh
│   │       └── woman4.sh
│   └── inference/
│       └── config_test.yaml
├── ctrl_model/
│   ├── diffusion_ctrl.py
│   └── svd_ctrl.py
├── main/
│   └── inference/
│       ├── sample_constant_motion.py
│       ├── sample_multi_region.py
│       └── sample_single_region.py
├── requirements.txt
├── sgm/
│   ├── __init__.py
│   ├── inference/
│   │   ├── api.py
│   │   └── helpers.py
│   ├── lr_scheduler.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── autoencoder.py
│   │   └── diffusion.py
│   ├── modules/
│   │   ├── __init__.py
│   │   ├── attention.py
│   │   ├── autoencoding/
│   │   │   ├── __init__.py
│   │   │   ├── losses/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── discriminator_loss.py
│   │   │   │   └── lpips.py
│   │   │   ├── lpips/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── loss/
│   │   │   │   │   ├── .gitignore
│   │   │   │   │   ├── LICENSE
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   └── lpips.py
│   │   │   │   ├── model/
│   │   │   │   │   ├── LICENSE
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   └── model.py
│   │   │   │   ├── util.py
│   │   │   │   └── vqperceptual.py
│   │   │   ├── regularizers/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── base.py
│   │   │   │   └── quantize.py
│   │   │   └── temporal_ae.py
│   │   ├── diffusionmodules/
│   │   │   ├── __init__.py
│   │   │   ├── denoiser.py
│   │   │   ├── denoiser_scaling.py
│   │   │   ├── denoiser_weighting.py
│   │   │   ├── discretizer.py
│   │   │   ├── guiders.py
│   │   │   ├── model.py
│   │   │   ├── openaimodel.py
│   │   │   ├── sampling.py
│   │   │   ├── sampling_utils.py
│   │   │   ├── sigma_sampling.py
│   │   │   ├── util.py
│   │   │   ├── video_model.py
│   │   │   └── wrappers.py
│   │   ├── distributions/
│   │   │   ├── __init__.py
│   │   │   └── distributions.py
│   │   ├── ema.py
│   │   ├── encoders/
│   │   │   ├── __init__.py
│   │   │   └── modules.py
│   │   └── video_attention.py
│   └── util.py
└── utils/
    ├── save_video.py
    ├── tools.py
    └── visualizer.py
Download .txt
SYMBOL INDEX (668 symbols across 42 files)

FILE: ctrl_model/diffusion_ctrl.py
  class DiffusionEngineCtrl (line 23) | class DiffusionEngineCtrl(DiffusionEngine):
    method __init__ (line 24) | def __init__(
  class CtrlNetWrapper (line 90) | class CtrlNetWrapper(OpenAIWrapper):
    method __init__ (line 91) | def __init__(self, diffusion_model, compile_model: bool = False, ctrln...
    method forward (line 95) | def forward(

FILE: ctrl_model/svd_ctrl.py
  class ControledVideoUnet (line 20) | class ControledVideoUnet(VideoUNet):
    method forward (line 24) | def forward(
  class ControlNetConditioningEmbedding (line 92) | class ControlNetConditioningEmbedding(nn.Module):
    method __init__ (line 93) | def __init__(
    method forward (line 115) | def forward(self, conditioning):
  class MaskEmbedding (line 128) | class MaskEmbedding(nn.Module):
    method __init__ (line 129) | def __init__(
    method forward (line 151) | def forward(self, conditioning):
  class WeightEmbedding (line 164) | class WeightEmbedding(nn.Module):
    method __init__ (line 165) | def __init__(
    method forward (line 189) | def forward(self, conditioning, t, cond_embeddings, mask_embeddings):
  class VideoCtrlNet (line 208) | class VideoCtrlNet(nn.Module):
    method __init__ (line 209) | def __init__(
    method make_zero_conv (line 531) | def make_zero_conv(self, channels):
    method init_from_ckpt (line 534) | def init_from_ckpt(self, ckpt_path=None):
    method init_from_unet (line 559) | def init_from_unet(self, unet):
    method forward (line 564) | def forward(

FILE: main/inference/sample_constant_motion.py
  function sample (line 32) | def sample(
  function get_parser (line 275) | def get_parser():

FILE: main/inference/sample_multi_region.py
  function sample (line 29) | def sample(
  function get_parser (line 289) | def get_parser():

FILE: main/inference/sample_single_region.py
  function sample (line 31) | def sample(
  function get_parser (line 285) | def get_parser():

FILE: sgm/inference/api.py
  class ModelArchitecture (line 19) | class ModelArchitecture(str, Enum):
  class Sampler (line 28) | class Sampler(str, Enum):
  class Discretization (line 37) | class Discretization(str, Enum):
  class Guider (line 42) | class Guider(str, Enum):
  class Thresholder (line 47) | class Thresholder(str, Enum):
  class SamplingParams (line 52) | class SamplingParams:
  class SamplingSpec (line 80) | class SamplingSpec:
  class SamplingPipeline (line 155) | class SamplingPipeline:
    method __init__ (line 156) | def __init__(
    method _load_model (line 173) | def _load_model(self, device="cuda", use_fp16=True):
    method text_to_image (line 184) | def text_to_image(
    method image_to_image (line 212) | def image_to_image(
    method refiner (line 245) | def refiner(
  function get_guider_config (line 280) | def get_guider_config(params: SamplingParams):
  function get_discretization_config (line 306) | def get_discretization_config(params: SamplingParams):
  function get_sampler_config (line 325) | def get_sampler_config(params: SamplingParams):

FILE: sgm/inference/helpers.py
  class WatermarkEmbedder (line 16) | class WatermarkEmbedder:
    method __init__ (line 17) | def __init__(self, watermark):
    method __call__ (line 23) | def __call__(self, image: torch.Tensor) -> torch.Tensor:
  function get_unique_embedder_keys_from_conditioner (line 61) | def get_unique_embedder_keys_from_conditioner(conditioner):
  function perform_save_locally (line 65) | def perform_save_locally(save_path, samples):
  class Img2ImgDiscretizationWrapper (line 77) | class Img2ImgDiscretizationWrapper:
    method __init__ (line 84) | def __init__(self, discretization, strength: float = 1.0):
    method __call__ (line 89) | def __call__(self, *args, **kwargs):
  function do_sample (line 101) | def do_sample(
  function get_batch (line 173) | def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
  function get_input_image_tensor (line 230) | def get_input_image_tensor(image: Image.Image, device="cuda"):
  function do_img2img (line 243) | def do_img2img(

FILE: sgm/lr_scheduler.py
  class LambdaWarmUpCosineScheduler (line 4) | class LambdaWarmUpCosineScheduler:
    method __init__ (line 9) | def __init__(
    method schedule (line 26) | def schedule(self, n, **kwargs):
    method __call__ (line 47) | def __call__(self, n, **kwargs):
  class LambdaWarmUpCosineScheduler2 (line 51) | class LambdaWarmUpCosineScheduler2:
    method __init__ (line 57) | def __init__(
    method find_in_interval (line 76) | def find_in_interval(self, n):
    method schedule (line 83) | def schedule(self, n, **kwargs):
    method __call__ (line 109) | def __call__(self, n, **kwargs):
  class LambdaLinearScheduler (line 113) | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
    method schedule (line 114) | def schedule(self, n, **kwargs):

FILE: sgm/models/autoencoder.py
  class AbstractAutoencoder (line 22) | class AbstractAutoencoder(pl.LightningModule):
    method __init__ (line 29) | def __init__(
    method apply_ckpt (line 49) | def apply_ckpt(self, ckpt: Union[None, str, dict]):
    method get_input (line 61) | def get_input(self, batch) -> Any:
    method on_train_batch_end (line 64) | def on_train_batch_end(self, *args, **kwargs):
    method ema_scope (line 70) | def ema_scope(self, context=None):
    method encode (line 85) | def encode(self, *args, **kwargs) -> torch.Tensor:
    method decode (line 89) | def decode(self, *args, **kwargs) -> torch.Tensor:
    method instantiate_optimizer_from_config (line 92) | def instantiate_optimizer_from_config(self, params, lr, cfg):
    method configure_optimizers (line 98) | def configure_optimizers(self) -> Any:
  class AutoencodingEngine (line 102) | class AutoencodingEngine(AbstractAutoencoder):
    method __init__ (line 109) | def __init__(
    method get_input (line 170) | def get_input(self, batch: Dict) -> torch.Tensor:
    method get_autoencoder_params (line 176) | def get_autoencoder_params(self) -> list:
    method get_discriminator_params (line 186) | def get_discriminator_params(self) -> list:
    method get_last_layer (line 193) | def get_last_layer(self):
    method encode (line 196) | def encode(
    method decode (line 210) | def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
    method forward (line 214) | def forward(
    method inner_training_step (line 221) | def inner_training_step(
    method training_step (line 281) | def training_step(self, batch: dict, batch_idx: int):
    method validation_step (line 298) | def validation_step(self, batch: dict, batch_idx: int) -> Dict:
    method _validation_step (line 305) | def _validation_step(self, batch: dict, batch_idx: int, postfix: str =...
    method get_param_groups (line 343) | def get_param_groups(
    method configure_optimizers (line 363) | def configure_optimizers(self) -> List[torch.optim.Optimizer]:
    method log_images (line 395) | def log_images(
  class AutoencodingEngineLegacy (line 437) | class AutoencodingEngineLegacy(AutoencodingEngine):
    method __init__ (line 438) | def __init__(self, embed_dim: int, **kwargs):
    method get_autoencoder_params (line 464) | def get_autoencoder_params(self) -> list:
    method encode (line 468) | def encode(
    method decode (line 490) | def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
  class AutoencoderKL (line 508) | class AutoencoderKL(AutoencodingEngineLegacy):
    method __init__ (line 509) | def __init__(self, **kwargs):
  class AutoencoderLegacyVQ (line 523) | class AutoencoderLegacyVQ(AutoencodingEngineLegacy):
    method __init__ (line 524) | def __init__(
  class IdentityFirstStage (line 549) | class IdentityFirstStage(AbstractAutoencoder):
    method __init__ (line 550) | def __init__(self, *args, **kwargs):
    method get_input (line 553) | def get_input(self, x: Any) -> Any:
    method encode (line 556) | def encode(self, x: Any, *args, **kwargs) -> Any:
    method decode (line 559) | def decode(self, x: Any, *args, **kwargs) -> Any:
  class AEIntegerWrapper (line 563) | class AEIntegerWrapper(nn.Module):
    method __init__ (line 564) | def __init__(
    method encode (line 580) | def encode(self, x) -> torch.Tensor:
    method decode (line 589) | def decode(
  class AutoencoderKLModeOnly (line 602) | class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
    method __init__ (line 603) | def __init__(self, **kwargs):

FILE: sgm/models/diffusion.py
  class DiffusionEngine (line 19) | class DiffusionEngine(pl.LightningModule):
    method __init__ (line 20) | def __init__(
    method init_from_ckpt (line 82) | def init_from_ckpt(
    method _init_first_stage (line 104) | def _init_first_stage(self, config):
    method decode_first_stage (line 112) | def decode_first_stage(self, z):
    method decode_first_stage_train (line 132) | def decode_first_stage_train(self, z):
    method encode_first_stage (line 152) | def encode_first_stage(self, x):
    method forward (line 167) | def forward(self, x, batch):
    method get_input (line 173) | def get_input(self, batch):
    method shared_step (line 180) | def shared_step(self, batch: Dict) -> Any:
    method training_step (line 187) | def training_step(self, batch, batch_idx):
    method on_train_start (line 198) | def on_train_start(self, *args, **kwargs):
    method on_train_batch_end (line 202) | def on_train_batch_end(self, *args, **kwargs):
    method ema_scope (line 207) | def ema_scope(self, context=None):
    method instantiate_optimizer_from_config (line 221) | def instantiate_optimizer_from_config(self, params, lr, cfg):
    method configure_optimizers (line 226) | def configure_optimizers(self):
    method sample (line 249) | def sample(
    method log_conditionings (line 266) | def log_conditionings(self, batch: Dict, n: int) -> Dict:
    method log_images (line 305) | def log_images(

FILE: sgm/modules/attention.py
  function exists (line 61) | def exists(val):
  function uniq (line 65) | def uniq(arr):
  function default (line 69) | def default(val, d):
  function max_neg_value (line 75) | def max_neg_value(t):
  function init_ (line 79) | def init_(tensor):
  class GEGLU (line 87) | class GEGLU(nn.Module):
    method __init__ (line 88) | def __init__(self, dim_in, dim_out):
    method forward (line 92) | def forward(self, x):
  class FeedForward (line 97) | class FeedForward(nn.Module):
    method __init__ (line 98) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
    method forward (line 112) | def forward(self, x):
  function zero_module (line 116) | def zero_module(module):
  function Normalize (line 125) | def Normalize(in_channels):
  class LinearAttention (line 131) | class LinearAttention(nn.Module):
    method __init__ (line 132) | def __init__(self, dim, heads=4, dim_head=32):
    method forward (line 139) | def forward(self, x):
  class SelfAttention (line 154) | class SelfAttention(nn.Module):
    method __init__ (line 157) | def __init__(
    method forward (line 179) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class SpatialSelfAttention (line 210) | class SpatialSelfAttention(nn.Module):
    method __init__ (line 211) | def __init__(self, in_channels):
    method forward (line 229) | def forward(self, x):
  class CrossAttention (line 255) | class CrossAttention(nn.Module):
    method __init__ (line 256) | def __init__(
    method forward (line 281) | def forward(
  class MemoryEfficientCrossAttention (line 347) | class MemoryEfficientCrossAttention(nn.Module):
    method __init__ (line 349) | def __init__(
    method forward (line 373) | def forward(
  class BasicTransformerBlock (line 463) | class BasicTransformerBlock(nn.Module):
    method __init__ (line 469) | def __init__(
    method forward (line 534) | def forward(
    method _forward (line 558) | def _forward(
  class BasicTransformerSingleLayerBlock (line 582) | class BasicTransformerSingleLayerBlock(nn.Module):
    method __init__ (line 589) | def __init__(
    method forward (line 615) | def forward(self, x, context=None):
    method _forward (line 620) | def _forward(self, x, context=None):
  class SpatialTransformer (line 626) | class SpatialTransformer(nn.Module):
    method __init__ (line 636) | def __init__(
    method forward (line 709) | def forward(self, x, context=None):
  class SimpleTransformer (line 733) | class SimpleTransformer(nn.Module):
    method __init__ (line 734) | def __init__(
    method forward (line 759) | def forward(

FILE: sgm/modules/autoencoding/losses/discriminator_loss.py
  class GeneralLPIPSWithDiscriminator (line 17) | class GeneralLPIPSWithDiscriminator(nn.Module):
    method __init__ (line 18) | def __init__(
    method get_trainable_parameters (line 85) | def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
    method get_trainable_autoencoder_parameters (line 88) | def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]:
    method log_images (line 94) | def log_images(
    method calculate_adaptive_weight (line 196) | def calculate_adaptive_weight(
    method forward (line 207) | def forward(
    method get_nll_loss (line 294) | def get_nll_loss(

FILE: sgm/modules/autoencoding/losses/lpips.py
  class LatentLPIPS (line 8) | class LatentLPIPS(nn.Module):
    method __init__ (line 9) | def __init__(
    method init_decoder (line 27) | def init_decoder(self, config):
    method forward (line 32) | def forward(self, latent_inputs, latent_predictions, image_inputs, spl...

FILE: sgm/modules/autoencoding/lpips/loss/lpips.py
  class LPIPS (line 12) | class LPIPS(nn.Module):
    method __init__ (line 14) | def __init__(self, use_dropout=True):
    method load_from_pretrained (line 28) | def load_from_pretrained(self, name="vgg_lpips"):
    method from_pretrained (line 36) | def from_pretrained(cls, name="vgg_lpips"):
    method forward (line 46) | def forward(self, input, target):
  class ScalingLayer (line 70) | class ScalingLayer(nn.Module):
    method __init__ (line 71) | def __init__(self):
    method forward (line 80) | def forward(self, inp):
  class NetLinLayer (line 84) | class NetLinLayer(nn.Module):
    method __init__ (line 87) | def __init__(self, chn_in, chn_out=1, use_dropout=False):
  class vgg16 (line 102) | class vgg16(torch.nn.Module):
    method __init__ (line 103) | def __init__(self, requires_grad=False, pretrained=True):
    method forward (line 126) | def forward(self, X):
  function normalize_tensor (line 144) | def normalize_tensor(x, eps=1e-10):
  function spatial_average (line 149) | def spatial_average(x, keepdim=True):

FILE: sgm/modules/autoencoding/lpips/model/model.py
  function weights_init (line 8) | def weights_init(m):
  class NLayerDiscriminator (line 17) | class NLayerDiscriminator(nn.Module):
    method __init__ (line 22) | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
    method forward (line 86) | def forward(self, input):

FILE: sgm/modules/autoencoding/lpips/util.py
  function download (line 16) | def download(url, local_path, chunk_size=1024):
  function md5_hash (line 28) | def md5_hash(path):
  function get_ckpt_path (line 34) | def get_ckpt_path(name, root, check=False):
  class ActNorm (line 45) | class ActNorm(nn.Module):
    method __init__ (line 46) | def __init__(
    method initialize (line 58) | def initialize(self, input):
    method forward (line 79) | def forward(self, input, reverse=False):
    method reverse (line 107) | def reverse(self, output):

FILE: sgm/modules/autoencoding/lpips/vqperceptual.py
  function hinge_d_loss (line 5) | def hinge_d_loss(logits_real, logits_fake):
  function vanilla_d_loss (line 12) | def vanilla_d_loss(logits_real, logits_fake):

FILE: sgm/modules/autoencoding/regularizers/__init__.py
  class DiagonalGaussianRegularizer (line 13) | class DiagonalGaussianRegularizer(AbstractRegularizer):
    method __init__ (line 14) | def __init__(self, sample: bool = True):
    method get_trainable_parameters (line 18) | def get_trainable_parameters(self) -> Any:
    method forward (line 21) | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:

FILE: sgm/modules/autoencoding/regularizers/base.py
  class AbstractRegularizer (line 9) | class AbstractRegularizer(nn.Module):
    method __init__ (line 10) | def __init__(self):
    method forward (line 13) | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
    method get_trainable_parameters (line 17) | def get_trainable_parameters(self) -> Any:
  class IdentityRegularizer (line 21) | class IdentityRegularizer(AbstractRegularizer):
    method forward (line 22) | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
    method get_trainable_parameters (line 25) | def get_trainable_parameters(self) -> Any:
  function measure_perplexity (line 29) | def measure_perplexity(

FILE: sgm/modules/autoencoding/regularizers/quantize.py
  class AbstractQuantizer (line 17) | class AbstractQuantizer(AbstractRegularizer):
    method __init__ (line 18) | def __init__(self):
    method remap_to_used (line 26) | def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor:
    method unmap_to_all (line 43) | def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor:
    method get_codebook_entry (line 55) | def get_codebook_entry(
    method get_trainable_parameters (line 60) | def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]:
  class GumbelQuantizer (line 64) | class GumbelQuantizer(AbstractQuantizer):
    method __init__ (line 73) | def __init__(
    method forward (line 119) | def forward(
    method get_codebook_entry (line 158) | def get_codebook_entry(self, indices, shape):
  class VectorQuantizer (line 172) | class VectorQuantizer(AbstractQuantizer):
    method __init__ (line 184) | def __init__(
    method forward (line 234) | def forward(
    method get_codebook_entry (line 302) | def get_codebook_entry(
  class EmbeddingEMA (line 323) | class EmbeddingEMA(nn.Module):
    method __init__ (line 324) | def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
    method forward (line 334) | def forward(self, embed_id):
    method cluster_size_ema_update (line 337) | def cluster_size_ema_update(self, new_cluster_size):
    method embed_avg_ema_update (line 342) | def embed_avg_ema_update(self, new_embed_avg):
    method weight_update (line 345) | def weight_update(self, num_tokens):
  class EMAVectorQuantizer (line 355) | class EMAVectorQuantizer(AbstractQuantizer):
    method __init__ (line 356) | def __init__(
    method forward (line 396) | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
  class VectorQuantizerWithInputProjection (line 446) | class VectorQuantizerWithInputProjection(VectorQuantizer):
    method __init__ (line 447) | def __init__(
    method forward (line 464) | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:

FILE: sgm/modules/autoencoding/temporal_ae.py
  class VideoResBlock (line 18) | class VideoResBlock(ResnetBlock):
    method __init__ (line 19) | def __init__(
    method get_alpha (line 56) | def get_alpha(self, bs):
    method forward (line 64) | def forward(self, x, temb, skip_video=False, timesteps=None):
  class AE3DConv (line 86) | class AE3DConv(torch.nn.Conv2d):
    method __init__ (line 87) | def __init__(self, in_channels, out_channels, video_kernel_size=3, *ar...
    method forward (line 101) | def forward(self, input, timesteps, skip_video=False):
  class VideoBlock (line 110) | class VideoBlock(AttnBlock):
    method __init__ (line 111) | def __init__(
    method forward (line 142) | def forward(self, x, timesteps, skip_video=False):
    method get_alpha (line 169) | def get_alpha(
  class MemoryEfficientVideoBlock (line 180) | class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock):
    method __init__ (line 181) | def __init__(
    method forward (line 212) | def forward(self, x, timesteps, skip_time_block=False):
    method get_alpha (line 239) | def get_alpha(
  function make_time_attn (line 250) | def make_time_attn(
  class Conv2DWrapper (line 288) | class Conv2DWrapper(torch.nn.Conv2d):
    method forward (line 289) | def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
  class VideoDecoder (line 293) | class VideoDecoder(Decoder):
    method __init__ (line 296) | def __init__(
    method get_last_layer (line 314) | def get_last_layer(self, skip_time_mix=False, **kwargs):
    method _make_attn (line 324) | def _make_attn(self) -> Callable:
    method _make_conv (line 334) | def _make_conv(self) -> Callable:
    method _make_resblock (line 340) | def _make_resblock(self) -> Callable:

FILE: sgm/modules/diffusionmodules/denoiser.py
  class Denoiser (line 11) | class Denoiser(nn.Module):
    method __init__ (line 12) | def __init__(self, scaling_config: Dict):
    method possibly_quantize_sigma (line 17) | def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
    method possibly_quantize_c_noise (line 20) | def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Te...
    method forward (line 23) | def forward(
  class DiscreteDenoiser (line 44) | class DiscreteDenoiser(Denoiser):
    method __init__ (line 45) | def __init__(
    method sigma_to_idx (line 63) | def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor:
    method idx_to_sigma (line 67) | def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor:
    method possibly_quantize_sigma (line 70) | def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
    method possibly_quantize_c_noise (line 73) | def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Te...

FILE: sgm/modules/diffusionmodules/denoiser_scaling.py
  class DenoiserScaling (line 7) | class DenoiserScaling(ABC):
    method __call__ (line 9) | def __call__(
  class EDMScaling (line 15) | class EDMScaling:
    method __init__ (line 16) | def __init__(self, sigma_data: float = 0.5):
    method __call__ (line 19) | def __call__(
  class EpsScaling (line 29) | class EpsScaling:
    method __call__ (line 30) | def __call__(
  class VScaling (line 40) | class VScaling:
    method __call__ (line 41) | def __call__(
  class VScalingWithEDMcNoise (line 51) | class VScalingWithEDMcNoise(DenoiserScaling):
    method __call__ (line 52) | def __call__(
  class VScalingWithEDMcNoise_fp16 (line 62) | class VScalingWithEDMcNoise_fp16(DenoiserScaling):
    method __call__ (line 63) | def __call__(

FILE: sgm/modules/diffusionmodules/denoiser_weighting.py
  class UnitWeighting (line 4) | class UnitWeighting:
    method __call__ (line 5) | def __call__(self, sigma):
  class EDMWeighting (line 9) | class EDMWeighting:
    method __init__ (line 10) | def __init__(self, sigma_data=0.5):
    method __call__ (line 13) | def __call__(self, sigma):
  class VWeighting (line 17) | class VWeighting(EDMWeighting):
    method __init__ (line 18) | def __init__(self):
  class EpsWeighting (line 22) | class EpsWeighting:
    method __call__ (line 23) | def __call__(self, sigma):

FILE: sgm/modules/diffusionmodules/discretizer.py
  function generate_roughly_equally_spaced_steps (line 11) | def generate_roughly_equally_spaced_steps(
  class Discretization (line 17) | class Discretization:
    method __call__ (line 18) | def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
    method get_sigmas (line 24) | def get_sigmas(self, n, device):
  class EDMDiscretization (line 28) | class EDMDiscretization(Discretization):
    method __init__ (line 29) | def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
    method get_sigmas (line 34) | def get_sigmas(self, n, device="cpu"):
  class LegacyDDPMDiscretization (line 42) | class LegacyDDPMDiscretization(Discretization):
    method __init__ (line 43) | def __init__(
    method get_sigmas (line 58) | def get_sigmas(self, n, device="cpu"):

FILE: sgm/modules/diffusionmodules/guiders.py
  class Guider (line 13) | class Guider(ABC):
    method __call__ (line 15) | def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
    method prepare_inputs (line 18) | def prepare_inputs(
  class VanillaCFG (line 24) | class VanillaCFG(Guider):
    method __init__ (line 25) | def __init__(self, scale: float):
    method __call__ (line 28) | def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
    method prepare_inputs (line 33) | def prepare_inputs(self, x, s, c, uc):
  class IdentityGuider (line 45) | class IdentityGuider(Guider):
    method __call__ (line 46) | def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
    method prepare_inputs (line 49) | def prepare_inputs(
  class LinearPredictionGuider (line 60) | class LinearPredictionGuider(Guider):
    method __init__ (line 61) | def __init__(
    method __call__ (line 78) | def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
    method prepare_inputs (line 89) | def prepare_inputs(
  class LinearPredictionGuider_fp16 (line 102) | class LinearPredictionGuider_fp16(Guider):
    method __init__ (line 103) | def __init__(
    method __call__ (line 120) | def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
    method prepare_inputs (line 130) | def prepare_inputs(

FILE: sgm/modules/diffusionmodules/model.py
  function get_timestep_embedding (line 26) | def get_timestep_embedding(timesteps, embedding_dim):
  function nonlinearity (line 47) | def nonlinearity(x):
  function Normalize (line 52) | def Normalize(in_channels, num_groups=32):
  class Upsample (line 58) | class Upsample(nn.Module):
    method __init__ (line 59) | def __init__(self, in_channels, with_conv):
    method forward (line 67) | def forward(self, x):
  class Downsample (line 74) | class Downsample(nn.Module):
    method __init__ (line 75) | def __init__(self, in_channels, with_conv):
    method forward (line 84) | def forward(self, x):
  class ResnetBlock (line 94) | class ResnetBlock(nn.Module):
    method __init__ (line 95) | def __init__(
    method forward (line 131) | def forward(self, x, temb):
  class LinAttnBlock (line 154) | class LinAttnBlock(LinearAttention):
    method __init__ (line 157) | def __init__(self, in_channels):
  class AttnBlock (line 161) | class AttnBlock(nn.Module):
    method __init__ (line 162) | def __init__(self, in_channels):
    method attention (line 180) | def attention(self, h_: torch.Tensor) -> torch.Tensor:
    method forward (line 197) | def forward(self, x, **kwargs):
  class MemoryEfficientAttnBlock (line 204) | class MemoryEfficientAttnBlock(nn.Module):
    method __init__ (line 212) | def __init__(self, in_channels):
    method attention (line 231) | def attention(self, h_: torch.Tensor) -> torch.Tensor:
    method forward (line 261) | def forward(self, x, **kwargs):
  class MemoryEfficientCrossAttentionWrapper (line 268) | class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
    method forward (line 269) | def forward(self, x, context=None, mask=None, **unused_kwargs):
  function make_attn (line 277) | def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
  class Model (line 312) | class Model(nn.Module):
    method __init__ (line 313) | def __init__(
    method forward (line 434) | def forward(self, x, t=None, context=None):
    method get_last_layer (line 483) | def get_last_layer(self):
  class Encoder (line 487) | class Encoder(nn.Module):
    method __init__ (line 488) | def __init__(
    method forward (line 576) | def forward(self, x):
  class Decoder (line 604) | class Decoder(nn.Module):
    method __init__ (line 605) | def __init__(
    method _make_attn (line 703) | def _make_attn(self) -> Callable:
    method _make_resblock (line 706) | def _make_resblock(self) -> Callable:
    method _make_conv (line 709) | def _make_conv(self) -> Callable:
    method get_last_layer (line 712) | def get_last_layer(self, **kwargs):
    method forward (line 715) | def forward(self, z, **kwargs):

FILE: sgm/modules/diffusionmodules/openaimodel.py
  class AttentionPool2d (line 22) | class AttentionPool2d(nn.Module):
    method __init__ (line 27) | def __init__(
    method forward (line 43) | def forward(self, x: th.Tensor) -> th.Tensor:
  class TimestepBlock (line 54) | class TimestepBlock(nn.Module):
    method forward (line 60) | def forward(self, x: th.Tensor, emb: th.Tensor):
  class TimestepEmbedSequential (line 66) | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    method forward (line 72) | def forward(
  class Upsample (line 107) | class Upsample(nn.Module):
    method __init__ (line 116) | def __init__(
    method forward (line 139) | def forward(self, x: th.Tensor) -> th.Tensor:
  class Downsample (line 160) | class Downsample(nn.Module):
    method __init__ (line 169) | def __init__(
    method forward (line 204) | def forward(self, x: th.Tensor) -> th.Tensor:
  class ResBlock (line 210) | class ResBlock(TimestepBlock):
    method __init__ (line 226) | def __init__(
    method forward (line 316) | def forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor:
    method _forward (line 328) | def _forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor:
  class AttentionBlock (line 357) | class AttentionBlock(nn.Module):
    method __init__ (line 364) | def __init__(
    method forward (line 393) | def forward(self, x: th.Tensor, **kwargs) -> th.Tensor:
    method _forward (line 396) | def _forward(self, x: th.Tensor) -> th.Tensor:
  class QKVAttentionLegacy (line 405) | class QKVAttentionLegacy(nn.Module):
    method __init__ (line 410) | def __init__(self, n_heads: int):
    method forward (line 414) | def forward(self, qkv: th.Tensor) -> th.Tensor:
  class QKVAttention (line 433) | class QKVAttention(nn.Module):
    method __init__ (line 438) | def __init__(self, n_heads: int):
    method forward (line 442) | def forward(self, qkv: th.Tensor) -> th.Tensor:
  class Timestep (line 463) | class Timestep(nn.Module):
    method __init__ (line 464) | def __init__(self, dim: int):
    method forward (line 468) | def forward(self, t: th.Tensor) -> th.Tensor:
  class UNetModel (line 472) | class UNetModel(nn.Module):
    method __init__ (line 502) | def __init__(
    method forward (line 816) | def forward(

FILE: sgm/modules/diffusionmodules/sampling.py
  class BaseDiffusionSampler (line 21) | class BaseDiffusionSampler:
    method __init__ (line 22) | def __init__(
    method prepare_sampling_loop (line 41) | def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
    method denoise (line 54) | def denoise(self, x, denoiser, sigma, cond, uc):
    method get_sigma_gen (line 60) | def get_sigma_gen(self, num_sigmas):
  class SingleStepDiffusionSampler (line 75) | class SingleStepDiffusionSampler(BaseDiffusionSampler):
    method sampler_step (line 76) | def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args...
    method euler_step (line 79) | def euler_step(self, x, d, dt):
  class EDMSampler (line 83) | class EDMSampler(SingleStepDiffusionSampler):
    method __init__ (line 84) | def __init__(
    method sampler_step (line 94) | def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, ...
    method __call__ (line 112) | def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
  class AncestralSampler (line 136) | class AncestralSampler(SingleStepDiffusionSampler):
    method __init__ (line 137) | def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs):
    method ancestral_euler_step (line 144) | def ancestral_euler_step(self, x, denoised, sigma, sigma_down):
    method ancestral_step (line 150) | def ancestral_step(self, x, sigma, next_sigma, sigma_up):
    method __call__ (line 158) | def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
  class LinearMultistepSampler (line 176) | class LinearMultistepSampler(BaseDiffusionSampler):
    method __init__ (line 177) | def __init__(
    method __call__ (line 187) | def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
  class EulerEDMSampler (line 214) | class EulerEDMSampler(EDMSampler):
    method possible_correction_step (line 215) | def possible_correction_step(
  class HeunEDMSampler (line 221) | class HeunEDMSampler(EDMSampler):
    method possible_correction_step (line 222) | def possible_correction_step(
  class EulerAncestralSampler (line 240) | class EulerAncestralSampler(AncestralSampler):
    method sampler_step (line 241) | def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc):
  class DPMPP2SAncestralSampler (line 250) | class DPMPP2SAncestralSampler(AncestralSampler):
    method get_variables (line 251) | def get_variables(self, sigma, sigma_down):
    method get_mult (line 257) | def get_mult(self, h, s, t, t_next):
    method sampler_step (line 265) | def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, ...
  class DPMPP2MSampler (line 290) | class DPMPP2MSampler(BaseDiffusionSampler):
    method get_variables (line 291) | def get_variables(self, sigma, next_sigma, previous_sigma=None):
    method get_mult (line 302) | def get_mult(self, h, r, t, t_next, previous_sigma):
    method sampler_step (line 313) | def sampler_step(
    method __call__ (line 347) | def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):

FILE: sgm/modules/diffusionmodules/sampling_utils.py
  function linear_multistep_coeff (line 7) | def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
  function get_ancestral_step (line 22) | def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
  function to_d (line 34) | def to_d(x, sigma, denoised):
  function to_neg_log_sigma (line 38) | def to_neg_log_sigma(sigma):
  function to_sigma (line 42) | def to_sigma(neg_log_sigma):

FILE: sgm/modules/diffusionmodules/sigma_sampling.py
  class EDMSampling (line 6) | class EDMSampling:
    method __init__ (line 7) | def __init__(self, p_mean=-1.2, p_std=1.2):
    method __call__ (line 11) | def __call__(self, n_samples, rand=None):
  class DiscreteSampling (line 16) | class DiscreteSampling:
    method __init__ (line 17) | def __init__(self, discretization_config, num_idx, do_append_zero=Fals...
    method idx_to_sigma (line 23) | def idx_to_sigma(self, idx):
    method __call__ (line 26) | def __call__(self, n_samples, rand=None):

FILE: sgm/modules/diffusionmodules/util.py
  function make_beta_schedule (line 20) | def make_beta_schedule(
  function extract_into_tensor (line 36) | def extract_into_tensor(a, t, x_shape):
  function mixed_checkpoint (line 42) | def mixed_checkpoint(func, inputs: dict, params, flag):
  class MixedCheckpointFunction (line 78) | class MixedCheckpointFunction(torch.autograd.Function):
    method forward (line 80) | def forward(
    method backward (line 120) | def backward(ctx, *output_grads):
  function checkpoint (line 155) | def checkpoint(func, inputs, params, flag):
  class CheckpointFunction (line 173) | class CheckpointFunction(torch.autograd.Function):
    method forward (line 175) | def forward(ctx, run_function, length, *args):
    method backward (line 189) | def backward(ctx, *output_grads):
  function timestep_embedding (line 209) | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=Fal...
  function zero_module (line 236) | def zero_module(module):
  function scale_module (line 245) | def scale_module(module, scale):
  function mean_flat (line 254) | def mean_flat(tensor):
  function normalization (line 261) | def normalization(channels):
  class SiLU (line 271) | class SiLU(nn.Module):
    method forward (line 272) | def forward(self, x):
  class GroupNorm32 (line 275) | class GroupNorm32(nn.GroupNorm):
    method forward (line 276) | def forward(self, x):
  function conv_nd (line 279) | def conv_nd(dims, *args, **kwargs):
  function linear (line 292) | def linear(*args, **kwargs):
  function avg_pool_nd (line 299) | def avg_pool_nd(dims, *args, **kwargs):
  class AlphaBlender (line 312) | class AlphaBlender(nn.Module):
    method __init__ (line 315) | def __init__(
    method get_alpha (line 341) | def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
    method forward (line 358) | def forward(

FILE: sgm/modules/diffusionmodules/video_model.py
  class VideoResBlock (line 13) | class VideoResBlock(ResBlock):
    method __init__ (line 14) | def __init__(
    method forward (line 63) | def forward(
  class VideoUNet (line 85) | class VideoUNet(nn.Module):
    method __init__ (line 86) | def __init__(
    method forward (line 444) | def forward(

FILE: sgm/modules/diffusionmodules/wrappers.py
  class IdentityWrapper (line 8) | class IdentityWrapper(nn.Module):
    method __init__ (line 9) | def __init__(self, diffusion_model, compile_model: bool = False):
    method forward (line 19) | def forward(self, *args, **kwargs):
  class OpenAIWrapper (line 23) | class OpenAIWrapper(IdentityWrapper):
    method forward (line 24) | def forward(

FILE: sgm/modules/distributions/distributions.py
  class AbstractDistribution (line 5) | class AbstractDistribution:
    method sample (line 6) | def sample(self):
    method mode (line 9) | def mode(self):
  class DiracDistribution (line 13) | class DiracDistribution(AbstractDistribution):
    method __init__ (line 14) | def __init__(self, value):
    method sample (line 17) | def sample(self):
    method mode (line 20) | def mode(self):
  class DiagonalGaussianDistribution (line 24) | class DiagonalGaussianDistribution(object):
    method __init__ (line 25) | def __init__(self, parameters, deterministic=False):
    method sample (line 37) | def sample(self):
    method kl (line 43) | def kl(self, other=None):
    method nll (line 62) | def nll(self, sample, dims=[1, 2, 3]):
    method mode (line 71) | def mode(self):
  function normal_kl (line 75) | def normal_kl(mean1, logvar1, mean2, logvar2):

FILE: sgm/modules/ema.py
  class LitEma (line 5) | class LitEma(nn.Module):
    method __init__ (line 6) | def __init__(self, model, decay=0.9999, use_num_upates=True):
    method reset_num_updates (line 29) | def reset_num_updates(self):
    method forward (line 33) | def forward(self, model):
    method copy_to (line 56) | def copy_to(self, model):
    method store (line 65) | def store(self, parameters):
    method restore (line 74) | def restore(self, parameters):

FILE: sgm/modules/encoders/modules.py
  class AbstractEmbModel (line 27) | class AbstractEmbModel(nn.Module):
    method __init__ (line 28) | def __init__(self):
    method is_trainable (line 35) | def is_trainable(self) -> bool:
    method ucg_rate (line 39) | def ucg_rate(self) -> Union[float, torch.Tensor]:
    method input_key (line 43) | def input_key(self) -> str:
    method is_trainable (line 47) | def is_trainable(self, value: bool):
    method ucg_rate (line 51) | def ucg_rate(self, value: Union[float, torch.Tensor]):
    method input_key (line 55) | def input_key(self, value: str):
    method is_trainable (line 59) | def is_trainable(self):
    method ucg_rate (line 63) | def ucg_rate(self):
    method input_key (line 67) | def input_key(self):
  class GeneralConditioner (line 71) | class GeneralConditioner(nn.Module):
    method __init__ (line 75) | def __init__(self, emb_models: Union[List, ListConfig]):
    method possibly_get_ucg_val (line 111) | def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict...
    method forward (line 120) | def forward(
    method get_unconditional_conditioning (line 167) | def get_unconditional_conditioning(
  class InceptionV3 (line 188) | class InceptionV3(nn.Module):
    method __init__ (line 192) | def __init__(self, normalize_input=False, **kwargs):
    method forward (line 199) | def forward(self, inp):
  class IdentityEncoder (line 208) | class IdentityEncoder(AbstractEmbModel):
    method encode (line 209) | def encode(self, x):
    method forward (line 212) | def forward(self, x):
  class ClassEmbedder (line 216) | class ClassEmbedder(AbstractEmbModel):
    method __init__ (line 217) | def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False):
    method forward (line 223) | def forward(self, c):
    method get_unconditional_conditioning (line 229) | def get_unconditional_conditioning(self, bs, device="cuda"):
  class ClassEmbedderForMultiCond (line 238) | class ClassEmbedderForMultiCond(ClassEmbedder):
    method forward (line 239) | def forward(self, batch, key=None, disable_dropout=False):
  class FrozenT5Embedder (line 250) | class FrozenT5Embedder(AbstractEmbModel):
    method __init__ (line 253) | def __init__(
    method freeze (line 264) | def freeze(self):
    method forward (line 270) | def forward(self, text):
    method encode (line 286) | def encode(self, text):
  class FrozenByT5Embedder (line 290) | class FrozenByT5Embedder(AbstractEmbModel):
    method __init__ (line 295) | def __init__(
    method freeze (line 306) | def freeze(self):
    method forward (line 312) | def forward(self, text):
    method encode (line 328) | def encode(self, text):
  class FrozenCLIPEmbedder (line 332) | class FrozenCLIPEmbedder(AbstractEmbModel):
    method __init__ (line 337) | def __init__(
    method freeze (line 362) | def freeze(self):
    method forward (line 369) | def forward(self, text):
    method encode (line 393) | def encode(self, text):
  class FrozenOpenCLIPEmbedder2 (line 397) | class FrozenOpenCLIPEmbedder2(AbstractEmbModel):
    method __init__ (line 404) | def __init__(
    method freeze (line 439) | def freeze(self):
    method forward (line 445) | def forward(self, text):
    method encode_with_transformer (line 455) | def encode_with_transformer(self, text):
    method pool (line 472) | def pool(self, x, text):
    method text_transformer_forward (line 480) | def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
    method encode (line 495) | def encode(self, text):
  class FrozenOpenCLIPEmbedder (line 499) | class FrozenOpenCLIPEmbedder(AbstractEmbModel):
    method __init__ (line 506) | def __init__(
    method freeze (line 535) | def freeze(self):
    method forward (line 540) | def forward(self, text):
    method encode_with_transformer (line 545) | def encode_with_transformer(self, text):
    method text_transformer_forward (line 554) | def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
    method encode (line 567) | def encode(self, text):
  class FrozenOpenCLIPImageEmbedder (line 571) | class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
    method __init__ (line 576) | def __init__(
    method preprocess (line 621) | def preprocess(self, x):
    method freeze (line 635) | def freeze(self):
    method forward (line 641) | def forward(self, image, no_dropout=False):
    method encode_with_vision_transformer (line 694) | def encode_with_vision_transformer(self, img):
    method encode (line 728) | def encode(self, text):
  class FrozenCLIPT5Encoder (line 732) | class FrozenCLIPT5Encoder(AbstractEmbModel):
    method __init__ (line 733) | def __init__(
    method encode (line 751) | def encode(self, text):
    method forward (line 754) | def forward(self, text):
  class SpatialRescaler (line 760) | class SpatialRescaler(nn.Module):
    method __init__ (line 761) | def __init__(
    method forward (line 800) | def forward(self, x):
    method encode (line 816) | def encode(self, x):
  class LowScaleEncoder (line 820) | class LowScaleEncoder(nn.Module):
    method __init__ (line 821) | def __init__(
    method register_schedule (line 840) | def register_schedule(
    method q_sample (line 888) | def q_sample(self, x_start, t, noise=None):
    method forward (line 896) | def forward(self, x):
    method decode (line 909) | def decode(self, z):
  class ConcatTimestepEmbedderND (line 914) | class ConcatTimestepEmbedderND(AbstractEmbModel):
    method __init__ (line 917) | def __init__(self, outdim):
    method forward (line 922) | def forward(self, x):
  class GaussianEncoder (line 933) | class GaussianEncoder(Encoder, AbstractEmbModel):
    method __init__ (line 934) | def __init__(
    method forward (line 942) | def forward(self, x) -> Tuple[Dict, torch.Tensor]:
  class VideoPredictionEmbedderWithEncoder (line 952) | class VideoPredictionEmbedderWithEncoder(AbstractEmbModel):
    method __init__ (line 953) | def __init__(
    method forward (line 985) | def forward(
  class FrozenOpenCLIPImagePredictionEmbedder (line 1028) | class FrozenOpenCLIPImagePredictionEmbedder(AbstractEmbModel):
    method __init__ (line 1029) | def __init__(
    method forward (line 1041) | def forward(self, vid):

FILE: sgm/modules/video_attention.py
  class TimeMixSequential (line 7) | class TimeMixSequential(nn.Sequential):
    method forward (line 8) | def forward(self, x, context=None, timesteps=None):
  class VideoTransformerBlock (line 15) | class VideoTransformerBlock(nn.Module):
    method __init__ (line 21) | def __init__(
    method forward (line 105) | def forward(
    method _forward (line 113) | def _forward(self, x, context=None, timesteps=None):
    method get_last_layer (line 156) | def get_last_layer(self):
  class SpatialVideoTransformer (line 160) | class SpatialVideoTransformer(SpatialTransformer):
    method __init__ (line 161) | def __init__(
    method forward (line 244) | def forward(

FILE: sgm/util.py
  function disabled_train (line 14) | def disabled_train(self, mode=True):
  function get_string_from_tuple (line 20) | def get_string_from_tuple(s):
  function is_power_of_two (line 36) | def is_power_of_two(n):
  function autocast (line 52) | def autocast(f, enabled=True):
  function load_partial_from_config (line 64) | def load_partial_from_config(config):
  function log_txt_as_img (line 68) | def log_txt_as_img(wh, xc, size=10):
  function partialclass (line 98) | def partialclass(cls, *args, **kwargs):
  function make_path_absolute (line 105) | def make_path_absolute(path):
  function ismap (line 112) | def ismap(x):
  function isimage (line 118) | def isimage(x):
  function isheatmap (line 124) | def isheatmap(x):
  function isneighbors (line 131) | def isneighbors(x):
  function exists (line 137) | def exists(x):
  function expand_dims_like (line 141) | def expand_dims_like(x, y):
  function default (line 147) | def default(val, d):
  function mean_flat (line 153) | def mean_flat(tensor):
  function count_params (line 161) | def count_params(model, verbose=False):
  function instantiate_from_config (line 168) | def instantiate_from_config(config):
  function get_obj_from_str (line 178) | def get_obj_from_str(string, reload=False, invalidate_cache=True):
  function append_zero (line 188) | def append_zero(x):
  function append_dims (line 192) | def append_dims(x, target_dims):
  function load_model_from_config (line 202) | def load_model_from_config(config, ckpt, verbose=True, freeze=True):
  function get_configs_path (line 233) | def get_configs_path() -> str:
  function get_nested_attribute (line 251) | def get_nested_attribute(obj, attribute_path, depth=None, return_key=Fal...

FILE: utils/save_video.py
  function frames_to_mp4 (line 15) | def frames_to_mp4(frame_dir, output_path, fps):
  function tensor_to_mp4 (line 28) | def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None):
  function tensor2videogrids (line 45) | def tensor2videogrids(video, root, filename, fps, rescale=True, clamp=Tr...
  function flow2rgb (line 63) | def flow2rgb(flow_map, max_value):
  function save_flow_video (line 77) | def save_flow_video(flow_tensor, output_file, fps=10, max_flow=None):
  function save_rgb_video (line 93) | def save_rgb_video(flow_tensor, output_file, fps=10, max_flow=None):
  function log_local (line 109) | def log_local(batch_logs, save_dir, filename, save_fps=10, rescale=True):
  function prepare_to_log (line 171) | def prepare_to_log(batch_logs, max_images=100000, clamp=True):
  function fill_with_black_squares (line 191) | def fill_with_black_squares(video, desired_len: int) -> Tensor:
  function load_num_videos (line 201) | def load_num_videos(data_path, num_videos):
  function npz_to_video_grid (line 214) | def npz_to_video_grid(data_path, out_path, num_frames, fps, num_videos=N...

FILE: utils/tools.py
  function quick_freeze (line 11) | def quick_freeze(model):
  function get_gaussian_kernel (line 16) | def get_gaussian_kernel(kernel_size, sigma, channels):
  function resize_pil_image (line 40) | def resize_pil_image(image, max_resolution=768 * 768, resize_short_edge=...
  function get_unique_embedder_keys_from_conditioner (line 53) | def get_unique_embedder_keys_from_conditioner(conditioner):
  function get_batch (line 57) | def get_batch(keys, value_dict, N, T, device):
  function load_model (line 98) | def load_model(

FILE: utils/visualizer.py
  function read_video_from_path (line 18) | def read_video_from_path(path):
  function draw_circle (line 30) | def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True):
  function draw_line (line 45) | def draw_line(rgb, coord_y, coord_x, color, linewidth):
  function add_weighted (line 55) | def add_weighted(rgb, alpha, original, beta, gamma):
  class Visualizer (line 59) | class Visualizer:
    method __init__ (line 60) | def __init__(
    method visualize (line 84) | def visualize(
    method save_video (line 130) | def save_video(self, video, filename, writer=None, step=0):
    method draw_tracks_on_video (line 159) | def draw_tracks_on_video(
    method _draw_pred_tracks (line 281) | def _draw_pred_tracks(
    method _draw_gt_tracks (line 312) | def _draw_gt_tracks(
Condensed preview — 80 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (478K chars).
[
  {
    "path": "LICENSE",
    "chars": 5928,
    "preview": "STABLE VIDEO DIFFUSION NON-COMMERCIAL COMMUNITY LICENSE AGREEMENT\t\nDated: November 21, 2023\n\n“AUP” means the Stability A"
  },
  {
    "path": "README.md",
    "chars": 6095,
    "preview": "# ReVideo: Remake a Video with Motion and Content Control\n[Chong Mou](https://scholar.google.com/citations?user=SYQoDk0A"
  },
  {
    "path": "configs/examples/constant_motion/head6.sh",
    "chars": 586,
    "preview": "name=\"svd-example2[fps6_mb127-temp]\"\nconfig=\"configs/inference/config_test.yaml\"\nckpt=\"ckpt/model.ckpt\"\nimage_input=\"tes"
  },
  {
    "path": "configs/examples/constant_motion/head7.sh",
    "chars": 584,
    "preview": "name=\"svd-example2[fps6_mb127-temp]\"\nconfig=\"configs/inference/config_test.yaml\"\nckpt=\"ckpt/model.ckpt\"\nimage_input=\"tes"
  },
  {
    "path": "configs/examples/constant_motion/kong.sh",
    "chars": 585,
    "preview": "name=\"svd-example2[fps6_mb127-temp]\"\nconfig=\"configs/inference/config_test.yaml\"\nckpt=\"ckpt/model.ckpt\"\nimage_input=\"tes"
  },
  {
    "path": "configs/examples/constant_motion/monkey.sh",
    "chars": 583,
    "preview": "name=\"svd-example2[fps6_mb127-temp]\"\nconfig=\"configs/inference/config_test.yaml\"\nckpt=\"ckpt/model.ckpt\"\nimage_input=\"tes"
  },
  {
    "path": "configs/examples/constant_motion/woman2.sh",
    "chars": 584,
    "preview": "name=\"svd-example2[fps6_mb127-temp]\"\nconfig=\"configs/inference/config_test.yaml\"\nckpt=\"ckpt/model.ckpt\"\nimage_input=\"tes"
  },
  {
    "path": "configs/examples/constant_motion/woman5.sh",
    "chars": 589,
    "preview": "name=\"svd-example2[fps6_mb127-temp]\"\nconfig=\"configs/inference/config_test.yaml\"\nckpt=\"ckpt/model.ckpt\"\nimage_input=\"tes"
  },
  {
    "path": "configs/examples/multi_region/lawn2.sh",
    "chars": 957,
    "preview": "name=\"svd-example2[fps6_mb127-temp]\"\nconfig=\"configs/inference/config_test.yaml\"\nckpt=\"ckpt/model.ckpt\"\nimage_input=\"tes"
  },
  {
    "path": "configs/examples/multi_region/woman.sh",
    "chars": 960,
    "preview": "name=\"svd-example2[fps6_mb127-temp]\"\nconfig=\"configs/inference/config_test.yaml\"\nckpt=\"ckpt/model.ckpt\"\nimage_input=\"tes"
  },
  {
    "path": "configs/examples/single_region/desert.sh",
    "chars": 624,
    "preview": "name=\"svd-example2[fps6_mb127-temp]\"\nconfig=\"configs/inference/config_test.yaml\"\nckpt=\"ckpt/model.ckpt\"\nimage_input=\"tes"
  },
  {
    "path": "configs/examples/single_region/dog.sh",
    "chars": 642,
    "preview": "name=\"svd-example2[fps6_mb127-temp]\"\nconfig=\"configs/inference/config_test.yaml\"\nckpt=\"ckpt/model.ckpt\"\nimage_input=\"tes"
  },
  {
    "path": "configs/examples/single_region/football.sh",
    "chars": 624,
    "preview": "name=\"svd-example2[fps6_mb127-temp]\"\nconfig=\"configs/inference/config_test.yaml\"\nckpt=\"ckpt/model.ckpt\"\nimage_input=\"tes"
  },
  {
    "path": "configs/examples/single_region/forest.sh",
    "chars": 639,
    "preview": "name=\"svd-example2[fps6_mb127-temp]\"\nconfig=\"configs/inference/config_test.yaml\"\nckpt=\"ckpt/model.ckpt\"\nimage_input=\"tes"
  },
  {
    "path": "configs/examples/single_region/head5.sh",
    "chars": 630,
    "preview": "name=\"svd-example2[fps6_mb127-temp]\"\nconfig=\"configs/inference/config_test.yaml\"\nckpt=\"ckpt/model.ckpt\"\nimage_input=\"tes"
  },
  {
    "path": "configs/examples/single_region/lawn.sh",
    "chars": 645,
    "preview": "name=\"svd-example2[fps6_mb127-temp]\"\nconfig=\"configs/inference/config_test.yaml\"\nckpt=\"ckpt/model.ckpt\"\nimage_input=\"tes"
  },
  {
    "path": "configs/examples/single_region/lizard.sh",
    "chars": 618,
    "preview": "name=\"svd-example2[fps6_mb127-temp]\"\nconfig=\"configs/inference/config_test.yaml\"\nckpt=\"ckpt/model.ckpt\"\nimage_input=\"tes"
  },
  {
    "path": "configs/examples/single_region/road.sh",
    "chars": 615,
    "preview": "name=\"svd-example2[fps6_mb127-temp]\"\nconfig=\"configs/inference/config_test.yaml\"\nckpt=\"ckpt/model.ckpt\"\nimage_input=\"tes"
  },
  {
    "path": "configs/examples/single_region/sea.sh",
    "chars": 629,
    "preview": "name=\"svd-example2[fps6_mb127-temp]\"\nconfig=\"configs/inference/config_test.yaml\"\nckpt=\"ckpt/model.ckpt\"\nimage_input=\"tes"
  },
  {
    "path": "configs/examples/single_region/sea2.sh",
    "chars": 615,
    "preview": "name=\"svd-example2[fps6_mb127-temp]\"\nconfig=\"configs/inference/config_test.yaml\"\nckpt=\"ckpt/model.ckpt\"\nimage_input=\"tes"
  },
  {
    "path": "configs/examples/single_region/sky.sh",
    "chars": 611,
    "preview": "name=\"svd-example2[fps6_mb127-temp]\"\nconfig=\"configs/inference/config_test.yaml\"\nckpt=\"ckpt/model.ckpt\"\nimage_input=\"tes"
  },
  {
    "path": "configs/examples/single_region/woman4.sh",
    "chars": 615,
    "preview": "name=\"svd-example2[fps6_mb127-temp]\"\nconfig=\"configs/inference/config_test.yaml\"\nckpt=\"ckpt/model.ckpt\"\nimage_input=\"tes"
  },
  {
    "path": "configs/inference/config_test.yaml",
    "chars": 6147,
    "preview": "num_frames: &num_frames 14\nmodel:\n  base_learning_rate: 3.0e-5\n  target: ctrl_model.diffusion_ctrl.DiffusionEngineCtrl\n "
  },
  {
    "path": "ctrl_model/diffusion_ctrl.py",
    "chars": 4385,
    "preview": "from functools import partial\n\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nfrom einops import rearrange,"
  },
  {
    "path": "ctrl_model/svd_ctrl.py",
    "chars": 24228,
    "preview": "from functools import partial\n\nfrom typing import List, Optional, Union, Tuple\n\nfrom einops import rearrange\n\nimport tor"
  },
  {
    "path": "main/inference/sample_constant_motion.py",
    "chars": 13355,
    "preview": "import datetime, time\nimport os, sys, argparse\nimport math\nfrom glob import glob\nfrom pathlib import Path\nfrom typing im"
  },
  {
    "path": "main/inference/sample_multi_region.py",
    "chars": 14221,
    "preview": "import datetime, time\nimport os, sys, argparse\nimport math\nfrom glob import glob\nfrom pathlib import Path\nfrom typing im"
  },
  {
    "path": "main/inference/sample_single_region.py",
    "chars": 14437,
    "preview": "import datetime, time\nimport os, sys, argparse\nimport math\nfrom glob import glob\nfrom pathlib import Path\nfrom typing im"
  },
  {
    "path": "requirements.txt",
    "chars": 744,
    "preview": "black==23.7.0\nchardet==5.1.0\nclip\neinops>=0.6.1\nfairscale==0.4.13\nfire==0.5.0\nfsspec>=2023.6.0\ninvisible-watermark>=0.2."
  },
  {
    "path": "sgm/__init__.py",
    "chars": 139,
    "preview": "from .models import AutoencodingEngine, DiffusionEngine\nfrom .util import get_configs_path, instantiate_from_config\n\n__v"
  },
  {
    "path": "sgm/inference/api.py",
    "chars": 11757,
    "preview": "import pathlib\nfrom dataclasses import asdict, dataclass\nfrom enum import Enum\nfrom typing import Optional\n\nfrom omegaco"
  },
  {
    "path": "sgm/inference/helpers.py",
    "chars": 10676,
    "preview": "import math\nimport os\nfrom typing import List, Optional, Union\n\nimport numpy as np\nimport torch\nfrom einops import rearr"
  },
  {
    "path": "sgm/lr_scheduler.py",
    "chars": 4286,
    "preview": "import numpy as np\n\n\nclass LambdaWarmUpCosineScheduler:\n    \"\"\"\n    note: use with a base_lr of 1.0\n    \"\"\"\n\n    def __i"
  },
  {
    "path": "sgm/models/__init__.py",
    "chars": 83,
    "preview": "from .autoencoder import AutoencodingEngine\nfrom .diffusion import DiffusionEngine\n"
  },
  {
    "path": "sgm/models/autoencoder.py",
    "chars": 22678,
    "preview": "import logging\nimport math\nimport re\nfrom abc import abstractmethod\nfrom contextlib import contextmanager\nfrom typing im"
  },
  {
    "path": "sgm/models/diffusion.py",
    "chars": 14391,
    "preview": "import math\nfrom contextlib import contextmanager\nfrom typing import Any, Dict, List, Optional, Tuple, Union\nfrom einops"
  },
  {
    "path": "sgm/modules/__init__.py",
    "chars": 159,
    "preview": "from .encoders.modules import GeneralConditioner\n\nUNCONDITIONAL_CONFIG = {\n    \"target\": \"sgm.modules.GeneralConditioner"
  },
  {
    "path": "sgm/modules/attention.py",
    "chars": 25335,
    "preview": "import logging\nimport math\nfrom inspect import isfunction\nfrom typing import Any, Optional\n\nimport torch\nimport torch.nn"
  },
  {
    "path": "sgm/modules/autoencoding/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "sgm/modules/autoencoding/losses/__init__.py",
    "chars": 164,
    "preview": "__all__ = [\n    \"GeneralLPIPSWithDiscriminator\",\n    \"LatentLPIPS\",\n]\n\nfrom .discriminator_loss import GeneralLPIPSWithD"
  },
  {
    "path": "sgm/modules/autoencoding/losses/discriminator_loss.py",
    "chars": 12077,
    "preview": "from typing import Dict, Iterator, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\ni"
  },
  {
    "path": "sgm/modules/autoencoding/losses/lpips.py",
    "chars": 2915,
    "preview": "import torch\nimport torch.nn as nn\n\nfrom ....util import default, instantiate_from_config\nfrom ..lpips.loss.lpips import"
  },
  {
    "path": "sgm/modules/autoencoding/lpips/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "sgm/modules/autoencoding/lpips/loss/.gitignore",
    "chars": 7,
    "preview": "vgg.pth"
  },
  {
    "path": "sgm/modules/autoencoding/lpips/loss/LICENSE",
    "chars": 1355,
    "preview": "Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang\nAll rights reserved.\n\nRedi"
  },
  {
    "path": "sgm/modules/autoencoding/lpips/loss/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "sgm/modules/autoencoding/lpips/loss/lpips.py",
    "chars": 5221,
    "preview": "\"\"\"Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models\"\"\"\n\nfrom collections import "
  },
  {
    "path": "sgm/modules/autoencoding/lpips/model/LICENSE",
    "chars": 3564,
    "preview": "Copyright (c) 2017, Jun-Yan Zhu and Taesung Park\nAll rights reserved.\n\nRedistribution and use in source and binary forms"
  },
  {
    "path": "sgm/modules/autoencoding/lpips/model/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "sgm/modules/autoencoding/lpips/model/model.py",
    "chars": 2850,
    "preview": "import functools\n\nimport torch.nn as nn\n\nfrom ..util import ActNorm\n\n\ndef weights_init(m):\n    classname = m.__class__._"
  },
  {
    "path": "sgm/modules/autoencoding/lpips/util.py",
    "chars": 3954,
    "preview": "import hashlib\nimport os\n\nimport requests\nimport torch\nimport torch.nn as nn\nfrom tqdm import tqdm\n\nURL_MAP = {\"vgg_lpip"
  },
  {
    "path": "sgm/modules/autoencoding/lpips/vqperceptual.py",
    "chars": 480,
    "preview": "import torch\nimport torch.nn.functional as F\n\n\ndef hinge_d_loss(logits_real, logits_fake):\n    loss_real = torch.mean(F."
  },
  {
    "path": "sgm/modules/autoencoding/regularizers/__init__.py",
    "chars": 877,
    "preview": "from abc import abstractmethod\nfrom typing import Any, Tuple\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functio"
  },
  {
    "path": "sgm/modules/autoencoding/regularizers/base.py",
    "chars": 1254,
    "preview": "from abc import abstractmethod\nfrom typing import Any, Tuple\n\nimport torch\nimport torch.nn.functional as F\nfrom torch im"
  },
  {
    "path": "sgm/modules/autoencoding/regularizers/quantize.py",
    "chars": 17424,
    "preview": "import logging\nfrom abc import abstractmethod\nfrom typing import Dict, Iterator, Literal, Optional, Tuple, Union\n\nimport"
  },
  {
    "path": "sgm/modules/autoencoding/temporal_ae.py",
    "chars": 11570,
    "preview": "from typing import Callable, Iterable, Union\n\nimport torch\nfrom einops import rearrange, repeat\n\nfrom sgm.modules.diffus"
  },
  {
    "path": "sgm/modules/diffusionmodules/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "sgm/modules/diffusionmodules/denoiser.py",
    "chars": 2615,
    "preview": "from typing import Dict, Union\n\nimport torch\nimport torch.nn as nn\n\nfrom ...util import append_dims, instantiate_from_co"
  },
  {
    "path": "sgm/modules/diffusionmodules/denoiser_scaling.py",
    "chars": 2418,
    "preview": "from abc import ABC, abstractmethod\nfrom typing import Tuple\n\nimport torch\n\n\nclass DenoiserScaling(ABC):\n    @abstractme"
  },
  {
    "path": "sgm/modules/diffusionmodules/denoiser_weighting.py",
    "chars": 516,
    "preview": "import torch\n\n\nclass UnitWeighting:\n    def __call__(self, sigma):\n        return torch.ones_like(sigma, device=sigma.de"
  },
  {
    "path": "sgm/modules/diffusionmodules/discretizer.py",
    "chars": 2314,
    "preview": "from abc import abstractmethod\nfrom functools import partial\n\nimport numpy as np\nimport torch\n\nfrom ...modules.diffusion"
  },
  {
    "path": "sgm/modules/diffusionmodules/guiders.py",
    "chars": 4698,
    "preview": "import logging\nfrom abc import ABC, abstractmethod\nfrom typing import Dict, List, Optional, Tuple, Union\n\nimport torch\nf"
  },
  {
    "path": "sgm/modules/diffusionmodules/model.py",
    "chars": 24004,
    "preview": "# pytorch_diffusion + derived encoder decoder\nimport logging\nimport math\nfrom typing import Any, Callable, Optional\n\nimp"
  },
  {
    "path": "sgm/modules/diffusionmodules/openaimodel.py",
    "chars": 31713,
    "preview": "import logging\nimport math\nfrom abc import abstractmethod\nfrom typing import Iterable, List, Optional, Tuple, Union\n\nimp"
  },
  {
    "path": "sgm/modules/diffusionmodules/sampling.py",
    "chars": 12025,
    "preview": "\"\"\"\n    Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py\n\"\"\"\n\n\nfrom ty"
  },
  {
    "path": "sgm/modules/diffusionmodules/sampling_utils.py",
    "chars": 1029,
    "preview": "import torch\nfrom scipy import integrate\n\nfrom ...util import append_dims\n\n\ndef linear_multistep_coeff(order, t, i, j, e"
  },
  {
    "path": "sgm/modules/diffusionmodules/sigma_sampling.py",
    "chars": 906,
    "preview": "import torch\n\nfrom ...util import default, instantiate_from_config\n\n\nclass EDMSampling:\n    def __init__(self, p_mean=-1"
  },
  {
    "path": "sgm/modules/diffusionmodules/util.py",
    "chars": 12426,
    "preview": "\"\"\"\npartially adopted from\nhttps://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion."
  },
  {
    "path": "sgm/modules/diffusionmodules/video_model.py",
    "chars": 18083,
    "preview": "from functools import partial\nfrom typing import List, Optional, Union\n\nfrom einops import rearrange\n\nimport torch\nfrom "
  },
  {
    "path": "sgm/modules/diffusionmodules/wrappers.py",
    "chars": 1032,
    "preview": "import torch\nimport torch.nn as nn\nfrom packaging import version\n\nOPENAIUNETWRAPPER = \"sgm.modules.diffusionmodules.wrap"
  },
  {
    "path": "sgm/modules/distributions/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "sgm/modules/distributions/distributions.py",
    "chars": 3095,
    "preview": "import numpy as np\nimport torch\n\n\nclass AbstractDistribution:\n    def sample(self):\n        raise NotImplementedError()\n"
  },
  {
    "path": "sgm/modules/ema.py",
    "chars": 3207,
    "preview": "import torch\nfrom torch import nn\n\n\nclass LitEma(nn.Module):\n    def __init__(self, model, decay=0.9999, use_num_upates="
  },
  {
    "path": "sgm/modules/encoders/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "sgm/modules/encoders/modules.py",
    "chars": 34886,
    "preview": "import math\nfrom contextlib import nullcontext\nfrom functools import partial\nfrom typing import Dict, List, Optional, Tu"
  },
  {
    "path": "sgm/modules/video_attention.py",
    "chars": 10106,
    "preview": "import torch\n\nfrom ..modules.attention import *\nfrom ..modules.diffusionmodules.util import AlphaBlender, linear, timest"
  },
  {
    "path": "sgm/util.py",
    "chars": 8466,
    "preview": "import functools\nimport importlib\nimport os\nfrom functools import partial\nfrom inspect import isfunction\n\nimport fsspec\n"
  },
  {
    "path": "utils/save_video.py",
    "chars": 10605,
    "preview": "import os\nimport numpy as np\nfrom tqdm import tqdm\nfrom PIL import Image\nfrom einops import rearrange\nimport cv2\n\nimport"
  },
  {
    "path": "utils/tools.py",
    "chars": 4067,
    "preview": "import torch\nimport torch.nn as nn\nimport numpy as np\nfrom omegaconf import OmegaConf\nfrom sgm.util import default, inst"
  },
  {
    "path": "utils/visualizer.py",
    "chars": 12487,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  }
]

About this extraction

This page contains the full source code of the MC-E/ReVideo GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 80 files (445.7 KB), approximately 113.2k tokens, and a symbol index with 668 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!