Showing preview only (325K chars total). Download the full file or copy to clipboard to get everything.
Repository: zhuangshaobin/Vlogger
Branch: main
Commit: 0e1766b92a99
Files: 39
Total size: 20.8 MB
Directory structure:
gitextract_ybj531gl/
├── LICENSE
├── MSYH.TTC
├── README.md
├── configs/
│ ├── vlog_read_script_sample.yaml
│ ├── vlog_write_script.yaml
│ ├── with_mask_ref_sample.yaml
│ └── with_mask_sample.yaml
├── datasets/
│ └── video_transforms.py
├── diffusion/
│ ├── __init__.py
│ ├── diffusion_utils.py
│ ├── gaussian_diffusion.py
│ ├── respace.py
│ └── timestep_sampler.py
├── models/
│ ├── __init__.py
│ ├── attention.py
│ ├── clip.py
│ ├── resnet.py
│ ├── unet.py
│ ├── unet_blocks.py
│ └── utils.py
├── requirements.txt
├── results/
│ └── vlog/
│ ├── teddy_travel/
│ │ ├── script/
│ │ │ ├── audio_prompts.txt
│ │ │ ├── protagonist_place_reference.txt
│ │ │ ├── protagonists_places.txt
│ │ │ ├── time_scripts.txt
│ │ │ ├── video_prompts.txt
│ │ │ └── zh_video_prompts.txt
│ │ └── story.txt
│ └── teddy_travel_/
│ └── story.txt
├── sample_scripts/
│ ├── vlog_read_script_sample.py
│ ├── vlog_write_script.py
│ ├── with_mask_ref_sample.py
│ └── with_mask_sample.py
├── utils.py
└── vlogger/
├── STEB/
│ └── model_transform.py
├── planning_utils/
│ └── gpt4_utils.py
├── videoaudio.py
├── videocaption.py
└── videofusion.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: MSYH.TTC
================================================
[File too large to display: 20.5 MB]
================================================
FILE: README.md
================================================
<div align="center">
<h1 align="center">Vlogger: Make Your Dream A Vlog</h1>
</a>
[Shaobin Zhuang](https://github.com/zhuangshaobin), [Kunchang Li](https://scholar.google.com/citations?user=D4tLSbsAAAAJ), [Xinyuan Chen†](https://scholar.google.com/citations?user=3fWSC8YAAAAJ), [Yaohui Wang†](https://scholar.google.com/citations?user=R7LyAb4AAAAJ), [Ziwei Liu](https://scholar.google.com/citations?user=lc45xlcAAAAJ), [Yu Qiao](https://scholar.google.com/citations?user=gFtI-8QAAAAJ&hl), [Yali Wang†](https://scholar.google.com/citations?user=hD948dkAAAAJ)
[](https://arxiv.org/abs/2401.09414)
[](https://zhuangshaobin.github.io/Vlogger.github.io/)
[](https://huggingface.co/GrayShine/Vlogger)
[](https://huggingface.co/spaces/GrayShine/Vlogger-ShowMaker)
[](https://youtu.be/ZRD1-jHbEGk)
</div>
</div>
In this work, we present **Vlogger**, a generic AI system for generating a **minute**-level video blog (i.e., vlog) of user descriptions. Different from short videos with a few seconds, vlog often contains a complex storyline with diversified scenes, which is challenging for most existing video generation approaches. To break through this bottleneck, our Vlogger smartly leverages Large Language Model (LLM) as Director and decomposes a long video generation task of vlog into four key stages, where we invoke various foundation models to play the critical roles of vlog professionals, including (1) Script, (2) Actor, (3) ShowMaker, and (4) Voicer. With such a design of mimicking human beings, our Vlogger can generate vlogs through explainable cooperation of top-down planning and bottom-up shooting. Moreover, we introduce a novel video diffusion model, **ShowMaker**, which serves as a videographer in our Vlogger for generating the video snippet of each shooting scene. By incorporating Script and Actor attentively as textual and visual prompts, it can effectively enhance spatial-temporal coherence in the snippet. Besides, we design a concise mixed training paradigm for ShowMaker, boosting its capacity for both T2V generation and prediction. Finally, the extensive experiments show that our method achieves state-of-the-art performance on zero-shot T2V generation and prediction tasks. More importantly, Vlogger can generate over 5-minute vlogs from open-world descriptions, without loss of video coherence on script and actor.
<div align="center">
<video src="https://github.com/zhuangshaobin/Vlogger/assets/94739615/1e8dd246-d3b9-49e9-8eee-d40b6d8523b9" controls="controls" width="500" height="300"></video>
<b>A compressed version of generated <a href="https://youtu.be/ZRD1-jHbEGk">Teddy Travel</a>.</b>
</div>
## Usage
<details>
<summary><h3>Setup</h3></summary>
<h4>Prepare Environment</h4>
```bash
conda create -n vlogger python==3.10.11
conda activate vlogger
pip install -r requirements.txt
```
<h4>Download our model and T2I base model</h4>
Our model is based on Stable diffusion v1.4, you may download [Stable Diffusion v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) and [OpenCLIP-ViT-H-14](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K) to the director of ``` pretrained ```
.
Download our model(ShowMaker) checkpoint (from [google drive](https://drive.google.com/file/d/1pAH73kz2QRfD2Dxk4lL3SrHvLAlWcPI3/view?usp=drive_link) or [hugging face](https://huggingface.co/GrayShine/Vlogger/tree/main)) and save to the directory of ```pretrained```
Now under `./pretrained`, you should be able to see the following:
```
├── pretrained
│ ├── ShowMaker.pt
│ ├── stable-diffusion-v1-4
│ ├── OpenCLIP-ViT-H-14
│ │ ├── ...
└── └── ├── ...
├── ...
```
</details>
<details>
<summary><h3>Inference for LLM planning and make reference image</h3></summary>
Run the following command to get script, actors and protagonist:
```python
python sample_scripts/vlog_write_script.py
```
- The generated scripts will be saved in ```results/vlog/$your_story_dir/script```.
- The generated reference images will be saved in ```results/vlog/$your_story_dir/img```.
- :warning: Enter your openai key in the 7th line of the file ```vlogger/planning_utils/gpt4_utils.py```
</details>
<details>
<summary><h3>Inference for vlog generation</h3></summary>
Run the following command to get the vlog:
```python
python sample_scripts/vlog_read_script_sample.py
```
- The generated scripts will be saved in ```results/vlog/$your_story_dir/video```.
</details>
<details>
<summary><h3>Inference for (T+I)2V </h3></summary>
Run the following command to get the (T+I)2V results:
```python
python sample_scripts/with_mask_sample.py
```
- The generated video will be saved in ```results/mask_no_ref```.
</details>
<details>
<summary><h3>Inference for (T+I+Ref)2V</h3></summary>
Run the following command to get the (T+I+Ref)2V results:
```python
python sample_scripts/with_mask_ref_sample.py
```
- The generated video will be saved in ```results/mask_ref```.
</details>
<details>
<summary><h3>More Details</h3></summary>
You may modify ```configs/with_mask_sample.yaml``` to change the (T+I)2V conditions and modify ```configs/with_mask_ref_sample.yaml``` to change the (T+I+Ref)2V conditions.
For example:
- ```ckpt``` is used to specify a model checkpoint.
- ```text_prompt``` is used to describe the content of the video.
- ```input_path``` is used to specify the path to the image.
- ```ref_path``` is used to specify the path to the reference image.
- ```save_path``` is used to specify the path to the generated video.
</details>
## Results
### (T+Ref)2V Results
<table class="center">
<tr>
<td style="text-align:center;width: 50%" colspan="1"><b>Reference Image</b></td>
<td style="text-align:center;width: 50%" colspan="1"><b>Output Video</b></td>
</tr>
<tr>
<td><img src="examples/TR2V/image/Egyptian_Pyramids.png" width="250">
<br>
<!-- <div class="text" style=" text-align:center;">
Scene Reference
</div> -->
<p align="center">Scene Reference</p>
</td>
<td>
<img src="examples/TR2V/video/Fireworks_explode_over_the_pyramids.gif" width="400">
<br>
<!-- <div class="text" style=" text-align:center;">
Fireworks explode over the pyramids.
</div> -->
<p align="center">Fireworks explode over the pyramids.</p>
</td>
</tr>
<tr>
<td><img src="examples/TR2V/image/Great_Wall.png" width="250">
<br>
<!-- <div class="text" style=" text-align:center;">
Scene Reference
</div> -->
<p align="center">Scene Reference</p>
</td>
<td>
<img src="examples/TR2V/video/The_Great_Wall_burning_with_raging_fire.gif" width="400">
<br>
<!-- <div class="text" style=" text-align:center;">
The Great Wall burning with raging fire.
</div> -->
<p align="center">The Great Wall burning with raging fire.</p>
</td>
</tr>
<tr>
<td><img src="examples/TR2V/image/a_green_cat.png" width="250">
<br>
<!-- <div class="text" style=" text-align:center;">
Object Reference
</div> -->
<p align="center">Object Reference</p>
</td>
<td>
<img src="examples/TR2V/video/A_cat_is_running_on_the_beach.gif" width="400">
<br>
<!-- <div class="text" style=" text-align:center;">
A cat is running on the beach.
</div> -->
<p align="center">A cat is running on the beach.</p>
</td>
</tr>
</table>
### (T+I)2V Results
<table class="center">
<tr>
<td style="text-align:center;width: 50%" colspan="1"><b>Input Image</b></td>
<td style="text-align:center;width: 50%" colspan="1"><b>Output Video</b></td>
</tr>
<tr>
<td><img src="input/i2v/Underwater_environment_cosmetic_bottles.png" width="400"></td>
<td>
<img src="examples/TI2V/Underwater_environment_cosmetic_bottles.gif" width="400">
<br>
<!-- <div class="text" style=" text-align:center;">
Underwater environment cosmetic bottles.
</div> -->
<p align="center">Underwater environment cosmetic bottles.</p>
</td>
</tr>
<tr>
<td><img src="input/i2v/A_big_drop_of_water_falls_on_a_rose_petal.png" width="400"></td>
<td>
<img src="examples/TI2V/A_big_drop_of_water_falls_on_a_rose_petal.gif" width="400">
<br>
<!-- <div class="text" style=" text-align:center;">
A big drop of water falls on a rose petal.
</div> -->
<p align="center">A big drop of water falls on a rose petal.</p>
</td>
</tr>
<tr>
<td><img src="input/i2v/A_fish_swims_past_an_oriental_woman.png" width="400"></td>
<td>
<img src="examples/TI2V/A_fish_swims_past_an_oriental_woman.gif" width="400">
<br>
<!-- <div class="text" style=" text-align:center;">
A fish swims past an oriental woman.
</div> -->
<p align="center">A fish swims past an oriental woman.</p>
</td>
</tr>
<tr>
<td><img src="input/i2v/Cinematic_photograph_View_of_piloting_aaero.png" width="400"></td>
<td>
<img src="examples/TI2V/Cinematic_photograph_View_of_piloting_aaero.gif" width="400">
<br>
<!-- <div class="text" style=" text-align:center;">
Cinematic photograph. View of piloting aaero.
</div> -->
<p align="center">Cinematic photograph. View of piloting aaero.</p>
</td>
</tr>
<tr>
<td><img src="input/i2v/Planet_hits_earth.png" width="400"></td>
<td>
<img src="examples/TI2V/Planet_hits_earth.gif" width="400">
<br>
<!-- <div class="text" style=" text-align:center;">
Planet hits earth.
</div> -->
<p align="center">Planet hits earth.</p>
</td>
</tr>
</table>
### T2V Results
<table>
<tr>
<td style="text-align:center;width: 66%" colspan="2"><b>Output Video</b></td>
</tr>
<tr>
<td>
<img src="examples/T2V/A_deer_looks_at_the_sunset_behind_him.gif"/>
<br>
<!-- <div class="text" style=" text-align:center;">
A deer looks at the sunset behind him.
</div> -->
<p align="center">A deer looks at the sunset behind him.</p>
</td>
<td>
<img src="examples/T2V/A_duck_is_teaching_math_to_another_duck.gif"/>
<br>
<!-- <div class="text" style=" text-align:center;">
A duck is teaching math to another duck.
</div> -->
<p align="center">A duck is teaching math to another duck.</p>
</td>
</tr>
<tr>
<td>
<img src="examples/T2V/Bezos_explores_tropical_rainforest.gif"/>
<br>
<!-- <div class="text" style=" text-align:center;">
Bezos explores tropical rainforest.
</div> -->
<p align="center">Bezos explores tropical rainforest.</p>
</td>
<td>
<img src="examples/T2V/Light_blue_water_lapping_on_the_beach.gif"/>
<br>
<!-- <div class="text" style=" text-align:center;">
Light blue water lapping on the beach.
</div> -->
<p align="center">Light blue water lapping on the beach.</p>
</td>
</tr>
</table>
## BibTeX
```bibtex
@article{zhuang2024vlogger,
title={Vlogger: Make Your Dream A Vlog},
author={Zhuang, Shaobin and Li, Kunchang and Chen, Xinyuan and Wang, Yaohui and Liu, Ziwei and Qiao, Yu and Wang, Yali},
journal={arXiv preprint arXiv:2401.09414},
year={2024}
}
```
```bibtex
@article{chen2023seine,
title={SEINE: Short-to-Long Video Diffusion Model for Generative Transition and Prediction},
author={Chen, Xinyuan and Wang, Yaohui and Zhang, Lingjun and Zhuang, Shaobin and Ma, Xin and Yu, Jiashuo and Wang, Yali and Lin, Dahua and Qiao, Yu and Liu, Ziwei},
journal={arXiv preprint arXiv:2310.20700},
year={2023}
}
```
```bibtex
@article{wang2023lavie,
title={LAVIE: High-Quality Video Generation with Cascaded Latent Diffusion Models},
author={Wang, Yaohui and Chen, Xinyuan and Ma, Xin and Zhou, Shangchen and Huang, Ziqi and Wang, Yi and Yang, Ceyuan and He, Yinan and Yu, Jiashuo and Yang, Peiqing and others},
journal={arXiv preprint arXiv:2309.15103},
year={2023}
}
```
## Disclaimer
We disclaim responsibility for user-generated content. The model was not trained to realistically represent people or events, so using it to generate such content is beyond the model's capabilities. It is prohibited for pornographic, violent and bloody content generation, and to generate content that is demeaning or harmful to people or their environment, culture, religion, etc. Users are solely liable for their actions. The project contributors are not legally affiliated with, nor accountable for users' behaviors. Use the generative model responsibly, adhering to ethical and legal standards.
## Contact Us
**Shaobin Zhuang**: [zhuangshaobin@pjlab.org.cn](mailto:zhuangshaobin@pjlab.org.cn), **Kunchang Li**: [likunchang@pjlab.org.cn](mailto:likunchang@pjlab.org.cn)
**Xinyuan Chen**: [chenxinyuan@pjlab.org.cn](mailto:chenxinyuan@pjlab.org.cn), **Yaohui Wang**: [wangyaohui@pjlab.org.cn](mailto:wangyaohui@pjlab.org.cn)
## Acknowledgements
The code is built upon [SEINE](https://github.com/Vchitect/SEINE), [LaVie](https://github.com/Vchitect/LaVie), [diffusers](https://github.com/huggingface/diffusers) and [Stable Diffusion](https://github.com/CompVis/stable-diffusion), we thank all the contributors for open-sourcing.
## License
The code is licensed under Apache-2.0, model weights are fully open for academic research and also allow **free** commercial usage. To apply for a commercial license, please contact zhuangshaobin@pjlab.org.cn.
================================================
FILE: configs/vlog_read_script_sample.yaml
================================================
# path:
ckpt: "pretrained/ShowMaker.pt"
pretrained_model_path: "pretrained/stable-diffusion-v1-4/"
image_encoder_path: "pretrained/OpenCLIP-ViT-H-14"
save_path: "results/vlog/teddy_travel/video"
# script path
reference_image_path: ["results/vlog/teddy_travel/ref_img/teddy.jpg"]
script_file_path: "results/vlog/teddy_travel/script/video_prompts.txt"
zh_script_file_path: "results/vlog/teddy_travel/script/zh_video_prompts.txt"
protagonist_file_path: "results/vlog/teddy_travel/script/protagonists_places.txt"
reference_file_path: "results/vlog/teddy_travel/script/protagonist_place_reference.txt"
time_file_path: "results/vlog/teddy_travel/script/time_scripts.txt"
video_transition: False
# model config:
model: UNet
num_frames: 16
image_size: [320, 512]
negative_prompt: "white background"
# sample config:
ref_cfg_scale: 0.3
seed: 3407
guidance_scale: 7.5
cfg_scale: 8.0
sample_method: 'ddim'
num_sampling_steps: 100
researve_frame: 3
mask_type: "first3"
use_mask: True
use_fp16: True
enable_xformers_memory_efficient_attention: True
do_classifier_free_guidance: True
fps: 8
sample_num:
# model speedup
use_compile: False
================================================
FILE: configs/vlog_write_script.yaml
================================================
# script path
story_path: "./results/vlog/teddy_travel_/story.txt"
only_one_protagonist: False
================================================
FILE: configs/with_mask_ref_sample.yaml
================================================
# path config:
ckpt: "pretrained/ShowMaker.pt"
pretrained_model_path: "pretrained/stable-diffusion-v1-4/"
image_encoder_path: "pretrained/OpenCLIP-ViT-H-14"
input_path: 'input/i2v/Planet_hits_earth.png'
ref_path: 'input/i2v/Planet_hits_earth.png'
save_path: "results/mask_ref/"
# model config:
model: UNet
num_frames: 16
# image_size: [320, 512]
image_size: [240, 560]
# model speedup
use_fp16: True
enable_xformers_memory_efficient_attention: True
# sample config:
seed: 3407
cfg_scale: 8.0
ref_cfg_scale: 0.5
sample_method: 'ddim'
num_sampling_steps: 100
text_prompt: [
# "Cinematic photograph. View of piloting aaero.",
# "A fish swims past an oriental woman.",
# "A big drop of water falls on a rose petal.",
# "Underwater environment cosmetic bottles.".
"Planet hits earth.",
]
additional_prompt: ""
negative_prompt: ""
do_classifier_free_guidance: True
mask_type: "first1"
use_mask: True
================================================
FILE: configs/with_mask_sample.yaml
================================================
# path config:
ckpt: "pretrained/ShowMaker.pt"
pretrained_model_path: "pretrained/OpenCLIP-ViT-H-14"
input_path: 'input/i2v/Planet_hits_earth.png'
save_path: "results/mask_no_ref/"
# model config:
model: UNet
num_frames: 16
# image_size: [320, 512]
image_size: [240, 560]
# model speedup
use_fp16: True
enable_xformers_memory_efficient_attention: True
# sample config:
seed: 3407
cfg_scale: 8.0
sample_method: 'ddim'
num_sampling_steps: 100
text_prompt: [
# "Cinematic photograph. View of piloting aaero.",
# "A fish swims past an oriental woman.",
# "A big drop of water falls on a rose petal.",
# "Underwater environment cosmetic bottles.".
"Planet hits earth.",
]
additional_prompt: ""
negative_prompt: ""
do_classifier_free_guidance: True
mask_type: "first1"
use_mask: True
================================================
FILE: datasets/video_transforms.py
================================================
import torch
import random
import numbers
from torchvision.transforms import RandomCrop, RandomResizedCrop
from PIL import Image
def _is_tensor_video_clip(clip):
if not torch.is_tensor(clip):
raise TypeError("clip should be Tensor. Got %s" % type(clip))
if not clip.ndimension() == 4:
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
return True
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
def crop(clip, i, j, h, w):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
"""
if len(clip.size()) != 4:
raise ValueError("clip should be a 4D tensor")
return clip[..., i : i + h, j : j + w]
def resize(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
def resize_scale(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
H, W = clip.size(-2), clip.size(-1)
scale_ = target_size[0] / min(H, W)
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
def resize_with_scale_factor(clip, scale_factor, interpolation_mode):
return torch.nn.functional.interpolate(clip, scale_factor=scale_factor, mode=interpolation_mode, align_corners=False)
def resize_scale_with_height(clip, target_size, interpolation_mode):
H, W = clip.size(-2), clip.size(-1)
scale_ = target_size / H
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
def resize_scale_with_weight(clip, target_size, interpolation_mode):
H, W = clip.size(-2), clip.size(-1)
scale_ = target_size / W
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
"""
Do spatial cropping and resizing to the video clip
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the cropped region.
w (int): Width of the cropped region.
size (tuple(int, int)): height and width of resized clip
Returns:
clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
clip = crop(clip, i, j, h, w)
clip = resize(clip, size, interpolation_mode)
return clip
def center_crop(clip, crop_size):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
# print(clip.shape)
th, tw = crop_size
if h < th or w < tw:
# print(h, w)
raise ValueError("height {} and width {} must be no smaller than crop_size".format(h, w))
i = int(round((h - th) / 2.0))
j = int(round((w - tw) / 2.0))
return crop(clip, i, j, th, tw)
def center_crop_using_short_edge(clip):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
if h < w:
th, tw = h, h
i = 0
j = int(round((w - tw) / 2.0))
else:
th, tw = w, w
i = int(round((h - th) / 2.0))
j = 0
return crop(clip, i, j, th, tw)
def random_shift_crop(clip):
'''
Slide along the long edge, with the short edge as crop size
'''
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
if h <= w:
long_edge = w
short_edge = h
else:
long_edge = h
short_edge =w
th, tw = short_edge, short_edge
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return crop(clip, i, j, th, tw)
def to_tensor(clip):
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
_is_tensor_video_clip(clip)
if not clip.dtype == torch.uint8:
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
# return clip.float().permute(3, 0, 1, 2) / 255.0
return clip.float() / 255.0
def normalize(clip, mean, std, inplace=False):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
mean (tuple): pixel RGB mean. Size is (3)
std (tuple): pixel standard deviation. Size is (3)
Returns:
normalized clip (torch.tensor): Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
if not inplace:
clip = clip.clone()
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
# print(mean)
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
return clip
def hflip(clip):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
Returns:
flipped clip (torch.tensor): Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
return clip.flip(-1)
class RandomCropVideo:
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: randomly cropped video clip.
size is (T, C, OH, OW)
"""
i, j, h, w = self.get_params(clip)
return crop(clip, i, j, h, w)
def get_params(self, clip):
h, w = clip.shape[-2:]
th, tw = self.size
if h < th or w < tw:
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
if w == tw and h == th:
return 0, 0, h, w
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return i, j, th, tw
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
class CenterCropResizeVideo:
'''
First use the short side for cropping length,
center crop video, then resize to the specified size
'''
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
# print(clip.shape)
clip_center_crop = center_crop_using_short_edge(clip)
# print(clip_center_crop.shape) 320 512
clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode)
return clip_center_crop_resize
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class CenterCropVideo:
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_center_crop = center_crop(clip, self.size)
return clip_center_crop
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class NormalizeVideo:
"""
Normalize the video clip by mean subtraction and division by standard deviation
Args:
mean (3-tuple): pixel RGB mean
std (3-tuple): pixel RGB standard deviation
inplace (boolean): whether do in-place normalization
"""
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace
def __call__(self, clip):
"""
Args:
clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
"""
return normalize(clip, self.mean, self.std, self.inplace)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
class ToTensorVideo:
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
"""
def __init__(self):
pass
def __call__(self, clip):
"""
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
return to_tensor(clip)
def __repr__(self) -> str:
return self.__class__.__name__
class ResizeVideo():
'''
First use the short side for cropping length,
center crop video, then resize to the specified size
'''
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
return clip_resize
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
# ------------------------------------------------------------
# --------------------- Sampling ---------------------------
# ------------------------------------------------------------
================================================
FILE: diffusion/__init__.py
================================================
# Modified from OpenAI's diffusion repos
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
from . import gaussian_diffusion as gd
from .respace import SpacedDiffusion, space_timesteps
def create_diffusion(
timestep_respacing,
noise_schedule="linear",
use_kl=False,
sigma_small=False,
predict_xstart=False,
# learn_sigma=True,
learn_sigma=False, # for unet
rescale_learned_sigmas=False,
diffusion_steps=1000
):
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
if use_kl:
loss_type = gd.LossType.RESCALED_KL
elif rescale_learned_sigmas:
loss_type = gd.LossType.RESCALED_MSE
else:
loss_type = gd.LossType.MSE
if timestep_respacing is None or timestep_respacing == "":
timestep_respacing = [diffusion_steps]
return SpacedDiffusion(
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
betas=betas,
model_mean_type=(
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
),
model_var_type=(
(
gd.ModelVarType.FIXED_LARGE
if not sigma_small
else gd.ModelVarType.FIXED_SMALL
)
if not learn_sigma
else gd.ModelVarType.LEARNED_RANGE
),
loss_type=loss_type
# rescale_timesteps=rescale_timesteps,
)
================================================
FILE: diffusion/diffusion_utils.py
================================================
# Modified from OpenAI's diffusion repos
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
import torch as th
import numpy as np
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, th.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for th.exp().
logvar1, logvar2 = [
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ th.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
)
def approx_standard_normal_cdf(x):
"""
A fast approximation of the cumulative distribution function of the
standard normal.
"""
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
def continuous_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a continuous Gaussian distribution.
:param x: the targets
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
centered_x = x - means
inv_stdv = th.exp(-log_scales)
normalized_x = centered_x * inv_stdv
log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
return log_probs
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a Gaussian distribution discretizing to a
given image.
:param x: the target images. It is assumed that this was uint8 values,
rescaled to the range [-1, 1].
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
assert x.shape == means.shape == log_scales.shape
centered_x = x - means
inv_stdv = th.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = th.where(
x < -0.999,
log_cdf_plus,
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
)
assert log_probs.shape == x.shape
return log_probs
================================================
FILE: diffusion/gaussian_diffusion.py
================================================
# Modified from OpenAI's diffusion repos
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
import math
import numpy as np
import torch as th
import enum
from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
class ModelMeanType(enum.Enum):
"""
Which type of output the model predicts.
"""
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
START_X = enum.auto() # the model predicts x_0
EPSILON = enum.auto() # the model predicts epsilon
class ModelVarType(enum.Enum):
"""
What is used as the model's output variance.
The LEARNED_RANGE option has been added to allow the model to predict
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
"""
LEARNED = enum.auto()
FIXED_SMALL = enum.auto()
FIXED_LARGE = enum.auto()
LEARNED_RANGE = enum.auto()
class LossType(enum.Enum):
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
RESCALED_MSE = (
enum.auto()
) # use raw MSE loss (with RESCALED_KL when learning variances)
KL = enum.auto() # use the variational lower-bound
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
def is_vb(self):
return self == LossType.KL or self == LossType.RESCALED_KL
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
warmup_time = int(num_diffusion_timesteps * warmup_frac)
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
return betas
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
"""
This is the deprecated API for creating beta schedules.
See get_named_beta_schedule() for the new library of schedules.
"""
if beta_schedule == "quad":
betas = (
np.linspace(
beta_start ** 0.5,
beta_end ** 0.5,
num_diffusion_timesteps,
dtype=np.float64,
)
** 2
)
elif beta_schedule == "linear":
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == "warmup10":
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
elif beta_schedule == "warmup50":
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
elif beta_schedule == "const":
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
betas = 1.0 / np.linspace(
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
)
else:
raise NotImplementedError(beta_schedule)
assert betas.shape == (num_diffusion_timesteps,)
return betas
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
"""
Get a pre-defined beta schedule for the given name.
The beta schedule library consists of beta schedules which remain similar
in the limit of num_diffusion_timesteps.
Beta schedules may be added, but should not be removed or changed once
they are committed to maintain backwards compatibility.
"""
if schedule_name == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / num_diffusion_timesteps
return get_beta_schedule(
"linear",
beta_start=scale * 0.0001,
beta_end=scale * 0.02,
# diffuser stable diffusion
# beta_start=scale * 0.00085,
# beta_end=scale * 0.012,
num_diffusion_timesteps=num_diffusion_timesteps,
)
elif schedule_name == "squaredcos_cap_v2":
return betas_for_alpha_bar(
num_diffusion_timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
class GaussianDiffusion:
"""
Utilities for training and sampling diffusion models.
Original ported from this codebase:
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
:param betas: a 1-D numpy array of betas for each diffusion timestep,
starting at T and going to 1.
"""
def __init__(
self,
*,
betas,
model_mean_type,
model_var_type,
loss_type
):
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
self.loss_type = loss_type
# Use float64 for accuracy.
betas = np.array(betas, dtype=np.float64)
self.betas = betas
assert len(betas.shape) == 1, "betas must be 1-D"
assert (betas > 0).all() and (betas <= 1).all()
self.num_timesteps = int(betas.shape[0])
alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = (
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = np.log(
np.append(self.posterior_variance[1], self.posterior_variance[1:])
) if len(self.posterior_variance) > 1 else np.array([])
self.posterior_mean_coef1 = (
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
self.posterior_mean_coef2 = (
(1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
)
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data for a given number of diffusion steps.
In other words, sample from q(x_t | x_0).
:param x_start: the initial data batch.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:param noise: if specified, the split-out normal noise.
:return: A noisy version of x_start.
"""
if noise is None:
noise = th.randn_like(x_start)
assert noise.shape == x_start.shape
return (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior:
q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor(
self.posterior_log_variance_clipped, t, x_t.shape
)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
== posterior_log_variance_clipped.shape[0]
== x_start.shape[0]
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None,
mask=None, x_start=None, use_concat=False):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
:param model: the model, which takes a signal and a batch of timesteps
as input.
:param x: the [N x C x ...] tensor at time t.
:param t: a 1-D Tensor of timesteps.
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample. Applies before
clip_denoised.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict with the following keys:
- 'mean': the model mean output.
- 'variance': the model variance output.
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
"""
if model_kwargs is None:
model_kwargs = {}
B, F, C = x.shape[:3]
assert t.shape == (B,)
if use_concat:
model_output = model(th.concat([x, mask, x_start], dim=1), t, **model_kwargs)
else:
model_output = model(x, t, **model_kwargs)
try:
model_output = model_output.sample # for tav unet
except:
pass
# model_output = model(x, t, **model_kwargs)
if isinstance(model_output, tuple):
model_output, extra = model_output
else:
extra = None
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
assert model_output.shape == (B, F, C * 2, *x.shape[3:])
model_output, model_var_values = th.split(model_output, C, dim=2)
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
model_variance = th.exp(model_log_variance)
else:
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so
# to get a better decoder log likelihood.
ModelVarType.FIXED_LARGE: (
np.append(self.posterior_variance[1], self.betas[1:]),
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
),
ModelVarType.FIXED_SMALL: (
self.posterior_variance,
self.posterior_log_variance_clipped,
),
}[self.model_var_type]
model_variance = _extract_into_tensor(model_variance, t, x.shape)
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
def process_xstart(x):
if denoised_fn is not None:
x = denoised_fn(x)
if clip_denoised:
return x.clamp(-1, 1)
return x
if self.model_mean_type == ModelMeanType.START_X:
pred_xstart = process_xstart(model_output)
else:
pred_xstart = process_xstart(
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
return {
"mean": model_mean,
"variance": model_variance,
"log_variance": model_log_variance,
"pred_xstart": pred_xstart,
"extra": extra,
}
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
)
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
"""
Compute the mean for the previous step, given a function cond_fn that
computes the gradient of a conditional log probability with respect to
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
condition on y.
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
"""
gradient = cond_fn(x, t, **model_kwargs)
new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
return new_mean
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
"""
Compute what the p_mean_variance output would have been, should the
model's score function be conditioned by cond_fn.
See condition_mean() for details on cond_fn.
Unlike condition_mean(), this instead uses the conditioning strategy
from Song et al (2020).
"""
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
out = p_mean_var.copy()
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
return out
def p_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
mask=None,
x_start=None,
use_concat=False
):
"""
Sample x_{t-1} from the model at the given timestep.
:param model: the model to sample from.
:param x: the current tensor at x_{t-1}.
:param t: the value of t, starting at 0 for the first diffusion step.
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param cond_fn: if not None, this is a gradient function that acts
similarly to the model.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
- 'sample': a random sample from the model.
- 'pred_xstart': a prediction of x_0.
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
mask=mask,
x_start=x_start,
use_concat=use_concat
)
noise = th.randn_like(x)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
if cond_fn is not None:
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
def p_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
mask=None,
x_start=None,
use_concat=False,
):
"""
Generate samples from the model.
:param model: the model module.
:param shape: the shape of the samples, (N, C, H, W).
:param noise: if specified, the noise from the encoder to sample.
Should be of the same shape as `shape`.
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param cond_fn: if not None, this is a gradient function that acts
similarly to the model.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param device: if specified, the device to create the samples on.
If not specified, use a model parameter's device.
:param progress: if True, show a tqdm progress bar.
:return: a non-differentiable batch of samples.
"""
final = None
for sample in self.p_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,
mask=mask,
x_start=x_start,
use_concat=use_concat
):
final = sample
return final["sample"]
def p_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
mask=None,
x_start=None,
use_concat=False
):
"""
Generate samples from the model and yield intermediate samples from
each timestep of diffusion.
Arguments are the same as p_sample_loop().
Returns a generator over dicts, where each dict is the return value of
p_sample().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
else:
img = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.p_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
mask=mask,
x_start=x_start,
use_concat=use_concat
)
yield out
img = out["sample"]
def ddim_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
eta=0.0,
mask=None,
x_start=None,
use_concat=False
):
"""
Sample x_{t-1} from the model using DDIM.
Same usage as p_sample().
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
mask=mask,
x_start=x_start,
use_concat=use_concat
)
if cond_fn is not None:
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
sigma = (
eta
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
)
# Equation 12.
noise = th.randn_like(x)
mean_pred = (
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
+ th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
sample = mean_pred + nonzero_mask * sigma * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
def ddim_reverse_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
eta=0.0,
):
"""
Sample x_{t+1} from the model using DDIM reverse ODE.
"""
assert eta == 0.0, "Reverse ODE only for deterministic path"
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
if cond_fn is not None:
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
- out["pred_xstart"]
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
# Equation 12. reversed
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
def ddim_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
eta=0.0,
mask=None,
x_start=None,
use_concat=False
):
"""
Generate samples from the model using DDIM.
Same usage as p_sample_loop().
"""
final = None
for sample in self.ddim_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,
eta=eta,
mask=mask,
x_start=x_start,
use_concat=use_concat
):
final = sample
return final["sample"]
def ddim_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
eta=0.0,
mask=None,
x_start=None,
use_concat=False
):
"""
Use DDIM to sample from the model and yield intermediate samples from
each timestep of DDIM.
Same usage as p_sample_loop_progressive().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
else:
img = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.ddim_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
eta=eta,
mask=mask,
x_start=x_start,
use_concat=use_concat
)
yield out
img = out["sample"]
def _vb_terms_bpd(
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
):
"""
Get a term for the variational lower-bound.
The resulting units are bits (rather than nats, as one might expect).
This allows for comparison to other papers.
:return: a dict with the following keys:
- 'output': a shape [N] tensor of NLLs or KLs.
- 'pred_xstart': the x_0 predictions.
"""
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t
)
out = self.p_mean_variance(
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
)
kl = normal_kl(
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
)
kl = mean_flat(kl) / np.log(2.0)
decoder_nll = -discretized_gaussian_log_likelihood(
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
)
assert decoder_nll.shape == x_start.shape
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
# At the first timestep return the decoder NLL,
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
output = th.where((t == 0), decoder_nll, kl)
return {"output": output, "pred_xstart": out["pred_xstart"]}
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, use_mask=False):
"""
Compute training losses for a single timestep.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param t: a batch of timestep indices.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param noise: if specified, the specific Gaussian noise to try to remove.
:return: a dict with the key "loss" containing a tensor of shape [N].
Some mean or variance settings may also have other keys.
"""
if model_kwargs is None:
model_kwargs = {}
if noise is None:
noise = th.randn_like(x_start)
x_t = self.q_sample(x_start, t, noise=noise)
if use_mask:
x_t = th.cat([x_t[:, :4], x_start[:, 4:]], dim=1)
terms = {}
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
terms["loss"] = self._vb_terms_bpd(
model=model,
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,
model_kwargs=model_kwargs,
)["output"]
if self.loss_type == LossType.RESCALED_KL:
terms["loss"] *= self.num_timesteps
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
model_output = model(x_t, t, **model_kwargs)
try:
# model_output = model(x_t, t, **model_kwargs).sample
model_output = model_output.sample # for tav unet
except:
pass
# model_output = model(x_t, t, **model_kwargs)
if self.model_var_type in [
ModelVarType.LEARNED,
ModelVarType.LEARNED_RANGE,
]:
B, F, C = x_t.shape[:3]
assert model_output.shape == (B, F, C * 2, *x_t.shape[3:])
model_output, model_var_values = th.split(model_output, C, dim=2)
# Learn the variance using the variational bound, but don't let
# it affect our mean prediction.
frozen_out = th.cat([model_output.detach(), model_var_values], dim=2)
terms["vb"] = self._vb_terms_bpd(
model=lambda *args, r=frozen_out: r,
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,
)["output"]
if self.loss_type == LossType.RESCALED_MSE:
# Divide by 1000 for equivalence with initial implementation.
# Without a factor of 1/1000, the VB term hurts the MSE term.
terms["vb"] *= self.num_timesteps / 1000.0
target = {
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t
)[0],
ModelMeanType.START_X: x_start,
ModelMeanType.EPSILON: noise,
}[self.model_mean_type]
# assert model_output.shape == target.shape == x_start.shape
if use_mask:
terms["mse"] = mean_flat((target[:,:4] - model_output) ** 2)
else:
terms["mse"] = mean_flat((target - model_output) ** 2)
if "vb" in terms:
terms["loss"] = terms["mse"] + terms["vb"]
else:
terms["loss"] = terms["mse"]
else:
raise NotImplementedError(self.loss_type)
return terms
def _prior_bpd(self, x_start):
"""
Get the prior KL term for the variational lower-bound, measured in
bits-per-dim.
This term can't be optimized, as it only depends on the encoder.
:param x_start: the [N x C x ...] tensor of inputs.
:return: a batch of [N] KL values (in bits), one per batch element.
"""
batch_size = x_start.shape[0]
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
kl_prior = normal_kl(
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
)
return mean_flat(kl_prior) / np.log(2.0)
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
"""
Compute the entire variational lower-bound, measured in bits-per-dim,
as well as other related quantities.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param clip_denoised: if True, clip denoised samples.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
- total_bpd: the total variational lower-bound, per batch element.
- prior_bpd: the prior term in the lower-bound.
- vb: an [N x T] tensor of terms in the lower-bound.
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
"""
device = x_start.device
batch_size = x_start.shape[0]
vb = []
xstart_mse = []
mse = []
for t in list(range(self.num_timesteps))[::-1]:
t_batch = th.tensor([t] * batch_size, device=device)
noise = th.randn_like(x_start)
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
# Calculate VLB term at the current timestep
with th.no_grad():
out = self._vb_terms_bpd(
model,
x_start=x_start,
x_t=x_t,
t=t_batch,
clip_denoised=clip_denoised,
model_kwargs=model_kwargs,
)
vb.append(out["output"])
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
mse.append(mean_flat((eps - noise) ** 2))
vb = th.stack(vb, dim=1)
xstart_mse = th.stack(xstart_mse, dim=1)
mse = th.stack(mse, dim=1)
prior_bpd = self._prior_bpd(x_start)
total_bpd = vb.sum(dim=1) + prior_bpd
return {
"total_bpd": total_bpd,
"prior_bpd": prior_bpd,
"vb": vb,
"xstart_mse": xstart_mse,
"mse": mse,
}
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res + th.zeros(broadcast_shape, device=timesteps.device)
================================================
FILE: diffusion/respace.py
================================================
# Modified from OpenAI's diffusion repos
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
import torch
import numpy as np
import torch as th
from .gaussian_diffusion import GaussianDiffusion
def space_timesteps(num_timesteps, section_counts):
"""
Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
of the original process.
For example, if there's 300 timesteps and the section counts are [10,15,20]
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
If the stride is a string starting with "ddim", then the fixed striding
from the DDIM paper is used, and only one section is allowed.
:param num_timesteps: the number of diffusion steps in the original
process to divide up.
:param section_counts: either a list of numbers, or a string containing
comma-separated numbers, indicating the step count
per section. As a special case, use "ddimN" where N
is a number of steps to use the striding from the
DDIM paper.
:return: a set of diffusion steps from the original process to use.
"""
if isinstance(section_counts, str):
if section_counts.startswith("ddim"):
desired_count = int(section_counts[len("ddim") :])
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i))
raise ValueError(
f"cannot create exactly {num_timesteps} steps with an integer stride"
)
section_counts = [int(x) for x in section_counts.split(",")]
size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts)
start_idx = 0
all_steps = []
for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0)
if size < section_count:
raise ValueError(
f"cannot divide section of {size} steps into {section_count}"
)
if section_count <= 1:
frac_stride = 1
else:
frac_stride = (size - 1) / (section_count - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(section_count):
taken_steps.append(start_idx + round(cur_idx))
cur_idx += frac_stride
all_steps += taken_steps
start_idx += size
return set(all_steps)
class SpacedDiffusion(GaussianDiffusion):
"""
A diffusion process which can skip steps in a base diffusion process.
:param use_timesteps: a collection (sequence or set) of timesteps from the
original diffusion process to retain.
:param kwargs: the kwargs to create the base diffusion process.
"""
def __init__(self, use_timesteps, **kwargs):
self.use_timesteps = set(use_timesteps)
self.timestep_map = []
self.original_num_steps = len(kwargs["betas"])
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
last_alpha_cumprod = 1.0
new_betas = []
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
if i in self.use_timesteps:
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
self.timestep_map.append(i)
kwargs["betas"] = np.array(new_betas)
super().__init__(**kwargs)
def p_mean_variance(
self, model, *args, **kwargs
): # pylint: disable=signature-differs
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
# @torch.compile
def training_losses(
self, model, *args, **kwargs
): # pylint: disable=signature-differs
return super().training_losses(self._wrap_model(model), *args, **kwargs)
def condition_mean(self, cond_fn, *args, **kwargs):
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
def condition_score(self, cond_fn, *args, **kwargs):
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
def _wrap_model(self, model):
if isinstance(model, _WrappedModel):
return model
return _WrappedModel(
model, self.timestep_map, self.original_num_steps
)
def _scale_timesteps(self, t):
# Scaling is done by the wrapped model.
return t
class _WrappedModel:
def __init__(self, model, timestep_map, original_num_steps):
self.model = model
self.timestep_map = timestep_map
# self.rescale_timesteps = rescale_timesteps
self.original_num_steps = original_num_steps
def __call__(self, x, ts, **kwargs):
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
new_ts = map_tensor[ts]
# if self.rescale_timesteps:
# new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
return self.model(x, new_ts, **kwargs)
================================================
FILE: diffusion/timestep_sampler.py
================================================
# Modified from OpenAI's diffusion repos
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
from abc import ABC, abstractmethod
import numpy as np
import torch as th
import torch.distributed as dist
def create_named_schedule_sampler(name, diffusion):
"""
Create a ScheduleSampler from a library of pre-defined samplers.
:param name: the name of the sampler.
:param diffusion: the diffusion object to sample for.
"""
if name == "uniform":
return UniformSampler(diffusion)
elif name == "loss-second-moment":
return LossSecondMomentResampler(diffusion)
else:
raise NotImplementedError(f"unknown schedule sampler: {name}")
class ScheduleSampler(ABC):
"""
A distribution over timesteps in the diffusion process, intended to reduce
variance of the objective.
By default, samplers perform unbiased importance sampling, in which the
objective's mean is unchanged.
However, subclasses may override sample() to change how the resampled
terms are reweighted, allowing for actual changes in the objective.
"""
@abstractmethod
def weights(self):
"""
Get a numpy array of weights, one per diffusion step.
The weights needn't be normalized, but must be positive.
"""
def sample(self, batch_size, device):
"""
Importance-sample timesteps for a batch.
:param batch_size: the number of timesteps.
:param device: the torch device to save to.
:return: a tuple (timesteps, weights):
- timesteps: a tensor of timestep indices.
- weights: a tensor of weights to scale the resulting losses.
"""
w = self.weights()
p = w / np.sum(w)
indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
indices = th.from_numpy(indices_np).long().to(device)
weights_np = 1 / (len(p) * p[indices_np])
weights = th.from_numpy(weights_np).float().to(device)
return indices, weights
class UniformSampler(ScheduleSampler):
def __init__(self, diffusion):
self.diffusion = diffusion
self._weights = np.ones([diffusion.num_timesteps])
def weights(self):
return self._weights
class LossAwareSampler(ScheduleSampler):
def update_with_local_losses(self, local_ts, local_losses):
"""
Update the reweighting using losses from a model.
Call this method from each rank with a batch of timesteps and the
corresponding losses for each of those timesteps.
This method will perform synchronization to make sure all of the ranks
maintain the exact same reweighting.
:param local_ts: an integer Tensor of timesteps.
:param local_losses: a 1D Tensor of losses.
"""
batch_sizes = [
th.tensor([0], dtype=th.int32, device=local_ts.device)
for _ in range(dist.get_world_size())
]
dist.all_gather(
batch_sizes,
th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
)
# Pad all_gather batches to be the maximum batch size.
batch_sizes = [x.item() for x in batch_sizes]
max_bs = max(batch_sizes)
timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
dist.all_gather(timestep_batches, local_ts)
dist.all_gather(loss_batches, local_losses)
timesteps = [
x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
]
losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
self.update_with_all_losses(timesteps, losses)
@abstractmethod
def update_with_all_losses(self, ts, losses):
"""
Update the reweighting using losses from a model.
Sub-classes should override this method to update the reweighting
using losses from the model.
This method directly updates the reweighting without synchronizing
between workers. It is called by update_with_local_losses from all
ranks with identical arguments. Thus, it should have deterministic
behavior to maintain state across workers.
:param ts: a list of int timesteps.
:param losses: a list of float losses, one per timestep.
"""
class LossSecondMomentResampler(LossAwareSampler):
def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
self.diffusion = diffusion
self.history_per_term = history_per_term
self.uniform_prob = uniform_prob
self._loss_history = np.zeros(
[diffusion.num_timesteps, history_per_term], dtype=np.float64
)
self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
def weights(self):
if not self._warmed_up():
return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
weights /= np.sum(weights)
weights *= 1 - self.uniform_prob
weights += self.uniform_prob / len(weights)
return weights
def update_with_all_losses(self, ts, losses):
for t, loss in zip(ts, losses):
if self._loss_counts[t] == self.history_per_term:
# Shift out the oldest loss term.
self._loss_history[t, :-1] = self._loss_history[t, 1:]
self._loss_history[t, -1] = loss
else:
self._loss_history[t, self._loss_counts[t]] = loss
self._loss_counts[t] += 1
def _warmed_up(self):
return (self._loss_counts == self.history_per_term).all()
================================================
FILE: models/__init__.py
================================================
import os
import sys
sys.path.append(os.path.split(sys.path[0])[0])
from .unet import UNet3DConditionModel
from torch.optim.lr_scheduler import LambdaLR
def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit
from torch.optim.lr_scheduler import LambdaLR
def fn(step):
if warmup_steps > 0:
return min(step / warmup_steps, 1)
else:
return 1
return LambdaLR(optimizer, fn)
def get_lr_scheduler(optimizer, name, **kwargs):
if name == 'warmup':
return customized_lr_scheduler(optimizer, **kwargs)
elif name == 'cosine':
from torch.optim.lr_scheduler import CosineAnnealingLR
return CosineAnnealingLR(optimizer, **kwargs)
else:
raise NotImplementedError(name)
def get_models(args):
if 'UNet' in args.model:
pretrained_model_path = args.pretrained_model_path
return UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", use_concat=args.use_mask)
else:
raise '{} Model Not Supported!'.format(args.model)
================================================
FILE: models/attention.py
================================================
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
import os
import sys
sys.path.append(os.path.split(sys.path[0])[0])
from dataclasses import dataclass
from typing import Optional
import math
import torch
import torch.nn.functional as F
from torch import nn
from copy import deepcopy
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput
from diffusers.utils.import_utils import is_xformers_available
from diffusers.models.attention import FeedForward, AdaLayerNorm
from rotary_embedding_torch import RotaryEmbedding
from typing import Callable, Optional
from einops import rearrange, repeat
try:
from diffusers.models.modeling_utils import ModelMixin
except:
from diffusers.modeling_utils import ModelMixin # 0.11.1
@dataclass
class Transformer3DModelOutput(BaseOutput):
sample: torch.FloatTensor
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
def exists(x):
return x is not None
class CrossAttention(nn.Module):
r"""
copy from diffuser 0.11.1
A cross attention layer.
Parameters:
query_dim (`int`): The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
"""
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias=False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
use_relative_position: bool = False,
):
super().__init__()
# print('num head', heads)
inner_dim = dim_head * heads
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.scale = dim_head**-0.5
self.heads = heads
self.dim_head = dim_head
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self.sliceable_head_dim = heads
self._slice_size = None
self._use_memory_efficient_attention_xformers = False
self.added_kv_proj_dim = added_kv_proj_dim
if norm_num_groups is not None:
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
else:
self.group_norm = None
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim))
self.to_out.append(nn.Dropout(dropout))
# print(use_relative_position)
self.use_relative_position = use_relative_position
if self.use_relative_position:
self.rotary_emb = RotaryEmbedding(min(32, dim_head))
self.ip_transformed = False
self.ip_scale = 1
def ip_transform(self):
if self.ip_transformed is not True:
self.ip_to_k = deepcopy(self.to_k).to(next(self.parameters()).device)
self.ip_to_v = deepcopy(self.to_v).to(next(self.parameters()).device)
self.ip_transformed = True
def ip_train_set(self):
if self.ip_transformed is True:
self.ip_to_k.requires_grad_(True)
self.ip_to_v.requires_grad_(True)
def set_scale(self, scale):
self.ip_scale = scale
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def reshape_for_scores(self, tensor):
# split heads and dims
# tensor should be [b (h w)] f (d nd)
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).contiguous()
return tensor
def same_batch_dim_to_heads(self, tensor):
batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d
tensor = tensor.reshape(batch_size, seq_len, dim * head_size)
return tensor
def set_attention_slice(self, slice_size):
if slice_size is not None and slice_size > self.sliceable_head_dim:
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
self._slice_size = slice_size
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None, ip_hidden_states=None):
batch_size, sequence_length, _ = hidden_states.shape
encoder_hidden_states = encoder_hidden_states
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
dim = query.shape[-1]
if not self.use_relative_position:
query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d
if self.added_kv_proj_dim is not None:
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
else:
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
if not self.use_relative_position:
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
if self.ip_transformed is True and ip_hidden_states is not None:
# print(ip_hidden_states.dtype)
# print(self.ip_to_k.weight.dtype)
ip_key = self.ip_to_k(ip_hidden_states)
ip_value = self.ip_to_v(ip_hidden_states)
if not self.use_relative_position:
ip_key = self.reshape_heads_to_batch_dim(ip_key)
ip_value = self.reshape_heads_to_batch_dim(ip_value)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
if self.ip_transformed is True and ip_hidden_states is not None:
ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, attention_mask)
ip_hidden_states = ip_hidden_states.to(query.dtype)
else:
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value, attention_mask)
if self.ip_transformed is True and ip_hidden_states is not None:
ip_hidden_states = self._attention(query, ip_key, ip_value, attention_mask)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
if self.ip_transformed is True and ip_hidden_states is not None:
ip_hidden_states = self._sliced_attention(query, ip_key, ip_value, sequence_length, dim, attention_mask)
if self.ip_transformed is True and ip_hidden_states is not None:
hidden_states = hidden_states + self.ip_scale * ip_hidden_states
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states
def _attention(self, query, key, value, attention_mask=None):
if self.upcast_attention:
query = query.float()
key = key.float()
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
if self.upcast_softmax:
attention_scores = attention_scores.float()
attention_probs = attention_scores.softmax(dim=-1)
attention_probs = attention_probs.to(value.dtype)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
)
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
if self.upcast_attention:
query_slice = query_slice.float()
key_slice = key_slice.float()
attn_slice = torch.baddbmm(
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
query_slice,
key_slice.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
if attention_mask is not None:
attn_slice = attn_slice + attention_mask[start_idx:end_idx]
if self.upcast_softmax:
attn_slice = attn_slice.float()
attn_slice = attn_slice.softmax(dim=-1)
# cast back to the original dtype
attn_slice = attn_slice.to(value.dtype)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
# TODO attention_mask
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
class Transformer3DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
use_first_frame: bool = False,
use_relative_position: bool = False,
rotary_emb: bool = None,
):
super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
# Define input layers
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
# Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
use_first_frame=use_first_frame,
use_relative_position=use_relative_position,
rotary_emb=rotary_emb,
)
for d in range(num_layers)
]
)
# 4. Define output layers
if use_linear_projection:
self.proj_out = nn.Linear(in_channels, inner_dim)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, use_image_num=None, return_dict: bool = True, ip_hidden_states=None, encoder_temporal_hidden_states=None):
# Input
# if ip_hidden_states is not None:
# ip_hidden_states = ip_hidden_states.to(dtype=encoder_hidden_states.dtype)
# print(ip_hidden_states.shape)
# print(encoder_hidden_states.shape)
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
if self.training:
video_length = hidden_states.shape[2] - use_image_num
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
encoder_hidden_states_length = encoder_hidden_states.shape[1]
encoder_hidden_states_video = encoder_hidden_states[:, :encoder_hidden_states_length - use_image_num, ...]
encoder_hidden_states_video = repeat(encoder_hidden_states_video, 'b m n c -> b (m f) n c', f=video_length).contiguous()
encoder_hidden_states_image = encoder_hidden_states[:, encoder_hidden_states_length - use_image_num:, ...]
encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1)
encoder_hidden_states = rearrange(encoder_hidden_states, 'b m n c -> (b m) n c').contiguous()
if ip_hidden_states is not None:
ip_hidden_states_length = ip_hidden_states.shape[1]
ip_hidden_states_video = ip_hidden_states[:, :ip_hidden_states_length - use_image_num, ...]
ip_hidden_states_video = repeat(ip_hidden_states_video, 'b m n c -> b (m f) n c', f=video_length).contiguous()
ip_hidden_states_image = ip_hidden_states[:, ip_hidden_states_length - use_image_num:, ...]
ip_hidden_states = torch.cat([ip_hidden_states_video, ip_hidden_states_image], dim=1)
ip_hidden_states = rearrange(ip_hidden_states, 'b m n c -> (b m) n c').contiguous()
else:
video_length = hidden_states.shape[2]
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length).contiguous()
if encoder_temporal_hidden_states is not None:
encoder_temporal_hidden_states = repeat(encoder_temporal_hidden_states, 'b n c -> (b f) n c', f=video_length).contiguous()
if ip_hidden_states is not None:
ip_hidden_states = repeat(ip_hidden_states, 'b 1 n c -> (b f) n c', f=video_length).contiguous()
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
hidden_states = self.proj_in(hidden_states)
# Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
video_length=video_length,
use_image_num=use_image_num,
ip_hidden_states=ip_hidden_states,
encoder_temporal_hidden_states=encoder_temporal_hidden_states
)
# Output
if not self.use_linear_projection:
hidden_states = (
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
)
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = (
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
)
output = hidden_states + residual
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length + use_image_num).contiguous()
if not return_dict:
return (output,)
return Transformer3DModelOutput(sample=output)
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
use_first_frame: bool = False,
use_relative_position: bool = False,
rotary_emb: bool = False,
):
super().__init__()
self.only_cross_attention = only_cross_attention
# print(only_cross_attention)
self.use_ada_layer_norm = num_embeds_ada_norm is not None
# print(self.use_ada_layer_norm)
self.use_first_frame = use_first_frame
self.dim = dim
self.cross_attention_dim = cross_attention_dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.dropout = dropout
self.attention_bias = attention_bias
self.upcast_attention = upcast_attention
# Spatial-Attn
self.attn1 = CrossAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=None,
upcast_attention=upcast_attention,
)
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
# Text Cross-Attn
if cross_attention_dim is not None:
self.attn2 = CrossAttention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
else:
self.attn2 = None
if cross_attention_dim is not None:
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
else:
self.norm2 = None
# Temp
self.attn_temp = TemporalAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=None,
upcast_attention=upcast_attention,
rotary_emb=rotary_emb,
)
self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
# Feed-forward
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.norm3 = nn.LayerNorm(dim)
self.tca_transformed = False
def tca_transform(self):
if self.tca_transformed is not True:
self.cross_attn_temp = CrossAttention(
query_dim=self.dim * 16,
cross_attention_dim=self.cross_attention_dim,
heads=self.num_attention_heads,
dim_head=self.attention_head_dim,
dropout=self.dropout,
bias=self.attention_bias,
upcast_attention=self.upcast_attention,
)
self.cross_norm_temp = AdaLayerNorm(self.dim * 16, self.num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(self.dim * 16)
nn.init.zeros_(self.cross_attn_temp.to_out[0].weight.data)
self.tca_transformed = True
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, op=None):
if not is_xformers_available():
print("Here is how to install it")
raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers",
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
" available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
if self.attn2 is not None:
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None, use_image_num=None, ip_hidden_states=None, encoder_temporal_hidden_states=None):
# SparseCausal-Attention
norm_hidden_states = (
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
)
if self.only_cross_attention:
hidden_states = (
self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
)
else:
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, use_image_num=use_image_num) + hidden_states
if self.attn2 is not None:
# Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
hidden_states = (
self.attn2(
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, ip_hidden_states=ip_hidden_states
)
+ hidden_states
)
# Temporal Attention
if self.training:
d = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous()
hidden_states_video = hidden_states[:, :video_length, :]
hidden_states_image = hidden_states[:, video_length:, :]
norm_hidden_states_video = (
self.norm_temp(hidden_states_video, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states_video)
)
hidden_states_video = self.attn_temp(norm_hidden_states_video) + hidden_states_video
# Temporal Cross Attention
if self.tca_transformed is True:
hidden_states_video = rearrange(hidden_states_video, "(b d) f c -> b d (f c)", d=d).contiguous()
norm_hidden_states_video = (
self.cross_norm_temp(hidden_states_video, timestep) if self.use_ada_layer_norm else self.cross_norm_temp(hidden_states_video)
)
temp_encoder_hidden_states = rearrange(encoder_hidden_states, "(b f) d c -> b f d c", f=video_length + use_image_num).contiguous()
temp_encoder_hidden_states = temp_encoder_hidden_states[:, 0:1].squeeze(dim=1)
hidden_states_video = self.cross_attn_temp(norm_hidden_states_video, encoder_hidden_states=temp_encoder_hidden_states, attention_mask=attention_mask) + hidden_states_video
hidden_states_video = rearrange(hidden_states_video, "b d (f c) -> (b d) f c", f=video_length).contiguous()
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous()
else:
d = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous()
norm_hidden_states = (
self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
)
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
# Temporal Cross Attention
if self.tca_transformed is True:
hidden_states = rearrange(hidden_states, "(b d) f c -> b d (f c)", d=d).contiguous()
norm_hidden_states = (
self.cross_norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.cross_norm_temp(hidden_states)
)
if encoder_temporal_hidden_states is not None:
encoder_hidden_states = encoder_temporal_hidden_states
temp_encoder_hidden_states = rearrange(encoder_hidden_states, "(b f) d c -> b f d c", f=video_length + use_image_num).contiguous()
temp_encoder_hidden_states = temp_encoder_hidden_states[:, 0:1].squeeze(dim=1)
hidden_states = self.cross_attn_temp(norm_hidden_states, encoder_hidden_states=temp_encoder_hidden_states, attention_mask=attention_mask) + hidden_states
hidden_states = rearrange(hidden_states, "b d (f c) -> (b f) d c", f=video_length + use_image_num, d=d).contiguous()
else:
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous()
# Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
return hidden_states
class SparseCausalAttention(CrossAttention):
def forward_video(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
batch_size, sequence_length, _ = hidden_states.shape
encoder_hidden_states = encoder_hidden_states
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states)
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query)
if self.added_kv_proj_dim is not None:
raise NotImplementedError
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
former_frame_index = torch.arange(video_length) - 1
former_frame_index[0] = 0
key = rearrange(key, "(b f) d c -> b f d c", f=video_length).contiguous()
key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2)
key = rearrange(key, "b f d c -> (b f) d c").contiguous()
value = rearrange(value, "(b f) d c -> b f d c", f=video_length).contiguous()
value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2)
value = rearrange(value, "b f d c -> (b f) d c").contiguous()
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value, attention_mask)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states
def forward_image(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None):
batch_size, sequence_length, _ = hidden_states.shape
encoder_hidden_states = encoder_hidden_states
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
dim = query.shape[-1]
if not self.use_relative_position:
query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d
if self.added_kv_proj_dim is not None:
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
else:
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
if not self.use_relative_position:
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value, attention_mask)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, use_image_num=None):
if self.training:
# print(use_image_num)
hidden_states = rearrange(hidden_states, "(b f) d c -> b f d c", f=video_length + use_image_num).contiguous()
hidden_states_video = hidden_states[:, :video_length, ...]
hidden_states_image = hidden_states[:, video_length:, ...]
hidden_states_video = rearrange(hidden_states_video, 'b f d c -> (b f) d c').contiguous()
hidden_states_image = rearrange(hidden_states_image, 'b f d c -> (b f) d c').contiguous()
hidden_states_video = self.forward_video(hidden_states=hidden_states_video,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
video_length=video_length)
hidden_states_image = self.forward_image(hidden_states=hidden_states_image,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask)
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=0)
return hidden_states
# exit()
else:
return self.forward_video(hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
video_length=video_length)
class TemporalAttention(CrossAttention):
def __init__(self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias=False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
rotary_emb=None):
super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, added_kv_proj_dim, norm_num_groups)
# relative time positional embeddings
self.time_rel_pos_bias = RelativePositionBias(heads=heads, max_distance=32) # realistically will not be able to generate that many frames of video... yet
self.rotary_emb = rotary_emb
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
time_rel_pos_bias = self.time_rel_pos_bias(hidden_states.shape[1], device=hidden_states.device)
batch_size, sequence_length, _ = hidden_states.shape
encoder_hidden_states = encoder_hidden_states
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
dim = query.shape[-1]
if self.added_kv_proj_dim is not None:
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
else:
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value, attention_mask, time_rel_pos_bias)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states
def _attention(self, query, key, value, attention_mask=None, time_rel_pos_bias=None):
if self.upcast_attention:
query = query.float()
key = key.float()
query = self.scale * rearrange(query, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
key = rearrange(key, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
value = rearrange(value, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
# torch.baddbmm only accepte 3-D tensor
# https://runebook.dev/zh/docs/pytorch/generated/torch.baddbmm
# attention_scores = self.scale * torch.matmul(query, key.transpose(-1, -2))
if exists(self.rotary_emb):
query = self.rotary_emb.rotate_queries_or_keys(query)
key = self.rotary_emb.rotate_queries_or_keys(key)
attention_scores = torch.einsum('... h i d, ... h j d -> ... h i j', query, key)
attention_scores = attention_scores + time_rel_pos_bias
if attention_mask is not None:
# add attention mask
attention_scores = attention_scores + attention_mask
# vdm
attention_scores = attention_scores - attention_scores.amax(dim = -1, keepdim = True).detach()
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# print(attention_probs[0][0])
# cast back to the original dtype
attention_probs = attention_probs.to(value.dtype)
# compute attention output
hidden_states = torch.einsum('... h i j, ... h j d -> ... h i d', attention_probs, value)
hidden_states = rearrange(hidden_states, 'b h f d -> b f (h d)')
return hidden_states
class RelativePositionBias(nn.Module):
def __init__(
self,
heads=8,
num_buckets=32,
max_distance=128,
):
super().__init__()
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
ret = 0
n = -relative_position
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, n, device):
q_pos = torch.arange(n, dtype = torch.long, device = device)
k_pos = torch.arange(n, dtype = torch.long, device = device)
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
values = self.relative_attention_bias(rp_bucket)
return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames
================================================
FILE: models/clip.py
================================================
import numpy
import torch.nn as nn
from transformers import CLIPTokenizer, CLIPTextModel
import transformers
transformers.logging.set_verbosity_error()
"""
Will encounter following warning:
- This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task
or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model
that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
https://github.com/CompVis/stable-diffusion/issues/97
according to this issue, this warning is safe.
This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion.
You can safely ignore the warning, it is not an error.
This clip usage is from U-ViT and same with Stable Diffusion.
"""
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
# def __init__(self, version="openai/clip-vit-huge-patch14", device="cuda", max_length=77):
def __init__(self, path, device="cuda", max_length=77):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(path, subfolder="tokenizer")
self.transformer = CLIPTextModel.from_pretrained(path, subfolder='text_encoder')
self.device = device
self.max_length = max_length
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
class TextEmbedder(nn.Module):
"""
Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance.
"""
def __init__(self, path, dropout_prob=0.1):
super().__init__()
self.text_encodder = FrozenCLIPEmbedder(path=path)
self.dropout_prob = dropout_prob
def token_drop(self, text_prompts, force_drop_ids=None):
"""
Drops text to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob
else:
# TODO
drop_ids = force_drop_ids == 1
labels = list(numpy.where(drop_ids, "", text_prompts))
# print(labels)
return labels
def forward(self, text_prompts, train, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
text_prompts = self.token_drop(text_prompts, force_drop_ids)
embeddings = self.text_encodder(text_prompts)
return embeddings
if __name__ == '__main__':
r"""
Returns:
Examples from CLIPTextModel:
```python
>>> from transformers import AutoTokenizer, CLIPTextModel
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
```"""
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
text_encoder = TextEmbedder(path='/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-2-1-base',
dropout_prob=0.00001).to(device)
text_prompt = [["a photo of a cat", "a photo of a cat"], ["a photo of a dog", "a photo of a cat"], ['a photo of a dog human', "a photo of a cat"]]
# text_prompt = ('None', 'None', 'None')
output = text_encoder(text_prompts=text_prompt, train=False)
# print(output)
print(output.shape)
# print(output.shape)
================================================
FILE: models/resnet.py
================================================
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
import os
import sys
sys.path.append(os.path.split(sys.path[0])[0])
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class InflatedConv3d(nn.Conv2d):
def forward(self, x):
video_length = x.shape[2]
x = rearrange(x, "b c f h w -> (b f) c h w")
x = super().forward(x)
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
return x
class Upsample3D(nn.Module):
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
conv = None
if use_conv_transpose:
raise NotImplementedError
elif use_conv:
conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
if name == "conv":
self.conv = conv
else:
self.Conv2d_0 = conv
def forward(self, hidden_states, output_size=None):
assert hidden_states.shape[1] == self.channels
if self.use_conv_transpose:
raise NotImplementedError
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if output_size is None:
hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
else:
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)
if self.use_conv:
if self.name == "conv":
hidden_states = self.conv(hidden_states)
else:
hidden_states = self.Conv2d_0(hidden_states)
return hidden_states
class Downsample3D(nn.Module):
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
if use_conv:
conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
raise NotImplementedError
if name == "conv":
self.Conv2d_0 = conv
self.conv = conv
elif name == "Conv2d_0":
self.conv = conv
else:
self.conv = conv
def forward(self, hidden_states):
assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0:
raise NotImplementedError
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)
return hidden_states
class ResnetBlock3D(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
output_scale_factor=1.0,
use_in_shortcut=None,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
time_emb_proj_out_channels = out_channels
elif self.time_embedding_norm == "scale_shift":
time_emb_proj_out_channels = out_channels * 2
else:
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
else:
self.time_emb_proj = None
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, input_tensor, temb):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
class Mish(torch.nn.Module):
def forward(self, hidden_states):
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
================================================
FILE: models/unet.py
================================================
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import os
import sys
sys.path.append(os.path.split(sys.path[0])[0])
import math
import json
import torch
import einops
import torch.nn as nn
import torch.utils.checkpoint
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, logging
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange
try:
from diffusers.models.modeling_utils import ModelMixin
except:
from diffusers.modeling_utils import ModelMixin # 0.11.1
try:
from .unet_blocks import (
CrossAttnDownBlock3D,
CrossAttnUpBlock3D,
DownBlock3D,
UNetMidBlock3DCrossAttn,
UpBlock3D,
get_down_block,
get_up_block,
)
from .resnet import InflatedConv3d
except:
from unet_blocks import (
CrossAttnDownBlock3D,
CrossAttnUpBlock3D,
DownBlock3D,
UNetMidBlock3DCrossAttn,
UpBlock3D,
get_down_block,
get_up_block,
)
from resnet import InflatedConv3d
from rotary_embedding_torch import RotaryEmbedding
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class RelativePositionBias(nn.Module):
def __init__(
self,
heads=8,
num_buckets=32,
max_distance=128,
):
super().__init__()
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
ret = 0
n = -relative_position
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, n, device):
q_pos = torch.arange(n, dtype = torch.long, device = device)
k_pos = torch.arange(n, dtype = torch.long, device = device)
rel_pos = einops.rearrange(k_pos, 'j -> 1 j') - einops.rearrange(q_pos, 'i -> i 1')
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
values = self.relative_attention_bias(rp_bucket)
return einops.rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames
@dataclass
class UNet3DConditionOutput(BaseOutput):
sample: torch.FloatTensor
class UNet3DConditionModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None, # 64
in_channels: int = 4,
out_channels: int = 4,
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D",
),
mid_block_type: str = "UNetMidBlock3DCrossAttn",
up_block_types: Tuple[str] = (
"UpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D"
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
use_first_frame: bool = False,
use_relative_position: bool = False,
):
super().__init__()
# print(use_first_frame)
self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
# input
self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
# time
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
else:
self.class_embedding = None
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
rotary_emb = RotaryEmbedding(32)
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[i],
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
use_first_frame=use_first_frame,
use_relative_position=use_relative_position,
rotary_emb=rotary_emb,
)
self.down_blocks.append(down_block)
# mid
if mid_block_type == "UNetMidBlock3DCrossAttn":
self.mid_block = UNetMidBlock3DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
use_first_frame=use_first_frame,
use_relative_position=use_relative_position,
rotary_emb=rotary_emb,
)
else:
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
# count how many layers upsample the videos
self.num_upsamplers = 0
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
# add upsample block for all BUT final layer
if not is_final_block:
add_upsample = True
self.num_upsamplers += 1
else:
add_upsample = False
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
use_first_frame=use_first_frame,
use_relative_position=use_relative_position,
rotary_emb=rotary_emb,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
# relative time positional embeddings
self.use_relative_position = use_relative_position
if self.use_relative_position:
self.time_rel_pos_bias = RelativePositionBias(heads=8, max_distance=32) # realistically will not be able to generate that many frames of video... yet
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_slicable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_slicable_dims(module)
num_slicable_layers = len(sliceable_head_dims)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_slicable_layers * [1]
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
)
for i in range(len(slice_size)):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor = None,
class_labels: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
use_image_num: int = 0,
return_dict: bool = True,
ip_hidden_states = None,
encoder_temporal_hidden_states = None
) -> Union[UNet3DConditionOutput, Tuple]:
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
if ip_hidden_states is not None:
b = ip_hidden_states.shape[0]
ip_hidden_states = rearrange(ip_hidden_states, 'b n c -> (b n) c')
ip_hidden_states = self.image_proj_model(ip_hidden_states)
ip_hidden_states = rearrange(ip_hidden_states, '(b n) m c -> b n m c', b=b)
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# time
timesteps = timestep
if not torch.is_tensor(timesteps):
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
# print(emb.shape) # torch.Size([3, 1280])
# print(class_emb.shape) # torch.Size([3, 1280])
emb = emb + class_emb
if self.use_relative_position:
frame_rel_pos_bias = self.time_rel_pos_bias(sample.shape[2], device=sample.device)
else:
frame_rel_pos_bias = None
# pre-process
sample = self.conv_in(sample)
# down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
use_image_num=use_image_num,
ip_hidden_states=ip_hidden_states,
encoder_temporal_hidden_states=encoder_temporal_hidden_states
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# mid
sample = self.mid_block(
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states, encoder_temporal_hidden_states=encoder_temporal_hidden_states
)
# up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
attention_mask=attention_mask,
use_image_num=use_image_num,
ip_hidden_states=ip_hidden_states,
encoder_temporal_hidden_states=encoder_temporal_hidden_states
)
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
# print(sample.shape)
if not return_dict:
return (sample,)
sample = UNet3DConditionOutput(sample=sample)
return sample
def forward_with_cfg(self,
x,
t,
encoder_hidden_states = None,
class_labels: Optional[torch.Tensor] = None,
cfg_scale=4.0,
use_fp16=False,
ip_hidden_states = None):
"""
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
half = x[: len(x) // 2]
combined = torch.cat([half, half], dim=0)
if use_fp16:
combined = combined.to(dtype=torch.float16)
model_out = self.forward(combined, t, encoder_hidden_states, class_labels, ip_hidden_states=ip_hidden_states).sample
# For exact reproducibility reasons, we apply classifier-free guidance on only
# three channels by default. The standard approach to cfg applies it to all channels.
# This can be done by uncommenting the following line and commenting-out the line following that.
eps, rest = model_out[:, :4], model_out[:, 4:]
# eps, rest = model_out[:, :3], model_out[:, 3:] # b c f h w
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)
@classmethod
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, use_concat=False):
if subfolder is not None:
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
# the content of the config file
# {
# "_class_name": "UNet2DConditionModel",
# "_diffusers_version": "0.2.2",
# "act_fn": "silu",
# "attention_head_dim": 8,
# "block_out_channels": [
# 320,
# 640,
# 1280,
# 1280
# ],
# "center_input_sample": false,
# "cross_attention_dim": 768,
# "down_block_types": [
# "CrossAttnDownBlock2D",
# "CrossAttnDownBlock2D",
# "CrossAttnDownBlock2D",
# "DownBlock2D"
# ],
# "downsample_padding": 1,
# "flip_sin_to_cos": true,
# "freq_shift": 0,
# "in_channels": 4,
# "layers_per_block": 2,
# "mid_block_scale_factor": 1,
# "norm_eps": 1e-05,
# "norm_num_groups": 32,
# "out_channels": 4,
# "sample_size": 64,
# "up_block_types": [
# "UpBlock2D",
# "CrossAttnUpBlock2D",
# "CrossAttnUpBlock2D",
# "CrossAttnUpBlock2D"
# ]
# }
config_file = os.path.join(pretrained_model_path, 'config.json')
if not os.path.isfile(config_file):
raise RuntimeError(f"{config_file} does not exist")
with open(config_file, "r") as f:
config = json.load(f)
config["_class_name"] = cls.__name__
config["down_block_types"] = [
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D"
]
config["up_block_types"] = [
"UpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D"
]
# config["use_first_frame"] = True
config["use_first_frame"] = False
if use_concat:
config["in_channels"] = 9
# config["use_relative_position"] = True
# # tmp
# config["class_embed_type"] = "timestep"
# config["num_class_embeds"] = 100
from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin
# {'_class_name': 'UNet3DConditionModel',
# '_diffusers_version': '0.2.2',
# 'act_fn': 'silu',
# 'attention_head_dim': 8,
# 'block_out_channels': [320, 640, 1280, 1280],
# 'center_input_sample': False,
# 'cross_attention_dim': 768,
# 'down_block_types':
# ['CrossAttnDownBlock3D',
# 'CrossAttnDownBlock3D',
# 'CrossAttnDownBlock3D',
# 'DownBlock3D'],
# 'downsample_padding': 1,
# 'flip_sin_to_cos': True,
# 'freq_shift': 0,
# 'in_channels': 4,
# 'layers_per_block': 2,
# 'mid_block_scale_factor': 1,
# 'norm_eps': 1e-05,
# 'norm_num_groups': 32,
# 'out_channels': 4,
# 'sample_size': 64,
# 'up_block_types':
# ['UpBlock3D',
# 'CrossAttnUpBlock3D',
# 'CrossAttnUpBlock3D',
# 'CrossAttnUpBlock3D']}
model = cls.from_config(config)
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
if not os.path.isfile(model_file):
raise RuntimeError(f"{model_file} does not exist")
state_dict = torch.load(model_file, map_location="cpu")
if use_concat:
new_state_dict = {}
conv_in_weight = state_dict["conv_in.weight"]
new_conv_weight = torch.zeros((conv_in_weight.shape[0], 9, *conv_in_weight.shape[2:]), dtype=conv_in_weight.dtype)
for i, j in zip([0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 6, 7, 8]):
new_conv_weight[:, j] = conv_in_weight[:, i]
new_state_dict["conv_in.weight"] = new_conv_weight
new_state_dict["conv_in.bias"] = state_dict["conv_in.bias"]
for k, v in model.state_dict().items():
# print(k)
if '_temp.' in k:
new_state_dict.update({k: v})
if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
k = k.replace('attn_fcross', 'attn1')
state_dict.update({k: state_dict[k]})
if 'norm_fcross' in k:
k = k.replace('norm_fcross', 'norm1')
state_dict.update({k: state_dict[k]})
if 'conv_in' in k:
continue
else:
new_state_dict[k] = v
# # tmp
# if 'class_embedding' in k:
# state_dict.update({k: v})
# breakpoint()
model.load_state_dict(new_state_dict)
else:
for k, v in model.state_dict().items():
# print(k)
if '_temp' in k:
state_dict.update({k: v})
if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
k = k.replace('attn_fcross', 'attn1')
state_dict.update({k: state_dict[k]})
if 'norm_fcross' in k:
k = k.replace('norm_fcross', 'norm1')
state_dict.update({k: state_dict[k]})
model.load_state_dict(state_dict)
return model
================================================
FILE: models/unet_blocks.py
================================================
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
import os
import sys
sys.path.append(os.path.split(sys.path[0])[0])
import torch
from torch import nn
try:
from .attention import Transformer3DModel
from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
except:
from attention import Transformer3DModel
from resnet import Downsample3D, ResnetBlock3D, Upsample3D
def get_down_block(
down_block_type,
num_layers,
in_channels,
out_channels,
temb_channels,
add_downsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
downsample_padding=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
use_first_frame=False,
use_relative_position=False,
rotary_emb=False,
):
# print(down_block_type)
# print(use_first_frame)
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock3D":
return DownBlock3D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif down_block_type == "CrossAttnDownBlock3D":
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
return CrossAttnDownBlock3D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
use_first_frame=use_first_frame,
use_relative_position=use_relative_position,
rotary_emb=rotary_emb,
)
raise ValueError(f"{down_block_type} does not exist.")
def get_up_block(
up_block_type,
num_layers,
in_channels,
out_channels,
prev_output_channel,
temb_channels,
add_upsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
use_first_frame=False,
use_relative_position=False,
rotary_emb=False,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock3D":
return UpBlock3D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif up_block_type == "CrossAttnUpBlock3D":
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
return CrossAttnUpBlock3D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
use_first_frame=use_first_frame,
use_relative_position=use_relative_position,
rotary_emb=rotary_emb,
)
raise ValueError(f"{up_block_type} does not exist.")
class UNetMidBlock3DCrossAttn(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
dual_cross_attention=False,
use_linear_projection=False,
upcast_attention=False,
use_first_frame=False,
use_relative_position=False,
rotary_emb=False,
):
super().__init__()
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# there is always at least one resnet
resnets = [
ResnetBlock3D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
attentions = []
for _ in range(num_layers):
if dual_cross_attention:
raise NotImplementedError
attentions.append(
Transformer3DModel(
attn_num_head_channels,
in_channels // attn_num_head_channels,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
use_first_frame=use_first_frame,
use_relative_position=use_relative_position,
rotary_emb=rotary_emb,
)
)
resnets.append(
ResnetBlock3D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None, ip_hidden_states=None, encoder_temporal_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states, encoder_temporal_hidden_states=encoder_temporal_hidden_states).sample
hidden_states = resnet(hidden_states, temb)
return hidden_states
class CrossAttnDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
use_first_frame=False,
use_relative_position=False,
rotary_emb=False,
):
super().__init__()
resnets = []
attentions = []
# print(use_first_frame)
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock3D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
if dual_cross_attention:
raise NotImplementedError
attentions.append(
Transformer3DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
use_first_frame=use_first_frame,
use_relative_position=use_relative_position,
rotary_emb=rotary_emb,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample3D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None, ip_hidden_states=None, encoder_temporal_hidden_states=None):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
def create_custom_forward_attn(module, return_dict=None, use_image_num=None, ip_hidden_states=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states)
else:
return module(*inputs, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward_attn(attn, return_dict=False, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states),
hidden_states,
encoder_hidden_states,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num, ip_hidden_states=ip_hidden_states, encoder_temporal_hidden_states=encoder_temporal_hidden_states).sample
output_states += (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
class DownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_downsample=True,
downsample_padding=1,
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock3D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample3D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None):
output_states = ()
for resnet in self.resnets:
if self.training and self.gradient_checkpointing:
def crea
gitextract_ybj531gl/
├── LICENSE
├── MSYH.TTC
├── README.md
├── configs/
│ ├── vlog_read_script_sample.yaml
│ ├── vlog_write_script.yaml
│ ├── with_mask_ref_sample.yaml
│ └── with_mask_sample.yaml
├── datasets/
│ └── video_transforms.py
├── diffusion/
│ ├── __init__.py
│ ├── diffusion_utils.py
│ ├── gaussian_diffusion.py
│ ├── respace.py
│ └── timestep_sampler.py
├── models/
│ ├── __init__.py
│ ├── attention.py
│ ├── clip.py
│ ├── resnet.py
│ ├── unet.py
│ ├── unet_blocks.py
│ └── utils.py
├── requirements.txt
├── results/
│ └── vlog/
│ ├── teddy_travel/
│ │ ├── script/
│ │ │ ├── audio_prompts.txt
│ │ │ ├── protagonist_place_reference.txt
│ │ │ ├── protagonists_places.txt
│ │ │ ├── time_scripts.txt
│ │ │ ├── video_prompts.txt
│ │ │ └── zh_video_prompts.txt
│ │ └── story.txt
│ └── teddy_travel_/
│ └── story.txt
├── sample_scripts/
│ ├── vlog_read_script_sample.py
│ ├── vlog_write_script.py
│ ├── with_mask_ref_sample.py
│ └── with_mask_sample.py
├── utils.py
└── vlogger/
├── STEB/
│ └── model_transform.py
├── planning_utils/
│ └── gpt4_utils.py
├── videoaudio.py
├── videocaption.py
└── videofusion.py
SYMBOL INDEX (269 symbols across 23 files)
FILE: datasets/video_transforms.py
function _is_tensor_video_clip (line 7) | def _is_tensor_video_clip(clip):
function center_crop_arr (line 17) | def center_crop_arr(pil_image, image_size):
function crop (line 38) | def crop(clip, i, j, h, w):
function resize (line 48) | def resize(clip, target_size, interpolation_mode):
function resize_scale (line 53) | def resize_scale(clip, target_size, interpolation_mode):
function resize_with_scale_factor (line 60) | def resize_with_scale_factor(clip, scale_factor, interpolation_mode):
function resize_scale_with_height (line 63) | def resize_scale_with_height(clip, target_size, interpolation_mode):
function resize_scale_with_weight (line 68) | def resize_scale_with_weight(clip, target_size, interpolation_mode):
function resized_crop (line 74) | def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
function center_crop (line 94) | def center_crop(clip, crop_size):
function center_crop_using_short_edge (line 109) | def center_crop_using_short_edge(clip):
function random_shift_crop (line 124) | def random_shift_crop(clip):
function to_tensor (line 146) | def to_tensor(clip):
function normalize (line 162) | def normalize(clip, mean, std, inplace=False):
function hflip (line 182) | def hflip(clip):
class RandomCropVideo (line 194) | class RandomCropVideo:
method __init__ (line 195) | def __init__(self, size):
method __call__ (line 201) | def __call__(self, clip):
method get_params (line 212) | def get_params(self, clip):
method __repr__ (line 227) | def __repr__(self) -> str:
class CenterCropResizeVideo (line 230) | class CenterCropResizeVideo:
method __init__ (line 235) | def __init__(
method __call__ (line 250) | def __call__(self, clip):
method __repr__ (line 264) | def __repr__(self) -> str:
class CenterCropVideo (line 268) | class CenterCropVideo:
method __init__ (line 269) | def __init__(
method __call__ (line 284) | def __call__(self, clip):
method __repr__ (line 295) | def __repr__(self) -> str:
class NormalizeVideo (line 299) | class NormalizeVideo:
method __init__ (line 308) | def __init__(self, mean, std, inplace=False):
method __call__ (line 313) | def __call__(self, clip):
method __repr__ (line 320) | def __repr__(self) -> str:
class ToTensorVideo (line 324) | class ToTensorVideo:
method __init__ (line 330) | def __init__(self):
method __call__ (line 333) | def __call__(self, clip):
method __repr__ (line 342) | def __repr__(self) -> str:
class ResizeVideo (line 346) | class ResizeVideo():
method __init__ (line 351) | def __init__(
method __call__ (line 366) | def __call__(self, clip):
method __repr__ (line 377) | def __repr__(self) -> str:
FILE: diffusion/__init__.py
function create_diffusion (line 10) | def create_diffusion(
FILE: diffusion/diffusion_utils.py
function normal_kl (line 10) | def normal_kl(mean1, logvar1, mean2, logvar2):
function approx_standard_normal_cdf (line 39) | def approx_standard_normal_cdf(x):
function continuous_gaussian_log_likelihood (line 47) | def continuous_gaussian_log_likelihood(x, *, means, log_scales):
function discretized_gaussian_log_likelihood (line 62) | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
FILE: diffusion/gaussian_diffusion.py
function mean_flat (line 16) | def mean_flat(tensor):
class ModelMeanType (line 23) | class ModelMeanType(enum.Enum):
class ModelVarType (line 33) | class ModelVarType(enum.Enum):
class LossType (line 46) | class LossType(enum.Enum):
method is_vb (line 54) | def is_vb(self):
function _warmup_beta (line 58) | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_f...
function get_beta_schedule (line 65) | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffus...
function get_named_beta_schedule (line 98) | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
function betas_for_alpha_bar (line 128) | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.9...
class GaussianDiffusion (line 147) | class GaussianDiffusion:
method __init__ (line 156) | def __init__(
method q_mean_variance (line 206) | def q_mean_variance(self, x_start, t):
method q_sample (line 218) | def q_sample(self, x_start, t, noise=None):
method q_posterior_mean_variance (line 235) | def q_posterior_mean_variance(self, x_start, x_t, t):
method p_mean_variance (line 257) | def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn...
method _predict_xstart_from_eps (line 346) | def _predict_xstart_from_eps(self, x_t, t, eps):
method _predict_eps_from_xstart (line 353) | def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
method condition_mean (line 358) | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
method condition_score (line 370) | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
method p_sample (line 388) | def p_sample(
method p_sample_loop (line 437) | def p_sample_loop(
method p_sample_loop_progressive (line 488) | def p_sample_loop_progressive(
method ddim_sample (line 543) | def ddim_sample(
method ddim_reverse_sample (line 598) | def ddim_reverse_sample(
method ddim_sample_loop (line 636) | def ddim_sample_loop(
method ddim_sample_loop_progressive (line 675) | def ddim_sample_loop_progressive(
method _vb_terms_bpd (line 730) | def _vb_terms_bpd(
method training_losses (line 763) | def training_losses(self, model, x_start, t, model_kwargs=None, noise=...
method _prior_bpd (line 847) | def _prior_bpd(self, x_start):
method calc_bpd_loop (line 863) | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwar...
function _extract_into_tensor (line 919) | def _extract_into_tensor(arr, timesteps, broadcast_shape):
FILE: diffusion/respace.py
function space_timesteps (line 12) | def space_timesteps(num_timesteps, section_counts):
class SpacedDiffusion (line 65) | class SpacedDiffusion(GaussianDiffusion):
method __init__ (line 73) | def __init__(self, use_timesteps, **kwargs):
method p_mean_variance (line 89) | def p_mean_variance(
method training_losses (line 95) | def training_losses(
method condition_mean (line 100) | def condition_mean(self, cond_fn, *args, **kwargs):
method condition_score (line 103) | def condition_score(self, cond_fn, *args, **kwargs):
method _wrap_model (line 106) | def _wrap_model(self, model):
method _scale_timesteps (line 113) | def _scale_timesteps(self, t):
class _WrappedModel (line 118) | class _WrappedModel:
method __init__ (line 119) | def __init__(self, model, timestep_map, original_num_steps):
method __call__ (line 125) | def __call__(self, x, ts, **kwargs):
FILE: diffusion/timestep_sampler.py
function create_named_schedule_sampler (line 13) | def create_named_schedule_sampler(name, diffusion):
class ScheduleSampler (line 27) | class ScheduleSampler(ABC):
method weights (line 38) | def weights(self):
method sample (line 44) | def sample(self, batch_size, device):
class UniformSampler (line 62) | class UniformSampler(ScheduleSampler):
method __init__ (line 63) | def __init__(self, diffusion):
method weights (line 67) | def weights(self):
class LossAwareSampler (line 71) | class LossAwareSampler(ScheduleSampler):
method update_with_local_losses (line 72) | def update_with_local_losses(self, local_ts, local_losses):
method update_with_all_losses (line 106) | def update_with_all_losses(self, ts, losses):
class LossSecondMomentResampler (line 120) | class LossSecondMomentResampler(LossAwareSampler):
method __init__ (line 121) | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
method weights (line 130) | def weights(self):
method update_with_all_losses (line 139) | def update_with_all_losses(self, ts, losses):
method _warmed_up (line 149) | def _warmed_up(self):
FILE: models/__init__.py
function customized_lr_scheduler (line 8) | def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u...
function get_lr_scheduler (line 18) | def get_lr_scheduler(optimizer, name, **kwargs):
function get_models (line 27) | def get_models(args):
FILE: models/attention.py
class Transformer3DModelOutput (line 28) | class Transformer3DModelOutput(BaseOutput):
function exists (line 38) | def exists(x):
class CrossAttention (line 42) | class CrossAttention(nn.Module):
method __init__ (line 57) | def __init__(
method ip_transform (line 115) | def ip_transform(self):
method ip_train_set (line 121) | def ip_train_set(self):
method set_scale (line 126) | def set_scale(self, scale):
method reshape_heads_to_batch_dim (line 129) | def reshape_heads_to_batch_dim(self, tensor):
method reshape_batch_dim_to_heads (line 136) | def reshape_batch_dim_to_heads(self, tensor):
method reshape_for_scores (line 143) | def reshape_for_scores(self, tensor):
method same_batch_dim_to_heads (line 152) | def same_batch_dim_to_heads(self, tensor):
method set_attention_slice (line 157) | def set_attention_slice(self, slice_size):
method forward (line 163) | def forward(self, hidden_states, encoder_hidden_states=None, attention...
method _attention (line 248) | def _attention(self, query, key, value, attention_mask=None):
method _sliced_attention (line 273) | def _sliced_attention(self, query, key, value, sequence_length, dim, a...
method _memory_efficient_attention_xformers (line 316) | def _memory_efficient_attention_xformers(self, query, key, value, atte...
class Transformer3DModel (line 326) | class Transformer3DModel(ModelMixin, ConfigMixin):
method __init__ (line 328) | def __init__(
method forward (line 390) | def forward(self, hidden_states, encoder_hidden_states=None, timestep=...
class BasicTransformerBlock (line 472) | class BasicTransformerBlock(nn.Module):
method __init__ (line 473) | def __init__(
method tca_transform (line 555) | def tca_transform(self):
method set_use_memory_efficient_attention_xformers (line 570) | def set_use_memory_efficient_attention_xformers(self, use_memory_effic...
method forward (line 598) | def forward(self, hidden_states, encoder_hidden_states=None, timestep=...
class SparseCausalAttention (line 676) | class SparseCausalAttention(CrossAttention):
method forward_video (line 677) | def forward_video(self, hidden_states, encoder_hidden_states=None, att...
method forward_image (line 734) | def forward_image(self, hidden_states, encoder_hidden_states=None, att...
method forward (line 793) | def forward(self, hidden_states, encoder_hidden_states=None, attention...
class TemporalAttention (line 817) | class TemporalAttention(CrossAttention):
method __init__ (line 818) | def __init__(self,
method forward (line 835) | def forward(self, hidden_states, encoder_hidden_states=None, attention...
method _attention (line 890) | def _attention(self, query, key, value, attention_mask=None, time_rel_...
class RelativePositionBias (line 928) | class RelativePositionBias(nn.Module):
method __init__ (line 929) | def __init__(
method _relative_position_bucket (line 941) | def _relative_position_bucket(relative_position, num_buckets=32, max_d...
method forward (line 960) | def forward(self, n, device):
FILE: models/clip.py
class AbstractEncoder (line 24) | class AbstractEncoder(nn.Module):
method __init__ (line 25) | def __init__(self):
method encode (line 28) | def encode(self, *args, **kwargs):
class FrozenCLIPEmbedder (line 32) | class FrozenCLIPEmbedder(AbstractEncoder):
method __init__ (line 35) | def __init__(self, path, device="cuda", max_length=77):
method freeze (line 43) | def freeze(self):
method forward (line 48) | def forward(self, text):
method encode (line 57) | def encode(self, text):
class TextEmbedder (line 61) | class TextEmbedder(nn.Module):
method __init__ (line 65) | def __init__(self, path, dropout_prob=0.1):
method token_drop (line 70) | def token_drop(self, text_prompts, force_drop_ids=None):
method forward (line 83) | def forward(self, text_prompts, train, force_drop_ids=None):
FILE: models/resnet.py
class InflatedConv3d (line 13) | class InflatedConv3d(nn.Conv2d):
method forward (line 14) | def forward(self, x):
class Upsample3D (line 24) | class Upsample3D(nn.Module):
method __init__ (line 25) | def __init__(self, channels, use_conv=False, use_conv_transpose=False,...
method forward (line 44) | def forward(self, hidden_states, output_size=None):
class Downsample3D (line 79) | class Downsample3D(nn.Module):
method __init__ (line 80) | def __init__(self, channels, use_conv=False, out_channels=None, paddin...
method forward (line 102) | def forward(self, hidden_states):
class ResnetBlock3D (line 113) | class ResnetBlock3D(nn.Module):
method __init__ (line 114) | def __init__(
method forward (line 177) | def forward(self, input_tensor, temb):
class Mish (line 210) | class Mish(torch.nn.Module):
method forward (line 211) | def forward(self, hidden_states):
FILE: models/unet.py
class RelativePositionBias (line 54) | class RelativePositionBias(nn.Module):
method __init__ (line 55) | def __init__(
method _relative_position_bucket (line 67) | def _relative_position_bucket(relative_position, num_buckets=32, max_d...
method forward (line 86) | def forward(self, n, device):
class UNet3DConditionOutput (line 95) | class UNet3DConditionOutput(BaseOutput):
class UNet3DConditionModel (line 99) | class UNet3DConditionModel(ModelMixin, ConfigMixin):
method __init__ (line 103) | def __init__(
method set_attention_slice (line 291) | def set_attention_slice(self, slice_size):
method _set_gradient_checkpointing (line 356) | def _set_gradient_checkpointing(self, module, value=False):
method forward (line 360) | def forward(
method forward_with_cfg (line 519) | def forward_with_cfg(self,
method from_pretrained_2d (line 547) | def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, use...
FILE: models/unet_blocks.py
function get_down_block (line 17) | def get_down_block(
function get_up_block (line 82) | def get_up_block(
class UNetMidBlock3DCrossAttn (line 145) | class UNetMidBlock3DCrossAttn(nn.Module):
method __init__ (line 146) | def __init__(
method forward (line 226) | def forward(self, hidden_states, temb=None, encoder_hidden_states=None...
class CrossAttnDownBlock3D (line 235) | class CrossAttnDownBlock3D(nn.Module):
method __init__ (line 236) | def __init__(
method forward (line 320) | def forward(self, hidden_states, temb=None, encoder_hidden_states=None...
class DownBlock3D (line 365) | class DownBlock3D(nn.Module):
method __init__ (line 366) | def __init__(
method forward (line 417) | def forward(self, hidden_states, temb=None):
class CrossAttnUpBlock3D (line 444) | class CrossAttnUpBlock3D(nn.Module):
method __init__ (line 445) | def __init__(
method forward (line 524) | def forward(
class UpBlock3D (line 579) | class UpBlock3D(nn.Module):
method __init__ (line 580) | def __init__(
method forward (line 627) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, u...
FILE: models/utils.py
function checkpoint (line 25) | def checkpoint(func, inputs, params, flag):
class CheckpointFunction (line 42) | class CheckpointFunction(torch.autograd.Function):
method forward (line 44) | def forward(ctx, run_function, length, *args):
method backward (line 54) | def backward(ctx, *output_grads):
function timestep_embedding (line 74) | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=Fal...
function zero_module (line 97) | def zero_module(module):
function scale_module (line 106) | def scale_module(module, scale):
function mean_flat (line 115) | def mean_flat(tensor):
function normalization (line 122) | def normalization(channels):
class SiLU (line 132) | class SiLU(nn.Module):
method forward (line 133) | def forward(self, x):
class GroupNorm32 (line 137) | class GroupNorm32(nn.GroupNorm):
method forward (line 138) | def forward(self, x):
function conv_nd (line 141) | def conv_nd(dims, *args, **kwargs):
function linear (line 154) | def linear(*args, **kwargs):
function avg_pool_nd (line 161) | def avg_pool_nd(dims, *args, **kwargs):
function noise_like (line 187) | def noise_like(shape, device, repeat=False):
function count_flops_attn (line 192) | def count_flops_attn(model, _x, y):
function count_params (line 211) | def count_params(model, verbose=False):
FILE: sample_scripts/vlog_read_script_sample.py
function auto_inpainting (line 37) | def auto_inpainting(args,
function main (line 110) | def main(args):
FILE: sample_scripts/vlog_write_script.py
function main (line 18) | def main(args):
FILE: sample_scripts/with_mask_ref_sample.py
function get_input (line 48) | def get_input(args):
function auto_inpainting (line 120) | def auto_inpainting(args,
function main (line 185) | def main(args):
FILE: sample_scripts/with_mask_sample.py
function get_input (line 47) | def get_input(args):
function auto_inpainting (line 119) | def auto_inpainting(args, video_input, masked_video, mask, prompt, vae, ...
function main (line 176) | def main(args):
FILE: utils.py
function fetch_files_by_numbers (line 20) | def fetch_files_by_numbers(start_number, count, file_list):
function get_grad_norm (line 35) | def get_grad_norm(
function clip_grad_norm_ (line 72) | def clip_grad_norm_(
function separation_content_motion (line 129) | def separation_content_motion(video_clip):
function get_experiment_dir (line 147) | def get_experiment_dir(root_dir, args):
function create_logger (line 166) | def create_logger(logging_dir):
function create_accelerate_logger (line 185) | def create_accelerate_logger(logging_dir, is_main_process=False):
function create_tensorboard (line 204) | def create_tensorboard(tensorboard_dir):
function write_tensorboard (line 214) | def write_tensorboard(writer, *args):
function update_ema (line 227) | def update_ema(ema_model, model, decay=0.9999):
function requires_grad (line 239) | def requires_grad(model, flag=True):
function cleanup (line 246) | def cleanup():
function setup_distributed (line 253) | def setup_distributed(backend="nccl", port=None):
function save_video_grid (line 292) | def save_video_grid(video, nrow=None):
function save_videos_grid_tav (line 312) | def save_videos_grid_tav(videos: torch.Tensor, path: str, rescale=False,...
function collect_env (line 336) | def collect_env():
function mask_generation_before (line 356) | def mask_generation_before(mask_type, shape, dtype, device, dropout_prob...
FILE: vlogger/STEB/model_transform.py
function tca_transform_model (line 11) | def tca_transform_model(model):
class ImageProjModel (line 32) | class ImageProjModel(torch.nn.Module):
method __init__ (line 34) | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024,...
method forward (line 42) | def forward(self, image_embeds):
function ip_transform_model (line 49) | def ip_transform_model(model):
function ip_scale_set (line 72) | def ip_scale_set(model, scale):
function ip_train_set (line 93) | def ip_train_set(model):
FILE: vlogger/planning_utils/gpt4_utils.py
function smart_openai_key (line 12) | def smart_openai_key():
function json_completion (line 20) | def json_completion(prompt):
function ExtractProtagonist (line 41) | def ExtractProtagonist(story, file_path):
function ExtractAProtagonist (line 89) | def ExtractAProtagonist(story, file_path):
function protagonist_place_reference (line 131) | def protagonist_place_reference(video_list, character_places):
function protagonist_place_reference1 (line 193) | def protagonist_place_reference1(video_list, character_places, file_path):
function split_story (line 260) | def split_story(story, file_path):
function patch_story_scripts (line 309) | def patch_story_scripts(story, video_list, file_path):
function refine_story_scripts (line 361) | def refine_story_scripts(video_list, file_path):
function time_scripts (line 413) | def time_scripts(video_list, file_path):
function translate_video_script (line 479) | def translate_video_script(video_list, file_path):
function readscript (line 527) | def readscript(script_file_path):
function readzhscript (line 537) | def readzhscript(zh_file_path):
function readtimescript (line 547) | def readtimescript(time_file_path):
function readprotagonistscript (line 558) | def readprotagonistscript(protagonist_file_path):
function readreferencescript (line 568) | def readreferencescript(video_list, character_places, reference_file_path):
FILE: vlogger/videoaudio.py
function make_audio (line 15) | def make_audio(en_prompt_file, output_dir):
function merge_video_audio (line 41) | def merge_video_audio(video_dir, audio_dir, output_dir):
function concatenate_videos (line 81) | def concatenate_videos(video_dir, output_dir=None):
FILE: vlogger/videocaption.py
function captioning (line 11) | def captioning(en_prompt_file, zh_prompt_file, input_video_dir, output_v...
FILE: vlogger/videofusion.py
function fusion (line 8) | def fusion(path):
Condensed preview — 39 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (333K chars).
[
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 13819,
"preview": "<div align=\"center\">\n\n<h1 align=\"center\">Vlogger: Make Your Dream A Vlog</h1>\n</a>\n\n[Shaobin Zhuang](https://github.com/"
},
{
"path": "configs/vlog_read_script_sample.yaml",
"chars": 1128,
"preview": "# path:\nckpt: \"pretrained/ShowMaker.pt\"\npretrained_model_path: \"pretrained/stable-diffusion-v1-4/\"\nimage_encoder_path: \""
},
{
"path": "configs/vlog_write_script.yaml",
"chars": 94,
"preview": "# script path\nstory_path: \"./results/vlog/teddy_travel_/story.txt\"\nonly_one_protagonist: False"
},
{
"path": "configs/with_mask_ref_sample.yaml",
"chars": 982,
"preview": "# path config:\nckpt: \"pretrained/ShowMaker.pt\"\npretrained_model_path: \"pretrained/stable-diffusion-v1-4/\"\nimage_encoder_"
},
{
"path": "configs/with_mask_sample.yaml",
"chars": 866,
"preview": "# path config:\nckpt: \"pretrained/ShowMaker.pt\"\npretrained_model_path: \"pretrained/OpenCLIP-ViT-H-14\"\ninput_path: 'input/"
},
{
"path": "datasets/video_transforms.py",
"chars": 12929,
"preview": "import torch\r\nimport random\r\nimport numbers\r\nfrom torchvision.transforms import RandomCrop, RandomResizedCrop\r\nfrom PIL "
},
{
"path": "diffusion/__init__.py",
"chars": 1705,
"preview": "# Modified from OpenAI's diffusion repos\r\n# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/g"
},
{
"path": "diffusion/diffusion_utils.py",
"chars": 3277,
"preview": "# Modified from OpenAI's diffusion repos\r\n# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/g"
},
{
"path": "diffusion/gaussian_diffusion.py",
"chars": 37194,
"preview": "# Modified from OpenAI's diffusion repos\r\n# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/g"
},
{
"path": "diffusion/respace.py",
"chars": 5648,
"preview": "# Modified from OpenAI's diffusion repos\r\n# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/g"
},
{
"path": "diffusion/timestep_sampler.py",
"chars": 6163,
"preview": "# Modified from OpenAI's diffusion repos\r\n# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/g"
},
{
"path": "models/__init__.py",
"chars": 1118,
"preview": "import os\r\nimport sys\r\nsys.path.append(os.path.split(sys.path[0])[0])\r\n\r\nfrom .unet import UNet3DConditionModel\r\nfrom to"
},
{
"path": "models/attention.py",
"chars": 44586,
"preview": "# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py\nimport os\nimport sys"
},
{
"path": "models/clip.py",
"chars": 4768,
"preview": "import numpy\r\nimport torch.nn as nn\r\nfrom transformers import CLIPTokenizer, CLIPTextModel\r\n\r\nimport transformers\r\ntrans"
},
{
"path": "models/resnet.py",
"chars": 7295,
"preview": "# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py\nimport os\nimport sys\nsy"
},
{
"path": "models/unet.py",
"chars": 28787,
"preview": "# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py\n\nfrom datacl"
},
{
"path": "models/unet_blocks.py",
"chars": 24569,
"preview": "# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py\nimport os\nimpor"
},
{
"path": "models/utils.py",
"chars": 7329,
"preview": "# adopted from\r\n# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py\r\n# and"
},
{
"path": "requirements.txt",
"chars": 405,
"preview": "bark_ssg==1.3.4\ndecord==0.6.0\ndiffusers==0.25.0\neinops==0.7.0\nimageio==2.28.0\nipython==8.14.0\nlibrosa==0.10.1\nmmcv==2.1."
},
{
"path": "results/vlog/teddy_travel/script/audio_prompts.txt",
"chars": 3145,
"preview": "[\n {\n \"video fragment id\": 1,\n \"video fragment description\": \"Teddy is planning on paper.\"\n },\n {\n \"video fragment id"
},
{
"path": "results/vlog/teddy_travel/script/protagonists_places.txt",
"chars": 582,
"preview": "[\n {\n \"id\": 1,\n \"name\": \"Teddy\",\n \"description\": \"A teddy bear with a dream of traveling the wor"
},
{
"path": "results/vlog/teddy_travel/script/time_scripts.txt",
"chars": 1097,
"preview": "[\n{\n \"video fragment id\": 1,\n \"time\": 2\n},\n{\n \"video fragment id\": 2,\n \"time\": 3\n},\n{\n \"video fragment id"
},
{
"path": "results/vlog/teddy_travel/script/video_prompts.txt",
"chars": 2235,
"preview": "[\n{\n \"video fragment id\": 1,\n \"video fragment description\": \"Teddy bear in a kid's room.\"\n},\n{\n \"video fragment"
},
{
"path": "results/vlog/teddy_travel/script/zh_video_prompts.txt",
"chars": 2078,
"preview": "[\n{\n \"序号\": 1,\n \"描述\": \"泰迪熊在孩子的房间里。\",\n },\n {\n \"序号\": 2,\n"
},
{
"path": "results/vlog/teddy_travel/story.txt",
"chars": 783,
"preview": "Once upon a time, there was a teddy bear named Teddy who dreamed of traveling the world. One day, his dream came true to"
},
{
"path": "results/vlog/teddy_travel_/story.txt",
"chars": 783,
"preview": "Once upon a time, there was a teddy bear named Teddy who dreamed of traveling the world. One day, his dream came true to"
},
{
"path": "sample_scripts/vlog_read_script_sample.py",
"chars": 14384,
"preview": "import torch\n\ntorch.backends.cuda.matmul.allow_tf32 = True\ntorch.backends.cudnn.allow_tf32 = True\nimport os\nimport sys\nt"
},
{
"path": "sample_scripts/vlog_write_script.py",
"chars": 4317,
"preview": "import torch\nimport os\nos.environ['CURL_CA_BUNDLE'] = ''\nimport argparse\nfrom omegaconf import OmegaConf\nfrom diffusers "
},
{
"path": "sample_scripts/with_mask_ref_sample.py",
"chars": 11798,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
},
{
"path": "sample_scripts/with_mask_sample.py",
"chars": 10873,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\r\n# All rights reserved.\r\n\r\n# This source code is licensed under the"
},
{
"path": "utils.py",
"chars": 14641,
"preview": "import os\nimport math\nimport torch\nimport logging\nimport subprocess\nimport numpy as np\nimport torch.distributed as dist\n"
},
{
"path": "vlogger/STEB/model_transform.py",
"chars": 4345,
"preview": "import torch\n# import argparse\n# from omegaconf import OmegaConf\n# from models import get_models\n# import sys\n# import o"
},
{
"path": "vlogger/planning_utils/gpt4_utils.py",
"chars": 24777,
"preview": "import openai\nimport re\nimport ast\n\n# Enter your openai key here\n# Allow multiple keys to be filled in to prevent the nu"
},
{
"path": "vlogger/videoaudio.py",
"chars": 4261,
"preview": "import os\nimport ast\nfrom IPython.display import Audio\nimport nltk # we'll use this to split into sentences\nimport nump"
},
{
"path": "vlogger/videocaption.py",
"chars": 1969,
"preview": "import torch\nimport ast\nimport os\nimport cv2 as cv\nfrom PIL import Image, ImageDraw, ImageFont\nfrom decord import VideoR"
},
{
"path": "vlogger/videofusion.py",
"chars": 659,
"preview": "import torch\nimport os\nfrom decord import VideoReader, cpu\nimport numpy as np\nimport torchvision\n\n\ndef fusion(path):\n "
}
]
// ... and 2 more files (download for full content)
About this extraction
This page contains the full source code of the zhuangshaobin/Vlogger GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 39 files (20.8 MB), approximately 73.7k tokens, and a symbol index with 269 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.