Showing preview only (975K chars total). Download the full file or copy to clipboard to get everything.
Repository: AIGODLIKE/ComfyUI-ToonCrafter
Branch: main
Commit: 96024189ecb2
Files: 98
Total size: 933.1 KB
Directory structure:
gitextract_h58eh22a/
├── .github/
│ └── workflows/
│ └── publish.yml
├── .gitignore
├── LICENSE
├── ToonCrafter/
│ ├── .gitignore
│ ├── LICENSE
│ ├── README.md
│ ├── __init__.py
│ ├── cldm/
│ │ ├── cldm.py
│ │ ├── ddim_hacked.py
│ │ ├── hack.py
│ │ ├── logger.py
│ │ └── model.py
│ ├── configs/
│ │ ├── cldm_v21.yaml
│ │ ├── inference_512_v1.0.yaml
│ │ ├── training_1024_v1.0/
│ │ │ ├── config.yaml
│ │ │ └── run.sh
│ │ └── training_512_v1.0/
│ │ ├── config.yaml
│ │ └── run.sh
│ ├── gradio_app.py
│ ├── ldm/
│ │ ├── data/
│ │ │ ├── __init__.py
│ │ │ └── util.py
│ │ ├── models/
│ │ │ ├── autoencoder.py
│ │ │ └── diffusion/
│ │ │ ├── __init__.py
│ │ │ ├── ddim.py
│ │ │ ├── ddpm.py
│ │ │ ├── dpm_solver/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── dpm_solver.py
│ │ │ │ └── sampler.py
│ │ │ ├── plms.py
│ │ │ └── sampling_util.py
│ │ ├── modules/
│ │ │ ├── attention.py
│ │ │ ├── diffusionmodules/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── model.py
│ │ │ │ ├── openaimodel.py
│ │ │ │ ├── upscaling.py
│ │ │ │ └── util.py
│ │ │ ├── distributions/
│ │ │ │ ├── __init__.py
│ │ │ │ └── distributions.py
│ │ │ ├── ema.py
│ │ │ ├── encoders/
│ │ │ │ ├── __init__.py
│ │ │ │ └── modules.py
│ │ │ ├── image_degradation/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── bsrgan.py
│ │ │ │ ├── bsrgan_light.py
│ │ │ │ └── utils_image.py
│ │ │ └── midas/
│ │ │ ├── __init__.py
│ │ │ ├── api.py
│ │ │ ├── midas/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base_model.py
│ │ │ │ ├── blocks.py
│ │ │ │ ├── dpt_depth.py
│ │ │ │ ├── midas_net.py
│ │ │ │ ├── midas_net_custom.py
│ │ │ │ ├── transforms.py
│ │ │ │ └── vit.py
│ │ │ └── utils.py
│ │ └── util.py
│ ├── lvdm/
│ │ ├── __init__.py
│ │ ├── basics.py
│ │ ├── common.py
│ │ ├── data/
│ │ │ ├── base.py
│ │ │ └── webvid.py
│ │ ├── distributions.py
│ │ ├── ema.py
│ │ ├── models/
│ │ │ ├── autoencoder.py
│ │ │ ├── autoencoder_dualref.py
│ │ │ ├── ddpm3d.py
│ │ │ ├── samplers/
│ │ │ │ ├── ddim.py
│ │ │ │ └── ddim_multiplecond.py
│ │ │ └── utils_diffusion.py
│ │ └── modules/
│ │ ├── attention.py
│ │ ├── attention_svd.py
│ │ ├── encoders/
│ │ │ ├── condition.py
│ │ │ └── resampler.py
│ │ ├── networks/
│ │ │ ├── ae_modules.py
│ │ │ └── openaimodel3d.py
│ │ └── x_transformer.py
│ ├── main/
│ │ ├── __init__.py
│ │ ├── callbacks.py
│ │ ├── trainer.py
│ │ ├── utils_data.py
│ │ └── utils_train.py
│ ├── prompts/
│ │ └── 512_interp/
│ │ └── prompts.txt
│ ├── requirements.txt
│ ├── scripts/
│ │ ├── evaluation/
│ │ │ ├── ddp_wrapper.py
│ │ │ ├── funcs.py
│ │ │ └── inference.py
│ │ ├── gradio/
│ │ │ ├── i2v_test.py
│ │ │ └── i2v_test_application.py
│ │ └── run.sh
│ └── utils/
│ ├── __init__.py
│ ├── save_video.py
│ └── utils.py
├── __init__.py
├── pre_run.py
├── pyproject.toml
├── readme.md
└── requirements.txt
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/publish.yml
================================================
name: Publish to Comfy registry
on:
workflow_dispatch:
push:
branches:
- main
- master
paths:
- "pyproject.toml"
jobs:
publish-node:
name: Publish Custom Node to registry
runs-on: ubuntu-latest
steps:
- name: Check out code
uses: actions/checkout@v4
- name: Publish Custom Node
uses: Comfy-Org/publish-node-action@main
with:
## Add your own personal access token to your Github Repository secrets and reference it here.
personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
================================================
FILE: .gitignore
================================================
.DS_Store
*pyc
.vscode
__pycache__
*.egg-info
checkpoints
ToonCrafter/checkpoints
results
backup
LOG
/models
ToonCrafter/tmp
Thumbs.db
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright Tencent
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: ToonCrafter/.gitignore
================================================
.DS_Store
*pyc
.vscode
__pycache__
*.egg-info
checkpoints
results
backup
LOG
================================================
FILE: ToonCrafter/LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright Tencent
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: ToonCrafter/README.md
================================================
## ___***ToonCrafter: Generative Cartoon Interpolation***___
<!-- {: width="50%"} -->
<!--  -->
<div align="center">
</div>
## 🔆 Introduction
⚠️ Please check our [disclaimer](#disc) first.
🤗 ToonCrafter can interpolate two cartoon images by leveraging the pre-trained image-to-video diffusion priors. Please check our project page and paper for more information. <br>
### 1.1 Showcases (512x320)
<table class="center">
<tr style="font-weight: bolder;text-align:center;">
<td>Input starting frame</td>
<td>Input ending frame</td>
<td>Generated video</td>
</tr>
<tr>
<td>
<img src=assets/72109_125.mp4_00-00.png width="250">
</td>
<td>
<img src=assets/72109_125.mp4_00-01.png width="250">
</td>
<td>
<img src=assets/00.gif width="250">
</td>
</tr>
<tr>
<td>
<img src=assets/Japan_v2_2_062266_s2_frame1.png width="250">
</td>
<td>
<img src=assets/Japan_v2_2_062266_s2_frame3.png width="250">
</td>
<td>
<img src=assets/03.gif width="250">
</td>
</tr>
<tr>
<td>
<img src=assets/Japan_v2_1_070321_s3_frame1.png width="250">
</td>
<td>
<img src=assets/Japan_v2_1_070321_s3_frame3.png width="250">
</td>
<td>
<img src=assets/02.gif width="250">
</td>
</tr>
<tr>
<td>
<img src=assets/74302_1349_frame1.png width="250">
</td>
<td>
<img src=assets/74302_1349_frame3.png width="250">
</td>
<td>
<img src=assets/01.gif width="250">
</td>
</tr>
</table>
### 1.2 Sparse sketch guidance
<table class="center">
<tr style="font-weight: bolder;text-align:center;">
<td>Input starting frame</td>
<td>Input ending frame</td>
<td>Input sketch guidance</td>
<td>Generated video</td>
</tr>
<tr>
<td>
<img src=assets/72105_388.mp4_00-00.png width="200">
</td>
<td>
<img src=assets/72105_388.mp4_00-01.png width="200">
</td>
<td>
<img src=assets/06.gif width="200">
</td>
<td>
<img src=assets/07.gif width="200">
</td>
</tr>
<tr>
<td>
<img src=assets/72110_255.mp4_00-00.png width="200">
</td>
<td>
<img src=assets/72110_255.mp4_00-01.png width="200">
</td>
<td>
<img src=assets/12.gif width="200">
</td>
<td>
<img src=assets/13.gif width="200">
</td>
</tr>
</table>
### 2. Applications
#### 2.1 Cartoon Sketch Interpolation (see project page for more details)
<table class="center">
<tr style="font-weight: bolder;text-align:center;">
<td>Input starting frame</td>
<td>Input ending frame</td>
<td>Generated video</td>
</tr>
<tr>
<td>
<img src=assets/frame0001_10.png width="250">
</td>
<td>
<img src=assets/frame0016_10.png width="250">
</td>
<td>
<img src=assets/10.gif width="250">
</td>
</tr>
<tr>
<td>
<img src=assets/frame0001_11.png width="250">
</td>
<td>
<img src=assets/frame0016_11.png width="250">
</td>
<td>
<img src=assets/11.gif width="250">
</td>
</tr>
</table>
#### 2.2 Reference-based Sketch Colorization
<table class="center">
<tr style="font-weight: bolder;text-align:center;">
<td>Input sketch</td>
<td>Input reference</td>
<td>Colorization results</td>
</tr>
<tr>
<td>
<img src=assets/04.gif width="250">
</td>
<td>
<img src=assets/frame0001_05.png width="250">
</td>
<td>
<img src=assets/05.gif width="250">
</td>
</tr>
<tr>
<td>
<img src=assets/08.gif width="250">
</td>
<td>
<img src=assets/frame0001_09.png width="250">
</td>
<td>
<img src=assets/09.gif width="250">
</td>
</tr>
</table>
## 📝 Changelog
- [ ] Add sketch control and colorization function.
- __[2024.05.29]__: 🔥🔥 Release code and model weights.
- __[2024.05.28]__: Launch the project page and update the arXiv preprint.
<br>
## 🧰 Models
|Model|Resolution|GPU Mem. & Inference Time (A100, ddim 50steps)|Checkpoint|
|:---------|:---------|:--------|:--------|
|ToonCrafter_512|320x512| TBD (`perframe_ae=True`)|[Hugging Face](https://huggingface.co/Doubiiu/ToonCrafter/blob/main/model.ckpt)|
Currently, our ToonCrafter can support generating videos of up to 16 frames with a resolution of 512x320. The inference time can be reduced by using fewer DDIM steps.
## ⚙️ Setup
### Install Environment via Anaconda (Recommended)
```bash
conda create -n tooncrafter python=3.8.5
conda activate tooncrafter
pip install -r requirements.txt
```
## 💫 Inference
### 1. Command line
Download pretrained ToonCrafter_512 and put the `model.ckpt` in `checkpoints/tooncrafter_512_interp_v1/model.ckpt`.
```bash
sh scripts/run.sh
```
### 2. Local Gradio demo
Download the pretrained model and put it in the corresponding directory according to the previous guidelines.
```bash
python gradio_app.py
```
<!-- ## 🤝 Community Support -->
<a name="disc"></a>
## 📢 Disclaimer
Calm down. Our framework opens up the era of generative cartoon interpolation, but due to the variaity of generative video prior, the success rate is not guaranteed.
⚠️This is an open-source research exploration, instead of commercial products. It can't meet all your expectations.
This project strives to impact the domain of AI-driven video generation positively. Users are granted the freedom to create videos using this tool, but they are expected to comply with local laws and utilize it responsibly. The developers do not assume any responsibility for potential misuse by users.
****
================================================
FILE: ToonCrafter/__init__.py
================================================
import sys
from pathlib import Path
sys.path.append(Path(__file__).parent.as_posix())
================================================
FILE: ToonCrafter/cldm/cldm.py
================================================
import einops
import torch
import torch as th
import torch.nn as nn
from ToonCrafter.ldm.modules.diffusionmodules.util import (
conv_nd,
linear,
zero_module,
timestep_embedding,
)
from einops import rearrange, repeat
from torchvision.utils import make_grid
from ToonCrafter.ldm.modules.attention import SpatialTransformer
from ToonCrafter.ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
from lvdm.modules.networks.openaimodel3d import UNetModel
from ToonCrafter.ldm.models.diffusion.ddpm import LatentDiffusion
from ToonCrafter.ldm.util import log_txt_as_img, exists, instantiate_from_config
from ToonCrafter.ldm.models.diffusion.ddim import DDIMSampler
class ControlledUnetModel(UNetModel):
def forward(self, x, timesteps, context=None, features_adapter=None, fs=None, control = None, **kwargs):
b,_,t,_,_ = x.shape
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).type(x.dtype)
emb = self.time_embed(t_emb)
## repeat t times for context [(b t) 77 768] & time embedding
## check if we use per-frame image conditioning
_, l_context, _ = context.shape
if l_context == 77 + t*16: ## !!! HARD CODE here
context_text, context_img = context[:,:77,:], context[:,77:,:]
context_text = context_text.repeat_interleave(repeats=t, dim=0)
context_img = rearrange(context_img, 'b (t l) c -> (b t) l c', t=t)
context = torch.cat([context_text, context_img], dim=1)
else:
context = context.repeat_interleave(repeats=t, dim=0)
emb = emb.repeat_interleave(repeats=t, dim=0)
## always in shape (b t) c h w, except for temporal layer
x = rearrange(x, 'b c t h w -> (b t) c h w')
## combine emb
if self.fs_condition:
if fs is None:
fs = torch.tensor(
[self.default_fs] * b, dtype=torch.long, device=x.device)
fs_emb = timestep_embedding(fs, self.model_channels, repeat_only=False).type(x.dtype)
fs_embed = self.fps_embedding(fs_emb)
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
emb = emb + fs_embed
h = x.type(self.dtype)
adapter_idx = 0
hs = []
with torch.no_grad():
for id, module in enumerate(self.input_blocks):
h = module(h, emb, context=context, batch_size=b)
if id ==0 and self.addition_attention:
h = self.init_attn(h, emb, context=context, batch_size=b)
## plug-in adapter features
if ((id+1)%3 == 0) and features_adapter is not None:
h = h + features_adapter[adapter_idx]
adapter_idx += 1
hs.append(h)
if features_adapter is not None:
assert len(features_adapter)==adapter_idx, 'Wrong features_adapter'
h = self.middle_block(h, emb, context=context, batch_size=b)
if control is not None:
h += control.pop()
for module in self.output_blocks:
if control is None:
h = torch.cat([h, hs.pop()], dim=1)
else:
h = torch.cat([h, hs.pop() + control.pop()], dim=1)
h = module(h, emb, context=context, batch_size=b)
h = h.type(x.dtype)
y = self.out(h)
# reshape back to (b c t h w)
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
return y
class ControlNet(nn.Module):
def __init__(
self,
image_size,
in_channels,
model_channels,
hint_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None,
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
):
super().__init__()
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
self.dims = dims
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
if len(num_res_blocks) != len(channel_mult):
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult")
self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set.")
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.predict_codebook_ids = n_embed is not None
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
self.input_hint_block = TimestepEmbedSequential(
conv_nd(dims, hint_channels, 16, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 16, 16, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 32, 32, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 96, 96, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
nn.SiLU(),
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self.zero_convs.append(self.make_zero_conv(ch))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
self.zero_convs.append(self.make_zero_conv(ch))
ds *= 2
self._feature_size += ch
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self.middle_block_out = self.make_zero_conv(ch)
self._feature_size += ch
def make_zero_conv(self, channels):
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
def forward(self, x, hint, timesteps, context, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
guided_hint = self.input_hint_block(hint, emb, context)
outs = []
h = x.type(self.dtype)
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
if guided_hint is not None:
h = module(h, emb, context)
h += guided_hint
guided_hint = None
else:
h = module(h, emb, context)
outs.append(zero_conv(h, emb, context, True))
h = self.middle_block(h, emb, context)
outs.append(self.middle_block_out(h, emb, context))
return outs
class ControlLDM(LatentDiffusion):
def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs):
super().__init__(*args, **kwargs)
self.control_model = instantiate_from_config(control_stage_config)
self.control_key = control_key
self.only_mid_control = only_mid_control
self.control_scales = [1.0] * 13
@torch.no_grad()
def get_input(self, batch, k, bs=None, *args, **kwargs):
x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
control = batch[self.control_key]
if bs is not None:
control = control[:bs]
control = control.to(self.device)
control = einops.rearrange(control, 'b h w c -> b c h w')
control = control.to(memory_format=torch.contiguous_format).float()
return x, dict(c_crossattn=[c], c_concat=[control])
def apply_model(self, x_noisy, t, cond, *args, **kwargs):
assert isinstance(cond, dict)
diffusion_model = self.model.diffusion_model
cond_txt = torch.cat(cond['c_crossattn'], 1)
if cond['c_concat'] is None:
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
else:
control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
control = [c * scale for c, scale in zip(control, self.control_scales)]
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
return eps
@torch.no_grad()
def get_unconditional_conditioning(self, N):
return self.get_learned_conditioning([""] * N)
@torch.no_grad()
def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
use_ema_scope=True,
**kwargs):
use_ddim = ddim_steps is not None
log = dict()
z, c = self.get_input(batch, self.first_stage_key, bs=N)
c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
N = min(z.shape[0], N)
n_row = min(z.shape[0], n_row)
log["reconstruction"] = self.decode_first_stage(z)
log["control"] = c_cat * 2.0 - 1.0
log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
if plot_diffusion_rows:
# get diffusion row
diffusion_row = list()
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if unconditional_guidance_scale > 1.0:
uc_cross = self.get_unconditional_conditioning(N)
uc_cat = c_cat # torch.zeros_like(c_cat)
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc_full,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
return log
@torch.no_grad()
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
ddim_sampler = DDIMSampler(self)
b, c, h, w = cond["c_concat"][0].shape
shape = (self.channels, h // 8, w // 8)
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
return samples, intermediates
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.control_model.parameters())
if not self.sd_locked:
params += list(self.model.diffusion_model.output_blocks.parameters())
params += list(self.model.diffusion_model.out.parameters())
opt = torch.optim.AdamW(params, lr=lr)
return opt
def low_vram_shift(self, is_diffusing):
if is_diffusing:
self.model = self.model.cuda()
self.control_model = self.control_model.cuda()
self.first_stage_model = self.first_stage_model.cpu()
self.cond_stage_model = self.cond_stage_model.cpu()
else:
self.model = self.model.cpu()
self.control_model = self.control_model.cpu()
self.first_stage_model = self.first_stage_model.cuda()
self.cond_stage_model = self.cond_stage_model.cuda()
================================================
FILE: ToonCrafter/cldm/ddim_hacked.py
================================================
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from ToonCrafter.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
ucg_schedule=None,
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule
)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
if ucg_schedule is not None:
assert len(ucg_schedule) == len(time_range)
unconditional_guidance_scale = ucg_schedule[i]
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
else:
model_t = self.model.apply_model(x, t, c)
model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
if self.model.parameterization == "v":
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
else:
e_t = model_output
if score_corrector is not None:
assert self.model.parameterization == "eps", 'not implemented'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
raise NotImplementedError()
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
num_reference_steps = timesteps.shape[0]
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc='Encoding Image'):
t = torch.full((x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long)
if unconditional_guidance_scale == 1.:
noise_pred = self.model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
torch.cat((unconditional_conditioning, c))), 2)
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
weighted_noise_pred = alphas_next[i].sqrt() * (
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
x_next = xt_weighted + weighted_noise_pred
if return_intermediates and i % (
num_steps // return_intermediates) == 0 and i < num_steps - 1:
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
if callback: callback(i)
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
if return_intermediates:
out.update({'intermediates': intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False, callback=None):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
if callback: callback(i)
return x_dec
================================================
FILE: ToonCrafter/cldm/hack.py
================================================
import torch
import einops
import ldm.modules.encoders.modules
import ldm.modules.attention
from transformers import logging
from ToonCrafter.ldm.modules.attention import default
def disable_verbosity():
logging.set_verbosity_error()
print('logging improved.')
return
def enable_sliced_attention():
ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
print('Enabled sliced_attention.')
return
def hack_everything(clip_skip=0):
disable_verbosity()
ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
print('Enabled clip hacks.')
return
# Written by Lvmin
def _hacked_clip_forward(self, text):
PAD = self.tokenizer.pad_token_id
EOS = self.tokenizer.eos_token_id
BOS = self.tokenizer.bos_token_id
def tokenize(t):
return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
def transformer_encode(t):
if self.clip_skip > 1:
rt = self.transformer(input_ids=t, output_hidden_states=True)
return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
else:
return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
def split(x):
return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
def pad(x, p, i):
return x[:i] if len(x) >= i else x + [p] * (i - len(x))
raw_tokens_list = tokenize(text)
tokens_list = []
for raw_tokens in raw_tokens_list:
raw_tokens_123 = split(raw_tokens)
raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
tokens_list.append(raw_tokens_123)
tokens_list = torch.IntTensor(tokens_list).to(self.device)
feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
y = transformer_encode(feed)
z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
return z
# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
limit = k.shape[0]
att_step = 1
q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
q_chunks.reverse()
k_chunks.reverse()
v_chunks.reverse()
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
del k, q, v
for i in range(0, limit, att_step):
q_buffer = q_chunks.pop()
k_buffer = k_chunks.pop()
v_buffer = v_chunks.pop()
sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
del k_buffer, q_buffer
# attention, what we cannot get enough of, by chunks
sim_buffer = sim_buffer.softmax(dim=-1)
sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
del v_buffer
sim[i:i + att_step, :, :] = sim_buffer
del sim_buffer
sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
return self.to_out(sim)
================================================
FILE: ToonCrafter/cldm/logger.py
================================================
import os
import numpy as np
import torch
import torchvision
from PIL import Image
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.distributed import rank_zero_only
class ImageLogger(Callback):
def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
log_images_kwargs=None):
super().__init__()
self.rescale = rescale
self.batch_freq = batch_frequency
self.max_images = max_images
if not increase_log_steps:
self.log_steps = [self.batch_freq]
self.clamp = clamp
self.disabled = disabled
self.log_on_batch_idx = log_on_batch_idx
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
self.log_first_step = log_first_step
@rank_zero_only
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
root = os.path.join(save_dir, "image_log", split)
for k in images:
grid = torchvision.utils.make_grid(images[k], nrow=4)
if self.rescale:
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
grid = grid.numpy()
grid = (grid * 255).astype(np.uint8)
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
path = os.path.join(root, filename)
os.makedirs(os.path.split(path)[0], exist_ok=True)
Image.fromarray(grid).save(path)
def log_img(self, pl_module, batch, batch_idx, split="train"):
check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
hasattr(pl_module, "log_images") and
callable(pl_module.log_images) and
self.max_images > 0):
logger = type(pl_module.logger)
is_train = pl_module.training
if is_train:
pl_module.eval()
with torch.no_grad():
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
for k in images:
N = min(images[k].shape[0], self.max_images)
images[k] = images[k][:N]
if isinstance(images[k], torch.Tensor):
images[k] = images[k].detach().cpu()
if self.clamp:
images[k] = torch.clamp(images[k], -1., 1.)
self.log_local(pl_module.logger.save_dir, split, images,
pl_module.global_step, pl_module.current_epoch, batch_idx)
if is_train:
pl_module.train()
def check_frequency(self, check_idx):
return check_idx % self.batch_freq == 0
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
if not self.disabled:
self.log_img(pl_module, batch, batch_idx, split="train")
================================================
FILE: ToonCrafter/cldm/model.py
================================================
import os
import torch
from omegaconf import OmegaConf
from comfy.ldm.util import instantiate_from_config
def get_state_dict(d):
return d.get('state_dict', d)
def load_state_dict(ckpt_path, location='cpu'):
_, extension = os.path.splitext(ckpt_path)
if extension.lower() == ".safetensors":
import safetensors.torch
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
else:
state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
state_dict = get_state_dict(state_dict)
print(f'Loaded state_dict from [{ckpt_path}]')
return state_dict
def create_model(config_path):
config = OmegaConf.load(config_path)
model = instantiate_from_config(config.model).cpu()
print(f'Loaded model config from [{config_path}]')
return model
================================================
FILE: ToonCrafter/configs/cldm_v21.yaml
================================================
control_stage_config:
target: ToonCrafter.cldm.cldm.ControlNet
params:
use_checkpoint: True
image_size: 32 # unused
in_channels: 4
hint_channels: 1
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
================================================
FILE: ToonCrafter/configs/inference_512_v1.0.yaml
================================================
model:
target: lvdm.models.ddpm3d.LatentVisualDiffusion
params:
rescale_betas_zero_snr: True
parameterization: "v"
linear_start: 0.00085
linear_end: 0.012
num_timesteps_cond: 1
timesteps: 1000
first_stage_key: video
cond_stage_key: caption
cond_stage_trainable: False
conditioning_key: hybrid
image_size: [40, 64]
channels: 4
scale_by_std: False
scale_factor: 0.18215
use_ema: False
uncond_type: 'empty_seq'
use_dynamic_rescale: true
base_scale: 0.7
fps_condition_type: 'fps'
perframe_ae: True
loop_video: true
unet_config:
target: lvdm.modules.networks.openaimodel3d.UNetModel
params:
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions:
- 4
- 2
- 1
num_res_blocks: 2
channel_mult:
- 1
- 2
- 4
- 4
dropout: 0.1
num_head_channels: 64
transformer_depth: 1
context_dim: 1024
use_linear: true
use_checkpoint: True
temporal_conv: True
temporal_attention: True
temporal_selfatt_only: true
use_relative_position: false
use_causal_attention: False
temporal_length: 16
addition_attention: true
image_cross_attention: true
default_fs: 24
fs_condition: true
first_stage_config:
target: lvdm.models.autoencoder.AutoencoderKL_Dualref
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
params:
freeze: true
layer: "penultimate"
img_cond_stage_config:
target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
params:
freeze: true
image_proj_stage_config:
target: lvdm.modules.encoders.resampler.Resampler
params:
dim: 1024
depth: 4
dim_head: 64
heads: 12
num_queries: 16
embedding_dim: 1280
output_dim: 1024
ff_mult: 4
video_length: 16
================================================
FILE: ToonCrafter/configs/training_1024_v1.0/config.yaml
================================================
model:
pretrained_checkpoint: checkpoints/dynamicrafter_1024_v1/model.ckpt
base_learning_rate: 1.0e-05
scale_lr: False
target: lvdm.models.ddpm3d.LatentVisualDiffusion
params:
rescale_betas_zero_snr: True
parameterization: "v"
linear_start: 0.00085
linear_end: 0.012
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: video
cond_stage_key: caption
cond_stage_trainable: False
image_proj_model_trainable: True
conditioning_key: hybrid
image_size: [72, 128]
channels: 4
scale_by_std: False
scale_factor: 0.18215
use_ema: False
uncond_prob: 0.05
uncond_type: 'empty_seq'
rand_cond_frame: true
use_dynamic_rescale: true
base_scale: 0.3
fps_condition_type: 'fps'
perframe_ae: True
unet_config:
target: lvdm.modules.networks.openaimodel3d.UNetModel
params:
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions:
- 4
- 2
- 1
num_res_blocks: 2
channel_mult:
- 1
- 2
- 4
- 4
dropout: 0.1
num_head_channels: 64
transformer_depth: 1
context_dim: 1024
use_linear: true
use_checkpoint: True
temporal_conv: True
temporal_attention: True
temporal_selfatt_only: true
use_relative_position: false
use_causal_attention: False
temporal_length: 16
addition_attention: true
image_cross_attention: true
default_fs: 10
fs_condition: true
first_stage_config:
target: lvdm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
params:
freeze: true
layer: "penultimate"
img_cond_stage_config:
target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
params:
freeze: true
image_proj_stage_config:
target: lvdm.modules.encoders.resampler.Resampler
params:
dim: 1024
depth: 4
dim_head: 64
heads: 12
num_queries: 16
embedding_dim: 1280
output_dim: 1024
ff_mult: 4
video_length: 16
data:
target: utils_data.DataModuleFromConfig
params:
batch_size: 1
num_workers: 12
wrap: false
train:
target: lvdm.data.webvid.WebVid
params:
data_dir: <WebVid10M DATA>
meta_path: <.csv FILE>
video_length: 16
frame_stride: 6
load_raw_resolution: true
resolution: [576, 1024]
spatial_transform: resize_center_crop
random_fs: true ## if true, we uniformly sample fs with max_fs=frame_stride (above)
lightning:
precision: 16
# strategy: deepspeed_stage_2
trainer:
benchmark: True
accumulate_grad_batches: 2
max_steps: 100000
# logger
log_every_n_steps: 50
# val
val_check_interval: 0.5
gradient_clip_algorithm: 'norm'
gradient_clip_val: 0.5
callbacks:
model_checkpoint:
target: pytorch_lightning.callbacks.ModelCheckpoint
params:
every_n_train_steps: 9000 #1000
filename: "{epoch}-{step}"
save_weights_only: True
metrics_over_trainsteps_checkpoint:
target: pytorch_lightning.callbacks.ModelCheckpoint
params:
filename: '{epoch}-{step}'
save_weights_only: True
every_n_train_steps: 10000 #20000 # 3s/step*2w=
batch_logger:
target: callbacks.ImageLogger
params:
batch_frequency: 500
to_local: False
max_images: 8
log_images_kwargs:
ddim_steps: 50
unconditional_guidance_scale: 7.5
timestep_spacing: uniform_trailing
guidance_rescale: 0.7
================================================
FILE: ToonCrafter/configs/training_1024_v1.0/run.sh
================================================
# NCCL configuration
# export NCCL_DEBUG=INFO
# export NCCL_IB_DISABLE=0
# export NCCL_IB_GID_INDEX=3
# export NCCL_NET_GDR_LEVEL=3
# export NCCL_TOPO_FILE=/tmp/topo.txt
# args
name="training_1024_v1.0"
config_file=configs/${name}/config.yaml
# save root dir for logs, checkpoints, tensorboard record, etc.
save_root="<YOUR_SAVE_ROOT_DIR>"
mkdir -p $save_root/$name
## run
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \
--nproc_per_node=$HOST_GPU_NUM --nnodes=1 --master_addr=127.0.0.1 --master_port=12352 --node_rank=0 \
./main/trainer.py \
--base $config_file \
--train \
--name $name \
--logdir $save_root \
--devices $HOST_GPU_NUM \
lightning.trainer.num_nodes=1
## debugging
# CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch \
# --nproc_per_node=4 --nnodes=1 --master_addr=127.0.0.1 --master_port=12352 --node_rank=0 \
# ./main/trainer.py \
# --base $config_file \
# --train \
# --name $name \
# --logdir $save_root \
# --devices 4 \
# lightning.trainer.num_nodes=1
================================================
FILE: ToonCrafter/configs/training_512_v1.0/config.yaml
================================================
model:
pretrained_checkpoint: checkpoints/dynamicrafter_512_v1/model.ckpt
base_learning_rate: 1.0e-05
scale_lr: False
target: lvdm.models.ddpm3d.LatentVisualDiffusion
params:
rescale_betas_zero_snr: True
parameterization: "v"
linear_start: 0.00085
linear_end: 0.012
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: video
cond_stage_key: caption
cond_stage_trainable: False
image_proj_model_trainable: True
conditioning_key: hybrid
image_size: [40, 64]
channels: 4
scale_by_std: False
scale_factor: 0.18215
use_ema: False
uncond_prob: 0.05
uncond_type: 'empty_seq'
rand_cond_frame: true
use_dynamic_rescale: true
base_scale: 0.7
fps_condition_type: 'fps'
perframe_ae: True
unet_config:
target: lvdm.modules.networks.openaimodel3d.UNetModel
params:
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions:
- 4
- 2
- 1
num_res_blocks: 2
channel_mult:
- 1
- 2
- 4
- 4
dropout: 0.1
num_head_channels: 64
transformer_depth: 1
context_dim: 1024
use_linear: true
use_checkpoint: True
temporal_conv: True
temporal_attention: True
temporal_selfatt_only: true
use_relative_position: false
use_causal_attention: False
temporal_length: 16
addition_attention: true
image_cross_attention: true
default_fs: 10
fs_condition: true
first_stage_config:
target: lvdm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
params:
freeze: true
layer: "penultimate"
img_cond_stage_config:
target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
params:
freeze: true
image_proj_stage_config:
target: lvdm.modules.encoders.resampler.Resampler
params:
dim: 1024
depth: 4
dim_head: 64
heads: 12
num_queries: 16
embedding_dim: 1280
output_dim: 1024
ff_mult: 4
video_length: 16
data:
target: utils_data.DataModuleFromConfig
params:
batch_size: 2
num_workers: 12
wrap: false
train:
target: lvdm.data.webvid.WebVid
params:
data_dir: <WebVid10M DATA>
meta_path: <.csv FILE>
video_length: 16
frame_stride: 6
load_raw_resolution: true
resolution: [320, 512]
spatial_transform: resize_center_crop
random_fs: true ## if true, we uniformly sample fs with max_fs=frame_stride (above)
lightning:
precision: 16
# strategy: deepspeed_stage_2
trainer:
benchmark: True
accumulate_grad_batches: 2
max_steps: 100000
# logger
log_every_n_steps: 50
# val
val_check_interval: 0.5
gradient_clip_algorithm: 'norm'
gradient_clip_val: 0.5
callbacks:
model_checkpoint:
target: pytorch_lightning.callbacks.ModelCheckpoint
params:
every_n_train_steps: 9000 #1000
filename: "{epoch}-{step}"
save_weights_only: True
metrics_over_trainsteps_checkpoint:
target: pytorch_lightning.callbacks.ModelCheckpoint
params:
filename: '{epoch}-{step}'
save_weights_only: True
every_n_train_steps: 10000 #20000 # 3s/step*2w=
batch_logger:
target: callbacks.ImageLogger
params:
batch_frequency: 500
to_local: False
max_images: 8
log_images_kwargs:
ddim_steps: 50
unconditional_guidance_scale: 7.5
timestep_spacing: uniform_trailing
guidance_rescale: 0.7
================================================
FILE: ToonCrafter/configs/training_512_v1.0/run.sh
================================================
# NCCL configuration
# export NCCL_DEBUG=INFO
# export NCCL_IB_DISABLE=0
# export NCCL_IB_GID_INDEX=3
# export NCCL_NET_GDR_LEVEL=3
# export NCCL_TOPO_FILE=/tmp/topo.txt
# args
name="training_512_v1.0"
config_file=configs/${name}/config.yaml
# save root dir for logs, checkpoints, tensorboard record, etc.
save_root="<YOUR_SAVE_ROOT_DIR>"
mkdir -p $save_root/$name
## run
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \
--nproc_per_node=$HOST_GPU_NUM --nnodes=1 --master_addr=127.0.0.1 --master_port=12352 --node_rank=0 \
./main/trainer.py \
--base $config_file \
--train \
--name $name \
--logdir $save_root \
--devices $HOST_GPU_NUM \
lightning.trainer.num_nodes=1
## debugging
# CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch \
# --nproc_per_node=4 --nnodes=1 --master_addr=127.0.0.1 --master_port=12352 --node_rank=0 \
# ./main/trainer.py \
# --base $config_file \
# --train \
# --name $name \
# --logdir $save_root \
# --devices 4 \
# lightning.trainer.num_nodes=1
================================================
FILE: ToonCrafter/gradio_app.py
================================================
import os
import argparse
import sys
import gradio as gr
from scripts.gradio.i2v_test_application import Image2Video
sys.path.insert(1, os.path.join(sys.path[0], 'lvdm'))
i2v_examples_interp_512 = [
['prompts/512_interp/74906_1462_frame1.png', 'walking man', 50, 7.5, 1.0, 10, 123, 'prompts/512_interp/74906_1462_frame3.png'],
['prompts/512_interp/Japan_v2_2_062266_s2_frame1.png', 'an anime scene', 50, 7.5, 1.0, 10, 789, 'prompts/512_interp/Japan_v2_2_062266_s2_frame3.png'],
['prompts/512_interp/Japan_v2_3_119235_s2_frame1.png', 'an anime scene', 50, 7.5, 1.0, 10, 123, 'prompts/512_interp/Japan_v2_3_119235_s2_frame3.png'],
]
def dynamicrafter_demo(result_dir='./tmp/', res=512):
if res == 1024:
resolution = '576_1024'
css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height:576px}"""
elif res == 512:
resolution = '320_512'
css = """#input_img {max-width: 512px !important} #output_vid {max-width: 512px; max-height: 320px} #input_img2 {max-width: 512px !important} #output_vid {max-width: 512px; max-height: 320px}"""
elif res == 256:
resolution = '256_256'
css = """#input_img {max-width: 256px !important} #output_vid {max-width: 256px; max-height: 256px}"""
else:
raise NotImplementedError(f"Unsupported resolution: {res}")
image2video = Image2Video(result_dir, resolution=resolution)
with gr.Blocks(analytics_enabled=False, css=css) as dynamicrafter_iface:
with gr.Tab(label='ToonCrafter_320x512'):
with gr.Column():
with gr.Row():
with gr.Column():
with gr.Row():
i2v_input_image = gr.Image(label="Input Image1", elem_id="input_img")
with gr.Row():
i2v_input_text = gr.Text(label='Prompts')
with gr.Row():
i2v_seed = gr.Slider(label='Random Seed', minimum=0, maximum=50000, step=1, value=123)
i2v_eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label='ETA', value=1.0, elem_id="i2v_eta")
i2v_cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.5, elem_id="i2v_cfg_scale")
with gr.Row():
i2v_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="i2v_steps", label="Sampling steps", value=50)
i2v_motion = gr.Slider(minimum=5, maximum=30, step=1, elem_id="i2v_motion", label="FPS", value=10)
i2v_end_btn = gr.Button("Generate")
with gr.Column():
with gr.Row():
i2v_input_image2 = gr.Image(label="Input Image2", elem_id="input_img2")
with gr.Row():
i2v_output_video = gr.Video(label="Generated Video", elem_id="output_vid", autoplay=True, show_share_button=True)
gr.Examples(examples=i2v_examples_interp_512,
inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed, i2v_input_image2],
outputs=[i2v_output_video],
fn=image2video.get_image,
cache_examples=False,
)
i2v_end_btn.click(inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed, i2v_input_image2],
outputs=[i2v_output_video],
fn=image2video.get_image
)
return dynamicrafter_iface
def get_parser():
parser = argparse.ArgumentParser()
return parser
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
result_dir = os.path.join('./', 'results')
dynamicrafter_iface = dynamicrafter_demo(result_dir)
dynamicrafter_iface.queue(max_size=12)
dynamicrafter_iface.launch(max_threads=1)
# dynamicrafter_iface.launch(server_name='0.0.0.0', server_port=80, max_threads=1)
================================================
FILE: ToonCrafter/ldm/data/__init__.py
================================================
================================================
FILE: ToonCrafter/ldm/data/util.py
================================================
import torch
from ToonCrafter.ldm.modules.midas.api import load_midas_transform
class AddMiDaS(object):
def __init__(self, model_type):
super().__init__()
self.transform = load_midas_transform(model_type)
def pt2np(self, x):
x = ((x + 1.0) * .5).detach().cpu().numpy()
return x
def np2pt(self, x):
x = torch.from_numpy(x) * 2 - 1.
return x
def __call__(self, sample):
# sample['jpg'] is tensor hwc in [-1, 1] at this point
x = self.pt2np(sample['jpg'])
x = self.transform({"image": x})["image"]
sample['midas_in'] = x
return sample
================================================
FILE: ToonCrafter/ldm/models/autoencoder.py
================================================
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
from ToonCrafter.ldm.modules.diffusionmodules.model import Encoder, Decoder
from ToonCrafter.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ToonCrafter.ldm.util import instantiate_from_config
from ToonCrafter.ldm.modules.ema import LitEma
class AutoencoderKL(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False
):
super().__init__()
self.learn_logvar = learn_logvar
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.use_ema = ema_decay is not None
if self.use_ema:
self.ema_decay = ema_decay
assert 0. < ema_decay < 1.
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, postfix=""):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
if self.learn_logvar:
print(f"{self.__class__.__name__}: Learning logvar")
ae_params_list.append(self.loss.logvar)
opt_ae = torch.optim.Adam(ae_params_list,
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
if log_ema or self.use_ema:
with self.ema_scope():
xrec_ema, posterior_ema = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec_ema.shape[1] > 3
xrec_ema = self.to_rgb(xrec_ema)
log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
log["reconstructions_ema"] = xrec_ema
log["inputs"] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x
================================================
FILE: ToonCrafter/ldm/models/diffusion/__init__.py
================================================
================================================
FILE: ToonCrafter/ldm/models/diffusion/ddim.py
================================================
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from ToonCrafter.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
ucg_schedule=None,
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule
)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
if ucg_schedule is not None:
assert len(ucg_schedule) == len(time_range)
unconditional_guidance_scale = ucg_schedule[i]
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [torch.cat([
unconditional_conditioning[k][i],
c[k][i]]) for i in range(len(c[k]))]
else:
c_in[k] = torch.cat([
unconditional_conditioning[k],
c[k]])
elif isinstance(c, list):
c_in = list()
assert isinstance(unconditional_conditioning, list)
for i in range(len(c)):
c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
else:
c_in = torch.cat([unconditional_conditioning, c])
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
if self.model.parameterization == "v":
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
else:
e_t = model_output
if score_corrector is not None:
assert self.model.parameterization == "eps", 'not implemented'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
raise NotImplementedError()
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc='Encoding Image'):
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
if unconditional_guidance_scale == 1.:
noise_pred = self.model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
torch.cat((unconditional_conditioning, c))), 2)
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
weighted_noise_pred = alphas_next[i].sqrt() * (
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
x_next = xt_weighted + weighted_noise_pred
if return_intermediates and i % (
num_steps // return_intermediates) == 0 and i < num_steps - 1:
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
if callback: callback(i)
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
if return_intermediates:
out.update({'intermediates': intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False, callback=None):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
if callback: callback(i)
return x_dec
================================================
FILE: ToonCrafter/ldm/models/diffusion/ddpm.py
================================================
"""
wild mixture of
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
https://github.com/CompVis/taming-transformers
-- merci
"""
import torch
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
from torch.optim.lr_scheduler import LambdaLR
from einops import rearrange, repeat
from contextlib import contextmanager, nullcontext
from functools import partial
import itertools
from tqdm import tqdm
from torchvision.utils import make_grid
from pytorch_lightning.utilities import rank_zero_only
from omegaconf import ListConfig
from ToonCrafter.ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
from ToonCrafter.ldm.modules.ema import LitEma
from ToonCrafter.ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
from ToonCrafter.ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
from ToonCrafter.ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
from ToonCrafter.ldm.models.diffusion.ddim import DDIMSampler
__conditioning_keys__ = {'concat': 'c_concat',
'crossattn': 'c_crossattn',
'adm': 'y'}
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def uniform_on_device(r1, r2, shape, device):
return (r1 - r2) * torch.rand(*shape, device=device) + r2
class DDPM(pl.LightningModule):
# classic DDPM with Gaussian diffusion, in image space
def __init__(self,
unet_config,
timesteps=1000,
beta_schedule="linear",
loss_type="l2",
ckpt_path=None,
ignore_keys=[],
load_only_unet=False,
monitor="val/loss",
use_ema=True,
first_stage_key="image",
image_size=256,
channels=3,
log_every_t=100,
clip_denoised=True,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
given_betas=None,
original_elbo_weight=0.,
v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
l_simple_weight=1.,
conditioning_key=None,
parameterization="eps", # all assuming fixed variance schedules
scheduler_config=None,
use_positional_encodings=False,
learn_logvar=False,
logvar_init=0.,
make_it_fit=False,
ucg_training=None,
reset_ema=False,
reset_num_ema_updates=False,
):
super().__init__()
assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
self.parameterization = parameterization
print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
self.cond_stage_model = None
self.clip_denoised = clip_denoised
self.log_every_t = log_every_t
self.first_stage_key = first_stage_key
self.image_size = image_size # try conv?
self.channels = channels
self.use_positional_encodings = use_positional_encodings
self.model = DiffusionWrapper(unet_config, conditioning_key)
count_params(self.model, verbose=True)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.use_scheduler = scheduler_config is not None
if self.use_scheduler:
self.scheduler_config = scheduler_config
self.v_posterior = v_posterior
self.original_elbo_weight = original_elbo_weight
self.l_simple_weight = l_simple_weight
if monitor is not None:
self.monitor = monitor
self.make_it_fit = make_it_fit
if reset_ema: assert exists(ckpt_path)
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
if reset_ema:
assert self.use_ema
print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
self.model_ema = LitEma(self.model)
if reset_num_ema_updates:
print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
assert self.use_ema
self.model_ema.reset_num_updates()
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
self.loss_type = loss_type
self.learn_logvar = learn_logvar
logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
if self.learn_logvar:
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
else:
self.register_buffer('logvar', logvar)
self.ucg_training = ucg_training or dict()
if self.ucg_training:
self.ucg_prng = np.random.RandomState()
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if exists(given_betas):
betas = given_betas
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
1. - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', to_torch(posterior_variance))
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
self.register_buffer('posterior_mean_coef1', to_torch(
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
self.register_buffer('posterior_mean_coef2', to_torch(
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
if self.parameterization == "eps":
lvlb_weights = self.betas ** 2 / (
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
elif self.parameterization == "x0":
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
elif self.parameterization == "v":
lvlb_weights = torch.ones_like(self.betas ** 2 / (
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
else:
raise NotImplementedError("mu not supported")
lvlb_weights[0] = lvlb_weights[1]
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
assert not torch.isnan(self.lvlb_weights).all()
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
if context is not None:
print(f"{context}: Restored training weights")
@torch.no_grad()
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
if self.make_it_fit:
n_params = len([name for name, _ in
itertools.chain(self.named_parameters(),
self.named_buffers())])
for name, param in tqdm(
itertools.chain(self.named_parameters(),
self.named_buffers()),
desc="Fitting old weights to new weights",
total=n_params
):
if not name in sd:
continue
old_shape = sd[name].shape
new_shape = param.shape
assert len(old_shape) == len(new_shape)
if len(new_shape) > 2:
# we only modify first two axes
assert new_shape[2:] == old_shape[2:]
# assumes first axis corresponds to output dim
if not new_shape == old_shape:
new_param = param.clone()
old_param = sd[name]
if len(new_shape) == 1:
for i in range(new_param.shape[0]):
new_param[i] = old_param[i % old_shape[0]]
elif len(new_shape) >= 2:
for i in range(new_param.shape[0]):
for j in range(new_param.shape[1]):
new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
n_used_old = torch.ones(old_shape[1])
for j in range(new_param.shape[1]):
n_used_old[j % old_shape[1]] += 1
n_used_new = torch.zeros(new_shape[1])
for j in range(new_param.shape[1]):
n_used_new[j] = n_used_old[j % old_shape[1]]
n_used_new = n_used_new[None, :]
while len(n_used_new.shape) < len(new_shape):
n_used_new = n_used_new.unsqueeze(-1)
new_param /= n_used_new
sd[name] = new_param
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys:\n {missing}")
if len(unexpected) > 0:
print(f"\nUnexpected Keys:\n {unexpected}")
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def predict_start_from_noise(self, x_t, t, noise):
return (
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def predict_start_from_z_and_v(self, x_t, t, v):
# self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
# self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
def predict_eps_from_z_and_v(self, x_t, t, v):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, clip_denoised: bool):
model_out = self.model(x, t)
if self.parameterization == "eps":
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
if clip_denoised:
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, return_intermediates=False):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
intermediates = [img]
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
clip_denoised=self.clip_denoised)
if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
intermediates.append(img)
if return_intermediates:
return img, intermediates
return img
@torch.no_grad()
def sample(self, batch_size=16, return_intermediates=False):
image_size = self.image_size
channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size),
return_intermediates=return_intermediates)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
def get_v(self, x, noise, t):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
)
def get_loss(self, pred, target, mean=True):
if self.loss_type == 'l1':
loss = (target - pred).abs()
if mean:
loss = loss.mean()
elif self.loss_type == 'l2':
if mean:
loss = torch.nn.functional.mse_loss(target, pred)
else:
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
else:
raise NotImplementedError("unknown loss type '{loss_type}'")
return loss
def p_losses(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_out = self.model(x_noisy, t)
loss_dict = {}
if self.parameterization == "eps":
target = noise
elif self.parameterization == "x0":
target = x_start
elif self.parameterization == "v":
target = self.get_v(x_start, noise, t)
else:
raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
log_prefix = 'train' if self.training else 'val'
loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
loss_simple = loss.mean() * self.l_simple_weight
loss_vlb = (self.lvlb_weights[t] * loss).mean()
loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
loss = loss_simple + self.original_elbo_weight * loss_vlb
loss_dict.update({f'{log_prefix}/loss': loss})
return loss, loss_dict
def forward(self, x, *args, **kwargs):
# b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
# assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
return self.p_losses(x, t, *args, **kwargs)
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = rearrange(x, 'b h w c -> b c h w')
x = x.to(memory_format=torch.contiguous_format).float()
return x
def shared_step(self, batch):
x = self.get_input(batch, self.first_stage_key)
loss, loss_dict = self(x)
return loss, loss_dict
def training_step(self, batch, batch_idx):
for k in self.ucg_training:
p = self.ucg_training[k]["p"]
val = self.ucg_training[k]["val"]
if val is None:
val = ""
for i in range(len(batch[k])):
if self.ucg_prng.choice(2, p=[1 - p, p]):
batch[k][i] = val
loss, loss_dict = self.shared_step(batch)
self.log_dict(loss_dict, prog_bar=True,
logger=True, on_step=True, on_epoch=True)
self.log("global_step", self.global_step,
prog_bar=True, logger=True, on_step=True, on_epoch=False)
if self.use_scheduler:
lr = self.optimizers().param_groups[0]['lr']
self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
return loss
@torch.no_grad()
def validation_step(self, batch, batch_idx):
_, loss_dict_no_ema = self.shared_step(batch)
with self.ema_scope():
_, loss_dict_ema = self.shared_step(batch)
loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self.model)
def _get_rows_from_list(self, samples):
n_imgs_per_row = len(samples)
denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
return denoise_grid
@torch.no_grad()
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
log = dict()
x = self.get_input(batch, self.first_stage_key)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
x = x.to(self.device)[:N]
log["inputs"] = x
# get diffusion row
diffusion_row = list()
x_start = x[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(x_start)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
diffusion_row.append(x_noisy)
log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
if sample:
# get denoise row
with self.ema_scope("Plotting"):
samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
log["samples"] = samples
log["denoise_row"] = self._get_rows_from_list(denoise_row)
if return_keys:
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
return log
else:
return {key: log[key] for key in return_keys}
return log
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
if self.learn_logvar:
params = params + [self.logvar]
opt = torch.optim.AdamW(params, lr=lr)
return opt
class LatentDiffusion(DDPM):
"""main class"""
def __init__(self,
first_stage_config,
cond_stage_config,
num_timesteps_cond=None,
cond_stage_key="image",
cond_stage_trainable=False,
concat_mode=True,
cond_stage_forward=None,
conditioning_key=None,
scale_factor=1.0,
scale_by_std=False,
force_null_conditioning=False,
*args, **kwargs):
self.force_null_conditioning = force_null_conditioning
self.num_timesteps_cond = default(num_timesteps_cond, 1)
self.scale_by_std = scale_by_std
assert self.num_timesteps_cond <= kwargs['timesteps']
# for backwards compatibility after implementation of DiffusionWrapper
if conditioning_key is None:
conditioning_key = 'concat' if concat_mode else 'crossattn'
if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning:
conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None)
reset_ema = kwargs.pop("reset_ema", False)
reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
ignore_keys = kwargs.pop("ignore_keys", [])
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key
try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except:
self.num_downs = 0
if not scale_by_std:
self.scale_factor = scale_factor
else:
self.register_buffer('scale_factor', torch.tensor(scale_factor))
self.instantiate_first_stage(first_stage_config)
self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
self.bbox_tokenizer = None
self.restarted_from_ckpt = False
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys)
self.restarted_from_ckpt = True
if reset_ema:
assert self.use_ema
print(
f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
self.model_ema = LitEma(self.model)
if reset_num_ema_updates:
print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
assert self.use_ema
self.model_ema.reset_num_updates()
def make_cond_schedule(self, ):
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
self.cond_ids[:self.num_timesteps_cond] = ids
@rank_zero_only
@torch.no_grad()
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
# only for very first batch
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
# set rescale weight to 1./std of encodings
print("### USING STD-RESCALING ###")
x = super().get_input(batch, self.first_stage_key)
x = x.to(self.device)
encoder_posterior = self.encode_first_stage(x)
z = self.get_first_stage_encoding(encoder_posterior).detach()
del self.scale_factor
self.register_buffer('scale_factor', 1. / z.flatten().std())
print(f"setting self.scale_factor to {self.scale_factor}")
print("### USING STD-RESCALING ###")
def register_schedule(self,
given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
self.shorten_cond_schedule = self.num_timesteps_cond > 1
if self.shorten_cond_schedule:
self.make_cond_schedule()
def instantiate_first_stage(self, config):
model = instantiate_from_config(config)
self.first_stage_model = model.eval()
self.first_stage_model.train = disabled_train
for param in self.first_stage_model.parameters():
param.requires_grad = False
def instantiate_cond_stage(self, config):
if not self.cond_stage_trainable:
if config == "__is_first_stage__":
print("Using first stage also as cond stage.")
self.cond_stage_model = self.first_stage_model
elif config == "__is_unconditional__":
print(f"Training {self.__class__.__name__} as an unconditional model.")
self.cond_stage_model = None
# self.be_unconditional = True
else:
model = instantiate_from_config(config)
self.cond_stage_model = model.eval()
self.cond_stage_model.train = disabled_train
for param in self.cond_stage_model.parameters():
param.requires_grad = False
else:
assert config != '__is_first_stage__'
assert config != '__is_unconditional__'
model = instantiate_from_config(config)
self.cond_stage_model = model
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
denoise_row = []
for zd in tqdm(samples, desc=desc):
denoise_row.append(self.decode_first_stage(zd.to(self.device),
force_not_quantize=force_no_decoder_quantization))
n_imgs_per_row = len(denoise_row)
denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
return denoise_grid
def get_first_stage_encoding(self, encoder_posterior):
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
z = encoder_posterior.sample()
elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
return self.scale_factor * z
def get_learned_conditioning(self, c):
if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
c = self.cond_stage_model.encode(c)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
else:
c = self.cond_stage_model(c)
else:
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
return c
def meshgrid(self, h, w):
y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
arr = torch.cat([y, x], dim=-1)
return arr
def delta_border(self, h, w):
"""
:param h: height
:param w: width
:return: normalized distance to image border,
wtith min distance = 0 at border and max dist = 0.5 at image center
"""
lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
arr = self.meshgrid(h, w) / lower_right_corner
dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
return edge_dist
def get_weighting(self, h, w, Ly, Lx, device):
weighting = self.delta_border(h, w)
weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
self.split_input_params["clip_max_weight"], )
weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
if self.split_input_params["tie_braker"]:
L_weighting = self.delta_border(Ly, Lx)
L_weighting = torch.clip(L_weighting,
self.split_input_params["clip_min_tie_weight"],
self.split_input_params["clip_max_tie_weight"])
L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
weighting = weighting * L_weighting
return weighting
def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
"""
:param x: img of size (bs, c, h, w)
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
"""
bs, nc, h, w = x.shape
# number of crops in image
Ly = (h - kernel_size[0]) // stride[0] + 1
Lx = (w - kernel_size[1]) // stride[1] + 1
if uf == 1 and df == 1:
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
unfold = torch.nn.Unfold(**fold_params)
fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
elif uf > 1 and df == 1:
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
unfold = torch.nn.Unfold(**fold_params)
fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
dilation=1, padding=0,
stride=(stride[0] * uf, stride[1] * uf))
fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
elif df > 1 and uf == 1:
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
unfold = torch.nn.Unfold(**fold_params)
fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
dilation=1, padding=0,
stride=(stride[0] // df, stride[1] // df))
fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
else:
raise NotImplementedError
return fold, unfold, normalization, weighting
@torch.no_grad()
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
cond_key=None, return_original_cond=False, bs=None, return_x=False):
x = super().get_input(batch, k)
if bs is not None:
x = x[:bs]
x = x.to(self.device)
encoder_posterior = self.encode_first_stage(x)
z = self.get_first_stage_encoding(encoder_posterior).detach()
if self.model.conditioning_key is not None and not self.force_null_conditioning:
if cond_key is None:
cond_key = self.cond_stage_key
if cond_key != self.first_stage_key:
if cond_key in ['caption', 'coordinates_bbox', "txt"]:
xc = batch[cond_key]
elif cond_key in ['class_label', 'cls']:
xc = batch
else:
xc = super().get_input(batch, cond_key).to(self.device)
else:
xc = x
if not self.cond_stage_trainable or force_c_encode:
if isinstance(xc, dict) or isinstance(xc, list):
c = self.get_learned_conditioning(xc)
else:
c = self.get_learned_conditioning(xc.to(self.device))
else:
c = xc
if bs is not None:
c = c[:bs]
if self.use_positional_encodings:
pos_x, pos_y = self.compute_latent_shifts(batch)
ckey = __conditioning_keys__[self.model.conditioning_key]
c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
else:
c = None
xc = None
if self.use_positional_encodings:
pos_x, pos_y = self.compute_latent_shifts(batch)
c = {'pos_x': pos_x, 'pos_y': pos_y}
out = [z, c]
if return_first_stage_outputs:
xrec = self.decode_first_stage(z)
out.extend([x, xrec])
if return_x:
out.extend([x])
if return_original_cond:
out.append(xc)
return out
@torch.no_grad()
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
if predict_cids:
if z.dim() == 4:
z = torch.argmax(z.exp(), dim=1).long()
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
z = rearrange(z, 'b h w c -> b c h w').contiguous()
z = 1. / self.scale_factor * z
return self.first_stage_model.decode(z)
@torch.no_grad()
def encode_first_stage(self, x):
return self.first_stage_model.encode(x)
def shared_step(self, batch, **kwargs):
x, c = self.get_input(batch, self.first_stage_key)
loss = self(x, c)
return loss
def forward(self, x, c, *args, **kwargs):
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
if self.model.conditioning_key is not None:
assert c is not None
if self.cond_stage_trainable:
c = self.get_learned_conditioning(c)
if self.shorten_cond_schedule: # TODO: drop this option
tc = self.cond_ids[t].to(self.device)
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs)
def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict):
# hybrid case, cond is expected to be a dict
pass
else:
if not isinstance(cond, list):
cond = [cond]
key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
cond = {key: cond}
x_recon = self.model(x_noisy, t, **cond)
if isinstance(x_recon, tuple) and not return_ids:
return x_recon[0]
else:
return x_recon
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def _prior_bpd(self, x_start):
"""
Get the prior KL term for the variational lower-bound, measured in
bits-per-dim.
This term can't be optimized, as it only depends on the encoder.
:param x_start: the [N x C x ...] tensor of inputs.
:return: a batch of [N] KL values (in bits), one per batch element.
"""
batch_size = x_start.shape[0]
t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
return mean_flat(kl_prior) / np.log(2.0)
def p_losses(self, x_start, cond, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_output = self.apply_model(x_noisy, t, cond)
loss_dict = {}
prefix = 'train' if self.training else 'val'
if self.parameterization == "x0":
target = x_start
elif self.parameterization == "eps":
target = noise
elif self.parameterization == "v":
target = self.get_v(x_start, noise, t)
else:
raise NotImplementedError()
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
logvar_t = self.logvar[t].to(self.device)
loss = loss_simple / torch.exp(logvar_t) + logvar_t
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
if self.learn_logvar:
loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
loss_dict.update({'logvar': self.logvar.data.mean()})
loss = self.l_simple_weight * loss.mean()
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
loss += (self.original_elbo_weight * loss_vlb)
loss_dict.update({f'{prefix}/loss': loss})
return loss, loss_dict
def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
return_x0=False, score_corrector=None, corrector_kwargs=None):
t_in = t
model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
if score_corrector is not None:
assert self.parameterization == "eps"
model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
if return_codebook_ids:
model_out, logits = model_out
if self.parameterization == "eps":
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
else:
raise NotImplementedError()
if clip_denoised:
x_recon.clamp_(-1., 1.)
if quantize_denoised:
x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
if return_codebook_ids:
return model_mean, posterior_variance, posterior_log_variance, logits
elif return_x0:
return model_mean, posterior_variance, posterior_log_variance, x_recon
else:
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
return_codebook_ids=False, quantize_denoised=False, return_x0=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
b, *_, device = *x.shape, x.device
outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
return_codebook_ids=return_codebook_ids,
quantize_denoised=quantize_denoised,
return_x0=return_x0,
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
if return_codebook_ids:
raise DeprecationWarning("Support dropped.")
model_mean, _, model_log_variance, logits = outputs
elif return_x0:
model_mean, _, model_log_variance, x0 = outputs
else:
model_mean, _, model_log_variance = outputs
noise = noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
if return_codebook_ids:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
if return_x0:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
else:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
log_every_t=None):
if not log_every_t:
log_every_t = self.log_every_t
timesteps = self.num_timesteps
if batch_size is not None:
b = batch_size if batch_size is not None else shape[0]
shape = [batch_size] + list(shape)
else:
b = batch_size = shape[0]
if x_T is None:
img = torch.randn(shape, device=self.device)
else:
img = x_T
intermediates = []
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
if start_T is not None:
timesteps = min(timesteps, start_T)
iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
total=timesteps) if verbose else reversed(
range(0, timesteps))
if type(temperature) == float:
temperature = [temperature] * timesteps
for i in iterator:
ts = torch.full((b,), i, device=self.device, dtype=torch.long)
if self.shorten_cond_schedule:
assert self.model.conditioning_key != 'hybrid'
tc = self.cond_ids[ts].to(cond.device)
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
img, x0_partial = self.p_sample(img, cond, ts,
clip_denoised=self.clip_denoised,
quantize_denoised=quantize_denoised, return_x0=True,
temperature=temperature[i], noise_dropout=noise_dropout,
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
if mask is not None:
assert x0 is not None
img_orig = self.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(x0_partial)
if callback: callback(i)
if img_callback: img_callback(img, i)
return img, intermediates
@torch.no_grad()
def p_sample_loop(self, cond, shape, return_intermediates=False,
x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, start_T=None,
log_every_t=None):
if not log_every_t:
log_every_t = self.log_every_t
device = self.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
intermediates = [img]
if timesteps is None:
timesteps = self.num_timesteps
if start_T is not None:
timesteps = min(timesteps, start_T)
iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
range(0, timesteps))
if mask is not None:
assert x0 is not None
assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
for i in iterator:
ts = torch.full((b,), i, device=device, dtype=torch.long)
if self.shorten_cond_schedule:
assert self.model.conditioning_key != 'hybrid'
tc = self.cond_ids[ts].to(cond.device)
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
img = self.p_sample(img, cond, ts,
clip_denoised=self.clip_denoised,
quantize_denoised=quantize_denoised)
if mask is not None:
img_orig = self.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(img)
if callback: callback(i)
if img_callback: img_callback(img, i)
if return_intermediates:
return img, intermediates
return img
@torch.no_grad()
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
verbose=True, timesteps=None, quantize_denoised=False,
mask=None, x0=None, shape=None, **kwargs):
if shape is None:
shape = (batch_size, self.channels, self.image_size, self.image_size)
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
return self.p_sample_loop(cond,
shape,
return_intermediates=return_intermediates, x_T=x_T,
verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
mask=mask, x0=x0)
@torch.no_grad()
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
if ddim:
ddim_sampler = DDIMSampler(self)
shape = (self.channels, self.image_size, self.image_size)
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
shape, cond, verbose=False, **kwargs)
else:
samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
return_intermediates=True, **kwargs)
return samples, intermediates
@torch.no_grad()
def get_unconditional_conditioning(self, batch_size, null_label=None):
if null_label is not None:
xc = null_label
if isinstance(xc, ListConfig):
xc = list(xc)
if isinstance(xc, dict) or isinstance(xc, list):
c = self.get_learned_conditioning(xc)
else:
if hasattr(xc, "to"):
xc = xc.to(self.device)
c = self.get_learned_conditioning(xc)
else:
if self.cond_stage_key in ["class_label", "cls"]:
xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
return self.get_learned_conditioning(xc)
else:
raise NotImplementedError("todo")
if isinstance(c, list): # in case the encoder gives us a list
for i in range(len(c)):
c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
else:
c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
return c
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
use_ema_scope=True,
**kwargs):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None
log = dict()
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
return_original_cond=True,
bs=N)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
log["inputs"] = x
log["reconstruction"] = xrec
if self.model.conditioning_key is not None:
if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc
elif self.cond_stage_key in ["caption", "txt"]:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
log["conditioning"] = xc
elif self.cond_stage_key in ['class_label', "cls"]:
try:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
log['conditioning'] = xc
except KeyError:
# probably no "human_label" in batch
pass
elif isimage(xc):
log["conditioning"] = xc
if ismap(xc):
log["original_conditioning"] = self.to_rgb(xc)
if plot_diffusion_rows:
# get diffusion row
diffusion_row = list()
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
with ema_scope("Sampling"):
samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
self.first_stage_model, IdentityFirstStage):
# also display when quantizing x0 while sampling
with ema_scope("Plotting Quantized Denoised"):
samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
quantize_denoised=True)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
# quantize_denoised=True)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_x0_quantized"] = x_samples
if unconditional_guidance_scale > 1.0:
uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
if self.model.conditioning_key == "crossattn-adm":
uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
with ema_scope("Sampling with classifier-free guidance"):
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
if inpaint:
# make a simple center square
b, h, w = z.shape[0], z.shape[2], z.shape[3]
mask = torch.ones(N, h, w).to(self.device)
# zeros will be filled in
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
mask = mask[:, None, ...]
with ema_scope("Plotting Inpaint"):
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_inpainting"] = x_samples
log["mask"] = mask
# outpaint
mask = 1. - mask
with ema_scope("Plotting Outpaint"):
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_outpainting"] = x_samples
if plot_progressive_rows:
with ema_scope("Plotting Progressives"):
img, progressives = self.progressive_denoising(c,
shape=(self.channels, self.image_size, self.image_size),
batch_size=N)
prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
log["progressive_row"] = prog_row
if return_keys:
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
return log
else:
return {key: log[key] for key in return_keys}
return log
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
if self.cond_stage_trainable:
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
params = params + list(self.cond_stage_model.parameters())
if self.learn_logvar:
print('Diffusion model optimizing logvar')
params.append(self.logvar)
opt = torch.optim.AdamW(params, lr=lr)
if self.use_scheduler:
assert 'target' in self.scheduler_config
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
}]
return [opt], scheduler
return opt
@torch.no_grad()
def to_rgb(self, x):
x = x.float()
if not hasattr(self, "colorize"):
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
x = nn.functional.conv2d(x, weight=self.colorize)
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x
class DiffusionWrapper(pl.LightningModule):
def __init__(self, diff_model_config, conditioning_key):
super().__init__()
self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
self.diffusion_model = instantiate_from_config(diff_model_config)
self.conditioning_key = conditioning_key
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
if self.conditioning_key is None:
out = self.diffusion_model(x, t)
elif self.conditioning_key == 'concat':
xc = torch.cat([x] + c_concat, dim=1)
out = self.diffusion_model(xc, t)
elif self.conditioning_key == 'crossattn':
if not self.sequential_cross_attn:
cc = torch.cat(c_crossattn, 1)
else:
cc = c_crossattn
out = self.diffusion_model(x, t, context=cc)
elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc)
elif self.conditioning_key == 'hybrid-adm':
assert c_adm is not None
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc, y=c_adm)
elif self.conditioning_key == 'crossattn-adm':
assert c_adm is not None
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc, y=c_adm)
elif self.conditioning_key == 'adm':
cc = c_crossattn[0]
out = self.diffusion_model(x, t, y=cc)
else:
raise NotImplementedError()
return out
class LatentUpscaleDiffusion(LatentDiffusion):
def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs):
super().__init__(*args, **kwargs)
# assumes that neither the cond_stage nor the low_scale_model contain trainable params
assert not self.cond_stage_trainable
self.instantiate_low_stage(low_scale_config)
self.low_scale_key = low_scale_key
self.noise_level_key = noise_level_key
def instantiate_low_stage(self, config):
model = instantiate_from_config(config)
self.low_scale_model = model.eval()
self.low_scale_model.train = disabled_train
for param in self.low_scale_model.parameters():
param.requires_grad = False
@torch.no_grad()
def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
if not log_mode:
z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
else:
z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
force_c_encode=True, return_original_cond=True, bs=bs)
x_low = batch[self.low_scale_key][:bs]
x_low = rearrange(x_low, 'b h w c -> b c h w')
x_low = x_low.to(memory_format=torch.contiguous_format).float()
zx, noise_level = self.low_scale_model(x_low)
if self.noise_level_key is not None:
# get noise level from batch instead, e.g. when extracting a custom noise level for bsr
raise NotImplementedError('TODO')
all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
if log_mode:
# TODO: maybe disable if too expensive
x_low_rec = self.low_scale_model.decode(zx)
return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
return z, all_conds
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
**kwargs):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None
log = dict()
z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
log_mode=True)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
log["inputs"] = x
log["reconstruction"] = xrec
log["x_lr"] = x_low
log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
if self.model.conditioning_key is not None:
if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc
elif self.cond_stage_key in ["caption", "txt"]:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
log["conditioning"] = xc
elif self.cond_stage_key in ['class_label', 'cls']:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
log['conditioning'] = xc
elif isimage(xc):
log["conditioning"] = xc
if ismap(xc):
log["original_conditioning"] = self.to_rgb(xc)
if plot_diffusion_rows:
# get diffusion row
diffusion_row = list()
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
with ema_scope("Sampling"):
samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if unconditional_guidance_scale > 1.0:
uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
# TODO explore better "unconditional" choices for the other keys
# maybe guide away from empty text label and highest noise level and maximally degraded zx?
uc = dict()
for k in c:
if k == "c_crossattn":
assert isinstance(c[k], list) and len(c[k]) == 1
uc[k] = [uc_tmp]
elif k == "c_adm": # todo: only run with text-based guidance?
assert isinstance(c[k], torch.Tensor)
#uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
uc[k] = c[k]
elif isinstance(c[k], list):
uc[k] = [c[k][i] for i in range(len(c[k]))]
else:
uc[k] = c[k]
with ema_scope("Sampling with classifier-free guidance"):
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
if plot_progressive_rows:
with ema_scope("Plotting Progressives"):
img, progressives = self.progressive_denoising(c,
shape=(self.channels, self.image_size, self.image_size),
batch_size=N)
prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
log["progressive_row"] = prog_row
return log
class LatentFinetuneDiffusion(LatentDiffusion):
"""
Basis for different finetunas, such as inpainting or depth2image
To disable finetuning mode, set finetune_keys to None
"""
def __init__(self,
concat_keys: tuple,
finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
"model_ema.diffusion_modelinput_blocks00weight"
),
keep_finetune_dims=4,
# if model was trained without concat mode before and we would like to keep these channels
c_concat_log_start=None, # to log reconstruction of c_concat codes
c_concat_log_end=None,
*args, **kwargs
):
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", list())
super().__init__(*args, **kwargs)
self.finetune_keys = finetune_keys
self.concat_keys = concat_keys
self.keep_dims = keep_finetune_dims
self.c_concat_log_start = c_concat_log_start
self.c_concat_log_end = c_concat_log_end
if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
if exists(ckpt_path):
self.init_from_ckpt(ckpt_path, ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
# make it explicit, finetune by including extra input channels
if exists(self.finetune_keys) and k in self.finetune_keys:
new_entry = None
for name, param in self.named_parameters():
if name in self.finetune_keys:
print(
f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
new_entry = torch.zeros_like(param) # zero init
assert exists(new_entry), 'did not find matching parameter to modify'
new_entry[:, :self.keep_dims, ...] = sd[k]
sd[k] = new_entry
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
use_ema_scope=True,
**kwargs):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None
log = dict()
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
log["inputs"] = x
log["reconstruction"] = xrec
if self.model.conditioning_key is not None:
if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc
elif self.cond_stage_key in ["caption", "txt"]:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
log["conditioning"] = xc
elif self.cond_stage_key in ['class_label', 'cls']:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
log['conditioning'] = xc
elif isimage(xc):
log["conditioning"] = xc
if ismap(xc):
log["original_conditioning"] = self.to_rgb(xc)
if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end])
if plot_diffusion_rows:
# get diffusion row
diffusion_row = list()
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
with ema_scope("Sampling"):
samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if unconditional_guidance_scale > 1.0:
uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
uc_cat = c_cat
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
with ema_scope("Sampling with classifier-free guidance"):
samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc_full,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
return log
class LatentInpaintDiffusion(LatentFinetuneDiffusion):
"""
can either run as pure inpainting model (only concat mode) or with mixed conditionings,
e.g. mask as concat and text via cross-attn.
To disable finetuning mode, set finetune_keys to None
"""
def __init__(self,
concat_keys=("mask", "masked_image"),
masked_image_key="masked_image",
*args, **kwargs
):
super().__init__(concat_keys, *args, **kwargs)
self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys
@torch.no_grad()
def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
gitextract_h58eh22a/ ├── .github/ │ └── workflows/ │ └── publish.yml ├── .gitignore ├── LICENSE ├── ToonCrafter/ │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── cldm/ │ │ ├── cldm.py │ │ ├── ddim_hacked.py │ │ ├── hack.py │ │ ├── logger.py │ │ └── model.py │ ├── configs/ │ │ ├── cldm_v21.yaml │ │ ├── inference_512_v1.0.yaml │ │ ├── training_1024_v1.0/ │ │ │ ├── config.yaml │ │ │ └── run.sh │ │ └── training_512_v1.0/ │ │ ├── config.yaml │ │ └── run.sh │ ├── gradio_app.py │ ├── ldm/ │ │ ├── data/ │ │ │ ├── __init__.py │ │ │ └── util.py │ │ ├── models/ │ │ │ ├── autoencoder.py │ │ │ └── diffusion/ │ │ │ ├── __init__.py │ │ │ ├── ddim.py │ │ │ ├── ddpm.py │ │ │ ├── dpm_solver/ │ │ │ │ ├── __init__.py │ │ │ │ ├── dpm_solver.py │ │ │ │ └── sampler.py │ │ │ ├── plms.py │ │ │ └── sampling_util.py │ │ ├── modules/ │ │ │ ├── attention.py │ │ │ ├── diffusionmodules/ │ │ │ │ ├── __init__.py │ │ │ │ ├── model.py │ │ │ │ ├── openaimodel.py │ │ │ │ ├── upscaling.py │ │ │ │ └── util.py │ │ │ ├── distributions/ │ │ │ │ ├── __init__.py │ │ │ │ └── distributions.py │ │ │ ├── ema.py │ │ │ ├── encoders/ │ │ │ │ ├── __init__.py │ │ │ │ └── modules.py │ │ │ ├── image_degradation/ │ │ │ │ ├── __init__.py │ │ │ │ ├── bsrgan.py │ │ │ │ ├── bsrgan_light.py │ │ │ │ └── utils_image.py │ │ │ └── midas/ │ │ │ ├── __init__.py │ │ │ ├── api.py │ │ │ ├── midas/ │ │ │ │ ├── __init__.py │ │ │ │ ├── base_model.py │ │ │ │ ├── blocks.py │ │ │ │ ├── dpt_depth.py │ │ │ │ ├── midas_net.py │ │ │ │ ├── midas_net_custom.py │ │ │ │ ├── transforms.py │ │ │ │ └── vit.py │ │ │ └── utils.py │ │ └── util.py │ ├── lvdm/ │ │ ├── __init__.py │ │ ├── basics.py │ │ ├── common.py │ │ ├── data/ │ │ │ ├── base.py │ │ │ └── webvid.py │ │ ├── distributions.py │ │ ├── ema.py │ │ ├── models/ │ │ │ ├── autoencoder.py │ │ │ ├── autoencoder_dualref.py │ │ │ ├── ddpm3d.py │ │ │ ├── samplers/ │ │ │ │ ├── ddim.py │ │ │ │ └── ddim_multiplecond.py │ │ │ └── utils_diffusion.py │ │ └── modules/ │ │ ├── attention.py │ │ ├── attention_svd.py │ │ ├── encoders/ │ │ │ ├── condition.py │ │ │ └── resampler.py │ │ ├── networks/ │ │ │ ├── ae_modules.py │ │ │ └── openaimodel3d.py │ │ └── x_transformer.py │ ├── main/ │ │ ├── __init__.py │ │ ├── callbacks.py │ │ ├── trainer.py │ │ ├── utils_data.py │ │ └── utils_train.py │ ├── prompts/ │ │ └── 512_interp/ │ │ └── prompts.txt │ ├── requirements.txt │ ├── scripts/ │ │ ├── evaluation/ │ │ │ ├── ddp_wrapper.py │ │ │ ├── funcs.py │ │ │ └── inference.py │ │ ├── gradio/ │ │ │ ├── i2v_test.py │ │ │ └── i2v_test_application.py │ │ └── run.sh │ └── utils/ │ ├── __init__.py │ ├── save_video.py │ └── utils.py ├── __init__.py ├── pre_run.py ├── pyproject.toml ├── readme.md └── requirements.txt
SYMBOL INDEX (1216 symbols across 67 files)
FILE: ToonCrafter/cldm/cldm.py
class ControlledUnetModel (line 23) | class ControlledUnetModel(UNetModel):
method forward (line 24) | def forward(self, x, timesteps, context=None, features_adapter=None, f...
class ControlNet (line 90) | class ControlNet(nn.Module):
method __init__ (line 91) | def __init__(
method make_zero_conv (line 323) | def make_zero_conv(self, channels):
method forward (line 326) | def forward(self, x, hint, timesteps, context, **kwargs):
class ControlLDM (line 351) | class ControlLDM(LatentDiffusion):
method __init__ (line 353) | def __init__(self, control_stage_config, control_key, only_mid_control...
method get_input (line 361) | def get_input(self, batch, k, bs=None, *args, **kwargs):
method apply_model (line 371) | def apply_model(self, x_noisy, t, cond, *args, **kwargs):
method get_unconditional_conditioning (line 387) | def get_unconditional_conditioning(self, N):
method log_images (line 391) | def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50,...
method sample_log (line 452) | def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
method configure_optimizers (line 459) | def configure_optimizers(self):
method low_vram_shift (line 468) | def low_vram_shift(self, is_diffusing):
FILE: ToonCrafter/cldm/ddim_hacked.py
class DDIMSampler (line 10) | class DDIMSampler(object):
method __init__ (line 11) | def __init__(self, model, schedule="linear", **kwargs):
method register_buffer (line 17) | def register_buffer(self, name, attr):
method make_schedule (line 23) | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddi...
method sample (line 55) | def sample(self,
method ddim_sampling (line 123) | def ddim_sampling(self, cond, shape,
method p_sample_ddim (line 181) | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_origin...
method encode (line 234) | def encode(self, x0, c, t_enc, use_original_steps=False, return_interm...
method stochastic_encode (line 282) | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
method decode (line 298) | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale...
FILE: ToonCrafter/cldm/hack.py
function disable_verbosity (line 11) | def disable_verbosity():
function enable_sliced_attention (line 17) | def enable_sliced_attention():
function hack_everything (line 23) | def hack_everything(clip_skip=0):
function _hacked_clip_forward (line 32) | def _hacked_clip_forward(self, text):
function _hacked_sliced_attentin_forward (line 72) | def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
FILE: ToonCrafter/cldm/logger.py
class ImageLogger (line 11) | class ImageLogger(Callback):
method __init__ (line 12) | def __init__(self, batch_frequency=2000, max_images=4, clamp=True, inc...
method log_local (line 28) | def log_local(self, save_dir, split, images, global_step, current_epoc...
method log_img (line 42) | def log_img(self, pl_module, batch, batch_idx, split="train"):
method check_frequency (line 71) | def check_frequency(self, check_idx):
method on_train_batch_end (line 74) | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch...
FILE: ToonCrafter/cldm/model.py
function get_state_dict (line 8) | def get_state_dict(d):
function load_state_dict (line 12) | def load_state_dict(ckpt_path, location='cpu'):
function create_model (line 24) | def create_model(config_path):
FILE: ToonCrafter/gradio_app.py
function dynamicrafter_demo (line 16) | def dynamicrafter_demo(result_dir='./tmp/', res=512):
function get_parser (line 67) | def get_parser():
FILE: ToonCrafter/ldm/data/util.py
class AddMiDaS (line 6) | class AddMiDaS(object):
method __init__ (line 7) | def __init__(self, model_type):
method pt2np (line 11) | def pt2np(self, x):
method np2pt (line 15) | def np2pt(self, x):
method __call__ (line 19) | def __call__(self, sample):
FILE: ToonCrafter/ldm/models/autoencoder.py
class AutoencoderKL (line 13) | class AutoencoderKL(pl.LightningModule):
method __init__ (line 14) | def __init__(self,
method init_from_ckpt (line 52) | def init_from_ckpt(self, path, ignore_keys=list()):
method ema_scope (line 64) | def ema_scope(self, context=None):
method on_train_batch_end (line 78) | def on_train_batch_end(self, *args, **kwargs):
method encode (line 82) | def encode(self, x):
method decode (line 88) | def decode(self, z):
method forward (line 93) | def forward(self, input, sample_posterior=True):
method get_input (line 102) | def get_input(self, batch, k):
method training_step (line 109) | def training_step(self, batch, batch_idx, optimizer_idx):
method validation_step (line 130) | def validation_step(self, batch, batch_idx):
method _validation_step (line 136) | def _validation_step(self, batch, batch_idx, postfix=""):
method configure_optimizers (line 150) | def configure_optimizers(self):
method get_last_layer (line 163) | def get_last_layer(self):
method log_images (line 167) | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
method to_rgb (line 192) | def to_rgb(self, x):
class IdentityFirstStage (line 201) | class IdentityFirstStage(torch.nn.Module):
method __init__ (line 202) | def __init__(self, *args, vq_interface=False, **kwargs):
method encode (line 206) | def encode(self, x, *args, **kwargs):
method decode (line 209) | def decode(self, x, *args, **kwargs):
method quantize (line 212) | def quantize(self, x, *args, **kwargs):
method forward (line 217) | def forward(self, x, *args, **kwargs):
FILE: ToonCrafter/ldm/models/diffusion/ddim.py
class DDIMSampler (line 10) | class DDIMSampler(object):
method __init__ (line 11) | def __init__(self, model, schedule="linear", **kwargs):
method register_buffer (line 17) | def register_buffer(self, name, attr):
method make_schedule (line 23) | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddi...
method sample (line 55) | def sample(self,
method ddim_sampling (line 123) | def ddim_sampling(self, cond, shape,
method p_sample_ddim (line 181) | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_origin...
method encode (line 254) | def encode(self, x0, c, t_enc, use_original_steps=False, return_interm...
method stochastic_encode (line 301) | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
method decode (line 317) | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale...
FILE: ToonCrafter/ldm/models/diffusion/ddpm.py
function disabled_train (line 36) | def disabled_train(self, mode=True):
function uniform_on_device (line 42) | def uniform_on_device(r1, r2, shape, device):
class DDPM (line 46) | class DDPM(pl.LightningModule):
method __init__ (line 48) | def __init__(self,
method register_schedule (line 138) | def register_schedule(self, given_betas=None, beta_schedule="linear", ...
method ema_scope (line 195) | def ema_scope(self, context=None):
method init_from_ckpt (line 210) | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
method q_mean_variance (line 272) | def q_mean_variance(self, x_start, t):
method predict_start_from_noise (line 284) | def predict_start_from_noise(self, x_t, t, noise):
method predict_start_from_z_and_v (line 290) | def predict_start_from_z_and_v(self, x_t, t, v):
method predict_eps_from_z_and_v (line 298) | def predict_eps_from_z_and_v(self, x_t, t, v):
method q_posterior (line 304) | def q_posterior(self, x_start, x_t, t):
method p_mean_variance (line 313) | def p_mean_variance(self, x, t, clip_denoised: bool):
method p_sample (line 326) | def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
method p_sample_loop (line 335) | def p_sample_loop(self, shape, return_intermediates=False):
method sample (line 350) | def sample(self, batch_size=16, return_intermediates=False):
method q_sample (line 356) | def q_sample(self, x_start, t, noise=None):
method get_v (line 361) | def get_v(self, x, noise, t):
method get_loss (line 367) | def get_loss(self, pred, target, mean=True):
method p_losses (line 382) | def p_losses(self, x_start, t, noise=None):
method forward (line 413) | def forward(self, x, *args, **kwargs):
method get_input (line 419) | def get_input(self, batch, k):
method shared_step (line 427) | def shared_step(self, batch):
method training_step (line 432) | def training_step(self, batch, batch_idx):
method validation_step (line 457) | def validation_step(self, batch, batch_idx):
method on_train_batch_end (line 465) | def on_train_batch_end(self, *args, **kwargs):
method _get_rows_from_list (line 469) | def _get_rows_from_list(self, samples):
method log_images (line 477) | def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=Non...
method configure_optimizers (line 514) | def configure_optimizers(self):
class LatentDiffusion (line 523) | class LatentDiffusion(DDPM):
method __init__ (line 526) | def __init__(self,
method make_cond_schedule (line 584) | def make_cond_schedule(self, ):
method on_train_batch_start (line 591) | def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
method register_schedule (line 606) | def register_schedule(self,
method instantiate_first_stage (line 615) | def instantiate_first_stage(self, config):
method instantiate_cond_stage (line 622) | def instantiate_cond_stage(self, config):
method _get_denoise_row_from_list (line 643) | def _get_denoise_row_from_list(self, samples, desc='', force_no_decode...
method get_first_stage_encoding (line 655) | def get_first_stage_encoding(self, encoder_posterior):
method get_learned_conditioning (line 664) | def get_learned_conditioning(self, c):
method meshgrid (line 677) | def meshgrid(self, h, w):
method delta_border (line 684) | def delta_border(self, h, w):
method get_weighting (line 698) | def get_weighting(self, h, w, Ly, Lx, device):
method get_fold_unfold (line 714) | def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo...
method get_input (line 767) | def get_input(self, batch, k, return_first_stage_outputs=False, force_...
method decode_first_stage (line 820) | def decode_first_stage(self, z, predict_cids=False, force_not_quantize...
method encode_first_stage (line 831) | def encode_first_stage(self, x):
method shared_step (line 834) | def shared_step(self, batch, **kwargs):
method forward (line 839) | def forward(self, x, c, *args, **kwargs):
method apply_model (line 850) | def apply_model(self, x_noisy, t, cond, return_ids=False):
method _predict_eps_from_xstart (line 867) | def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
method _prior_bpd (line 871) | def _prior_bpd(self, x_start):
method p_losses (line 885) | def p_losses(self, x_start, cond, t, noise=None):
method p_mean_variance (line 922) | def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codeboo...
method p_sample (line 954) | def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
method progressive_denoising (line 985) | def progressive_denoising(self, cond, shape, verbose=True, callback=No...
method p_sample_loop (line 1041) | def p_sample_loop(self, cond, shape, return_intermediates=False,
method sample (line 1092) | def sample(self, cond, batch_size=16, return_intermediates=False, x_T=...
method sample_log (line 1110) | def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
method get_unconditional_conditioning (line 1124) | def get_unconditional_conditioning(self, batch_size, null_label=None):
method log_images (line 1149) | def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ...
method configure_optimizers (line 1278) | def configure_optimizers(self):
method to_rgb (line 1303) | def to_rgb(self, x):
class DiffusionWrapper (line 1312) | class DiffusionWrapper(pl.LightningModule):
method __init__ (line 1313) | def __init__(self, diff_model_config, conditioning_key):
method forward (line 1320) | def forward(self, x, t, c_concat: list = None, c_crossattn: list = Non...
class LatentUpscaleDiffusion (line 1354) | class LatentUpscaleDiffusion(LatentDiffusion):
method __init__ (line 1355) | def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_...
method instantiate_low_stage (line 1363) | def instantiate_low_stage(self, config):
method get_input (line 1371) | def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
method log_images (line 1393) | def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200,...
class LatentFinetuneDiffusion (line 1492) | class LatentFinetuneDiffusion(LatentDiffusion):
method __init__ (line 1498) | def __init__(self,
method init_from_ckpt (line 1521) | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
method log_images (line 1553) | def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200,...
class LatentInpaintDiffusion (line 1634) | class LatentInpaintDiffusion(LatentFinetuneDiffusion):
method __init__ (line 1641) | def __init__(self,
method get_input (line 1651) | def get_input(self, batch, k, cond_key=None, bs=None, return_first_sta...
method log_images (line 1677) | def log_images(self, *args, **kwargs):
class LatentDepth2ImageDiffusion (line 1684) | class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
method __init__ (line 1689) | def __init__(self, depth_stage_config, concat_keys=("midas_in",), *arg...
method get_input (line 1695) | def get_input(self, batch, k, cond_key=None, bs=None, return_first_sta...
method log_images (line 1728) | def log_images(self, *args, **kwargs):
class LatentUpscaleFinetuneDiffusion (line 1737) | class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
method __init__ (line 1741) | def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
method instantiate_low_stage (line 1752) | def instantiate_low_stage(self, config):
method get_input (line 1760) | def get_input(self, batch, k, cond_key=None, bs=None, return_first_sta...
method log_images (line 1794) | def log_images(self, *args, **kwargs):
FILE: ToonCrafter/ldm/models/diffusion/dpm_solver/dpm_solver.py
class NoiseScheduleVP (line 7) | class NoiseScheduleVP:
method __init__ (line 8) | def __init__(
method marginal_log_mean_coeff (line 106) | def marginal_log_mean_coeff(self, t):
method marginal_alpha (line 120) | def marginal_alpha(self, t):
method marginal_std (line 126) | def marginal_std(self, t):
method marginal_lambda (line 132) | def marginal_lambda(self, t):
method inverse_lambda (line 140) | def inverse_lambda(self, lamb):
function model_wrapper (line 161) | def model_wrapper(
class DPM_Solver (line 319) | class DPM_Solver:
method __init__ (line 320) | def __init__(self, model_fn, noise_schedule, predict_x0=False, thresho...
method noise_prediction_fn (line 346) | def noise_prediction_fn(self, x, t):
method data_prediction_fn (line 352) | def data_prediction_fn(self, x, t):
method model_fn (line 367) | def model_fn(self, x, t):
method get_time_steps (line 376) | def get_time_steps(self, skip_type, t_T, t_0, N, device):
method get_orders_and_timesteps_for_singlestep_solver (line 405) | def get_orders_and_timesteps_for_singlestep_solver(self, steps, order,...
method denoise_to_zero_fn (line 463) | def denoise_to_zero_fn(self, x, s):
method dpm_solver_first_update (line 469) | def dpm_solver_first_update(self, x, s, t, model_s=None, return_interm...
method singlestep_dpm_solver_second_update (line 515) | def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s...
method singlestep_dpm_solver_third_update (line 599) | def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2...
method multistep_dpm_solver_second_update (line 723) | def multistep_dpm_solver_second_update(self, x, model_prev_list, t_pre...
method multistep_dpm_solver_third_update (line 780) | def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev...
method singlestep_dpm_solver_update (line 827) | def singlestep_dpm_solver_update(self, x, s, t, order, return_intermed...
method multistep_dpm_solver_update (line 855) | def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list,...
method dpm_solver_adaptive (line 878) | def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0....
method sample (line 939) | def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_...
function interpolate_fn (line 1104) | def interpolate_fn(x, xp, yp):
function expand_dims (line 1145) | def expand_dims(v, dims):
FILE: ToonCrafter/ldm/models/diffusion/dpm_solver/sampler.py
class DPMSolverSampler (line 13) | class DPMSolverSampler(object):
method __init__ (line 14) | def __init__(self, model, **kwargs):
method register_buffer (line 20) | def register_buffer(self, name, attr):
method sample (line 27) | def sample(self,
FILE: ToonCrafter/ldm/models/diffusion/plms.py
class PLMSSampler (line 12) | class PLMSSampler(object):
method __init__ (line 13) | def __init__(self, model, schedule="linear", **kwargs):
method register_buffer (line 19) | def register_buffer(self, name, attr):
method make_schedule (line 25) | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddi...
method sample (line 59) | def sample(self,
method plms_sampling (line 118) | def plms_sampling(self, cond, shape,
method p_sample_plms (line 178) | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_origin...
FILE: ToonCrafter/ldm/models/diffusion/sampling_util.py
function append_dims (line 5) | def append_dims(x, target_dims):
function norm_thresholding (line 14) | def norm_thresholding(x0, value):
function spatial_norm_thresholding (line 19) | def spatial_norm_thresholding(x0, value):
FILE: ToonCrafter/ldm/modules/attention.py
function exists (line 23) | def exists(val):
function uniq (line 27) | def uniq(arr):
function default (line 31) | def default(val, d):
function max_neg_value (line 37) | def max_neg_value(t):
function init_ (line 41) | def init_(tensor):
class GEGLU (line 49) | class GEGLU(nn.Module):
method __init__ (line 50) | def __init__(self, dim_in, dim_out):
method forward (line 54) | def forward(self, x):
class FeedForward (line 59) | class FeedForward(nn.Module):
method __init__ (line 60) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
method forward (line 75) | def forward(self, x):
function zero_module (line 79) | def zero_module(module):
function Normalize (line 88) | def Normalize(in_channels):
class SpatialSelfAttention (line 92) | class SpatialSelfAttention(nn.Module):
method __init__ (line 93) | def __init__(self, in_channels):
method forward (line 119) | def forward(self, x):
class CrossAttention (line 145) | class CrossAttention(nn.Module):
method __init__ (line 146) | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, ...
method forward (line 163) | def forward(self, x, context=None, mask=None):
class MemoryEfficientCrossAttention (line 197) | class MemoryEfficientCrossAttention(nn.Module):
method __init__ (line 199) | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, ...
method forward (line 216) | def forward(self, x, context=None, mask=None):
class BasicTransformerBlock (line 246) | class BasicTransformerBlock(nn.Module):
method __init__ (line 251) | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None,...
method forward (line 268) | def forward(self, x, context=None):
method _forward (line 271) | def _forward(self, x, context=None):
class SpatialTransformer (line 278) | class SpatialTransformer(nn.Module):
method __init__ (line 287) | def __init__(self, in_channels, n_heads, d_head,
method forward (line 321) | def forward(self, x, context=None):
FILE: ToonCrafter/ldm/modules/diffusionmodules/model.py
function get_timestep_embedding (line 20) | def get_timestep_embedding(timesteps, embedding_dim):
function nonlinearity (line 41) | def nonlinearity(x):
function Normalize (line 46) | def Normalize(in_channels, num_groups=32):
class Upsample (line 50) | class Upsample(nn.Module):
method __init__ (line 51) | def __init__(self, in_channels, with_conv):
method forward (line 61) | def forward(self, x):
class Downsample (line 68) | class Downsample(nn.Module):
method __init__ (line 69) | def __init__(self, in_channels, with_conv):
method forward (line 80) | def forward(self, x):
class ResnetBlock (line 90) | class ResnetBlock(nn.Module):
method __init__ (line 91) | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=Fa...
method forward (line 129) | def forward(self, x, temb):
class AttnBlock (line 152) | class AttnBlock(nn.Module):
method __init__ (line 153) | def __init__(self, in_channels):
method forward (line 179) | def forward(self, x):
class MemoryEfficientAttnBlock (line 205) | class MemoryEfficientAttnBlock(nn.Module):
method __init__ (line 212) | def __init__(self, in_channels):
method forward (line 239) | def forward(self, x):
class MemoryEfficientCrossAttentionWrapper (line 271) | class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
method forward (line 272) | def forward(self, x, context=None, mask=None):
function make_attn (line 280) | def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
class Model (line 300) | class Model(nn.Module):
method __init__ (line 301) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
method forward (line 400) | def forward(self, x, t=None, context=None):
method get_last_layer (line 448) | def get_last_layer(self):
class Encoder (line 452) | class Encoder(nn.Module):
method __init__ (line 453) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
method forward (line 518) | def forward(self, x):
class Decoder (line 546) | class Decoder(nn.Module):
method __init__ (line 547) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
method forward (line 619) | def forward(self, z):
class SimpleDecoder (line 655) | class SimpleDecoder(nn.Module):
method __init__ (line 656) | def __init__(self, in_channels, out_channels, *args, **kwargs):
method forward (line 678) | def forward(self, x):
class UpsampleDecoder (line 691) | class UpsampleDecoder(nn.Module):
method __init__ (line 692) | def __init__(self, in_channels, out_channels, ch, num_res_blocks, reso...
method forward (line 725) | def forward(self, x):
class LatentRescaler (line 739) | class LatentRescaler(nn.Module):
method __init__ (line 740) | def __init__(self, factor, in_channels, mid_channels, out_channels, de...
method forward (line 764) | def forward(self, x):
class MergedRescaleEncoder (line 776) | class MergedRescaleEncoder(nn.Module):
method __init__ (line 777) | def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
method forward (line 789) | def forward(self, x):
class MergedRescaleDecoder (line 795) | class MergedRescaleDecoder(nn.Module):
method __init__ (line 796) | def __init__(self, z_channels, out_ch, resolution, num_res_blocks, att...
method forward (line 806) | def forward(self, x):
class Upsampler (line 812) | class Upsampler(nn.Module):
method __init__ (line 813) | def __init__(self, in_size, out_size, in_channels, out_channels, ch_mu...
method forward (line 825) | def forward(self, x):
class Resize (line 831) | class Resize(nn.Module):
method __init__ (line 832) | def __init__(self, in_channels=None, learned=False, mode="bilinear"):
method forward (line 847) | def forward(self, x, scale_factor=1.0):
FILE: ToonCrafter/ldm/modules/diffusionmodules/openaimodel.py
function convert_module_to_f16 (line 23) | def convert_module_to_f16(x):
function convert_module_to_f32 (line 26) | def convert_module_to_f32(x):
class AttentionPool2d (line 31) | class AttentionPool2d(nn.Module):
method __init__ (line 36) | def __init__(
method forward (line 50) | def forward(self, x):
class TimestepBlock (line 61) | class TimestepBlock(nn.Module):
method forward (line 67) | def forward(self, x, emb):
class TimestepEmbedSequential (line 73) | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
method forward (line 79) | def forward(self, x, emb, context=None, check=False):
class Upsample (line 90) | class Upsample(nn.Module):
method __init__ (line 99) | def __init__(self, channels, use_conv, dims=2, out_channels=None, padd...
method forward (line 108) | def forward(self, x):
class TransposedUpsample (line 120) | class TransposedUpsample(nn.Module):
method __init__ (line 122) | def __init__(self, channels, out_channels=None, ks=5):
method forward (line 129) | def forward(self,x):
class Downsample (line 133) | class Downsample(nn.Module):
method __init__ (line 142) | def __init__(self, channels, use_conv, dims=2, out_channels=None,paddi...
method forward (line 157) | def forward(self, x):
class ResBlock (line 162) | class ResBlock(TimestepBlock):
method __init__ (line 178) | def __init__(
method forward (line 242) | def forward(self, x, emb):
method _forward (line 254) | def _forward(self, x, emb):
class AttentionBlock (line 277) | class AttentionBlock(nn.Module):
method __init__ (line 284) | def __init__(
method forward (line 313) | def forward(self, x):
method _forward (line 317) | def _forward(self, x):
function count_flops_attn (line 326) | def count_flops_attn(model, _x, y):
class QKVAttentionLegacy (line 346) | class QKVAttentionLegacy(nn.Module):
method __init__ (line 351) | def __init__(self, n_heads):
method forward (line 355) | def forward(self, qkv):
method count_flops (line 374) | def count_flops(model, _x, y):
class QKVAttention (line 378) | class QKVAttention(nn.Module):
method __init__ (line 383) | def __init__(self, n_heads):
method forward (line 387) | def forward(self, qkv):
method count_flops (line 408) | def count_flops(model, _x, y):
class UNetModel (line 412) | class UNetModel(nn.Module):
method __init__ (line 442) | def __init__(
method convert_to_fp16 (line 738) | def convert_to_fp16(self):
method convert_to_fp32 (line 746) | def convert_to_fp32(self):
method forward (line 754) | def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
FILE: ToonCrafter/ldm/modules/diffusionmodules/upscaling.py
class AbstractLowScaleModel (line 10) | class AbstractLowScaleModel(nn.Module):
method __init__ (line 12) | def __init__(self, noise_schedule_config=None):
method register_schedule (line 17) | def register_schedule(self, beta_schedule="linear", timesteps=1000,
method q_sample (line 44) | def q_sample(self, x_start, t, noise=None):
method forward (line 49) | def forward(self, x):
method decode (line 52) | def decode(self, x):
class SimpleImageConcat (line 56) | class SimpleImageConcat(AbstractLowScaleModel):
method __init__ (line 58) | def __init__(self):
method forward (line 62) | def forward(self, x):
class ImageConcatWithNoiseAugmentation (line 67) | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
method __init__ (line 68) | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cud...
method forward (line 72) | def forward(self, x, noise_level=None):
FILE: ToonCrafter/ldm/modules/diffusionmodules/util.py
function make_beta_schedule (line 21) | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_e...
function make_ddim_timesteps (line 46) | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_...
function make_ddim_sampling_parameters (line 63) | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbos...
function betas_for_alpha_bar (line 77) | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.9...
function extract_into_tensor (line 96) | def extract_into_tensor(a, t, x_shape):
function checkpoint (line 102) | def checkpoint(func, inputs, params, flag):
class CheckpointFunction (line 119) | class CheckpointFunction(torch.autograd.Function):
method forward (line 121) | def forward(ctx, run_function, length, *args):
method backward (line 133) | def backward(ctx, *output_grads):
function timestep_embedding (line 154) | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=Fal...
function zero_module (line 177) | def zero_module(module):
function scale_module (line 186) | def scale_module(module, scale):
function mean_flat (line 195) | def mean_flat(tensor):
function normalization (line 202) | def normalization(channels):
class SiLU (line 212) | class SiLU(nn.Module):
method forward (line 213) | def forward(self, x):
class GroupNorm32 (line 217) | class GroupNorm32(nn.GroupNorm):
method forward (line 218) | def forward(self, x):
function conv_nd (line 221) | def conv_nd(dims, *args, **kwargs):
function linear (line 234) | def linear(*args, **kwargs):
function avg_pool_nd (line 241) | def avg_pool_nd(dims, *args, **kwargs):
class HybridConditioner (line 254) | class HybridConditioner(nn.Module):
method __init__ (line 256) | def __init__(self, c_concat_config, c_crossattn_config):
method forward (line 261) | def forward(self, c_concat, c_crossattn):
function noise_like (line 267) | def noise_like(shape, device, repeat=False):
FILE: ToonCrafter/ldm/modules/distributions/distributions.py
class AbstractDistribution (line 5) | class AbstractDistribution:
method sample (line 6) | def sample(self):
method mode (line 9) | def mode(self):
class DiracDistribution (line 13) | class DiracDistribution(AbstractDistribution):
method __init__ (line 14) | def __init__(self, value):
method sample (line 17) | def sample(self):
method mode (line 20) | def mode(self):
class DiagonalGaussianDistribution (line 24) | class DiagonalGaussianDistribution(object):
method __init__ (line 25) | def __init__(self, parameters, deterministic=False):
method sample (line 35) | def sample(self):
method kl (line 39) | def kl(self, other=None):
method nll (line 53) | def nll(self, sample, dims=[1,2,3]):
method mode (line 61) | def mode(self):
function normal_kl (line 65) | def normal_kl(mean1, logvar1, mean2, logvar2):
FILE: ToonCrafter/ldm/modules/ema.py
class LitEma (line 5) | class LitEma(nn.Module):
method __init__ (line 6) | def __init__(self, model, decay=0.9999, use_num_upates=True):
method reset_num_updates (line 25) | def reset_num_updates(self):
method forward (line 29) | def forward(self, model):
method copy_to (line 50) | def copy_to(self, model):
method store (line 59) | def store(self, parameters):
method restore (line 68) | def restore(self, parameters):
FILE: ToonCrafter/ldm/modules/encoders/modules.py
class AbstractEncoder (line 11) | class AbstractEncoder(nn.Module):
method __init__ (line 12) | def __init__(self):
method encode (line 15) | def encode(self, *args, **kwargs):
class IdentityEncoder (line 19) | class IdentityEncoder(AbstractEncoder):
method encode (line 21) | def encode(self, x):
class ClassEmbedder (line 25) | class ClassEmbedder(nn.Module):
method __init__ (line 26) | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
method forward (line 33) | def forward(self, batch, key=None, disable_dropout=False):
method get_unconditional_conditioning (line 45) | def get_unconditional_conditioning(self, bs, device="cuda"):
function disabled_train (line 52) | def disabled_train(self, mode=True):
class FrozenT5Embedder (line 58) | class FrozenT5Embedder(AbstractEncoder):
method __init__ (line 60) | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_...
method freeze (line 69) | def freeze(self):
method forward (line 75) | def forward(self, text):
method encode (line 84) | def encode(self, text):
class FrozenCLIPEmbedder (line 88) | class FrozenCLIPEmbedder(AbstractEncoder):
method __init__ (line 95) | def __init__(self, version="openai/clip-vit-large-patch14", device="cu...
method freeze (line 111) | def freeze(self):
method forward (line 117) | def forward(self, text):
method encode (line 130) | def encode(self, text):
class FrozenOpenCLIPEmbedder (line 134) | class FrozenOpenCLIPEmbedder(AbstractEncoder):
method __init__ (line 143) | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", devic...
method freeze (line 163) | def freeze(self):
method forward (line 168) | def forward(self, text):
method encode_with_transformer (line 173) | def encode_with_transformer(self, text):
method text_transformer_forward (line 182) | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
method encode (line 192) | def encode(self, text):
class FrozenCLIPT5Encoder (line 196) | class FrozenCLIPT5Encoder(AbstractEncoder):
method __init__ (line 197) | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_ve...
method encode (line 205) | def encode(self, text):
method forward (line 208) | def forward(self, text):
FILE: ToonCrafter/ldm/modules/image_degradation/bsrgan.py
function modcrop_np (line 29) | def modcrop_np(img, sf):
function analytic_kernel (line 49) | def analytic_kernel(k):
function anisotropic_Gaussian (line 65) | def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
function gm_blur_kernel (line 86) | def gm_blur_kernel(mean, cov, size=15):
function shift_pixel (line 99) | def shift_pixel(x, sf, upper_left=True):
function blur (line 128) | def blur(x, k):
function gen_kernel (line 145) | def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]),...
function fspecial_gaussian (line 187) | def fspecial_gaussian(hsize, sigma):
function fspecial_laplacian (line 201) | def fspecial_laplacian(alpha):
function fspecial (line 210) | def fspecial(filter_type, *args, **kwargs):
function bicubic_degradation (line 228) | def bicubic_degradation(x, sf=3):
function srmd_degradation (line 240) | def srmd_degradation(x, k, sf=3):
function dpsr_degradation (line 262) | def dpsr_degradation(x, k, sf=3):
function classical_degradation (line 284) | def classical_degradation(x, k, sf=3):
function add_sharpening (line 299) | def add_sharpening(img, weight=0.5, radius=50, threshold=10):
function add_blur (line 325) | def add_blur(img, sf=4):
function add_resize (line 339) | def add_resize(img, sf=4):
function add_Gaussian_noise (line 369) | def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
function add_speckle_noise (line 386) | def add_speckle_noise(img, noise_level1=2, noise_level2=25):
function add_Poisson_noise (line 404) | def add_Poisson_noise(img):
function add_JPEG_noise (line 418) | def add_JPEG_noise(img):
function random_crop (line 427) | def random_crop(lq, hq, sf=4, lq_patchsize=64):
function degradation_bsrgan (line 438) | def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
function degradation_bsrgan_variant (line 530) | def degradation_bsrgan_variant(image, sf=4, isp_model=None):
function degradation_bsrgan_plus (line 617) | def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True,...
FILE: ToonCrafter/ldm/modules/image_degradation/bsrgan_light.py
function modcrop_np (line 28) | def modcrop_np(img, sf):
function analytic_kernel (line 48) | def analytic_kernel(k):
function anisotropic_Gaussian (line 64) | def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
function gm_blur_kernel (line 85) | def gm_blur_kernel(mean, cov, size=15):
function shift_pixel (line 98) | def shift_pixel(x, sf, upper_left=True):
function blur (line 127) | def blur(x, k):
function gen_kernel (line 144) | def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]),...
function fspecial_gaussian (line 186) | def fspecial_gaussian(hsize, sigma):
function fspecial_laplacian (line 200) | def fspecial_laplacian(alpha):
function fspecial (line 209) | def fspecial(filter_type, *args, **kwargs):
function bicubic_degradation (line 227) | def bicubic_degradation(x, sf=3):
function srmd_degradation (line 239) | def srmd_degradation(x, k, sf=3):
function dpsr_degradation (line 261) | def dpsr_degradation(x, k, sf=3):
function classical_degradation (line 283) | def classical_degradation(x, k, sf=3):
function add_sharpening (line 298) | def add_sharpening(img, weight=0.5, radius=50, threshold=10):
function add_blur (line 324) | def add_blur(img, sf=4):
function add_resize (line 342) | def add_resize(img, sf=4):
function add_Gaussian_noise (line 372) | def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
function add_speckle_noise (line 389) | def add_speckle_noise(img, noise_level1=2, noise_level2=25):
function add_Poisson_noise (line 407) | def add_Poisson_noise(img):
function add_JPEG_noise (line 421) | def add_JPEG_noise(img):
function random_crop (line 430) | def random_crop(lq, hq, sf=4, lq_patchsize=64):
function degradation_bsrgan (line 441) | def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
function degradation_bsrgan_variant (line 533) | def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False):
FILE: ToonCrafter/ldm/modules/image_degradation/utils_image.py
function is_image_file (line 29) | def is_image_file(filename):
function get_timestamp (line 33) | def get_timestamp():
function imshow (line 37) | def imshow(x, title=None, cbar=False, figsize=None):
function surf (line 47) | def surf(Z, cmap='rainbow', figsize=None):
function get_image_paths (line 67) | def get_image_paths(dataroot):
function _get_paths_from_images (line 74) | def _get_paths_from_images(path):
function patches_from_image (line 93) | def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
function imssave (line 112) | def imssave(imgs, img_path):
function split_imageset (line 125) | def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_si...
function mkdir (line 153) | def mkdir(path):
function mkdirs (line 158) | def mkdirs(paths):
function mkdir_and_rename (line 166) | def mkdir_and_rename(path):
function imread_uint (line 185) | def imread_uint(path, n_channels=3):
function imsave (line 203) | def imsave(img, img_path):
function imwrite (line 209) | def imwrite(img, img_path):
function read_img (line 220) | def read_img(path):
function uint2single (line 249) | def uint2single(img):
function single2uint (line 254) | def single2uint(img):
function uint162single (line 259) | def uint162single(img):
function single2uint16 (line 264) | def single2uint16(img):
function uint2tensor4 (line 275) | def uint2tensor4(img):
function uint2tensor3 (line 282) | def uint2tensor3(img):
function tensor2uint (line 289) | def tensor2uint(img):
function single2tensor3 (line 302) | def single2tensor3(img):
function single2tensor4 (line 307) | def single2tensor4(img):
function tensor2single (line 312) | def tensor2single(img):
function tensor2single3 (line 320) | def tensor2single3(img):
function single2tensor5 (line 329) | def single2tensor5(img):
function single32tensor5 (line 333) | def single32tensor5(img):
function single42tensor4 (line 337) | def single42tensor4(img):
function tensor2img (line 342) | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
function augment_img (line 380) | def augment_img(img, mode=0):
function augment_img_tensor4 (line 401) | def augment_img_tensor4(img, mode=0):
function augment_img_tensor (line 422) | def augment_img_tensor(img, mode=0):
function augment_img_np3 (line 441) | def augment_img_np3(img, mode=0):
function augment_imgs (line 469) | def augment_imgs(img_list, hflip=True, rot=True):
function modcrop (line 494) | def modcrop(img_in, scale):
function shave (line 510) | def shave(img_in, border=0):
function rgb2ycbcr (line 529) | def rgb2ycbcr(img, only_y=True):
function ycbcr2rgb (line 553) | def ycbcr2rgb(img):
function bgr2ycbcr (line 573) | def bgr2ycbcr(img, only_y=True):
function channel_convert (line 597) | def channel_convert(in_c, tar_type, img_list):
function calculate_psnr (line 621) | def calculate_psnr(img1, img2, border=0):
function calculate_ssim (line 642) | def calculate_ssim(img1, img2, border=0):
function ssim (line 669) | def ssim(img1, img2):
function cubic (line 700) | def cubic(x):
function calculate_weights_indices (line 708) | def calculate_weights_indices(in_length, out_length, scale, kernel, kern...
function imresize (line 766) | def imresize(img, scale, antialiasing=True):
function imresize_np (line 839) | def imresize_np(img, scale, antialiasing=True):
FILE: ToonCrafter/ldm/modules/midas/api.py
function disabled_train (line 22) | def disabled_train(self, mode=True):
function load_midas_transform (line 28) | def load_midas_transform(model_type):
function load_model (line 73) | def load_model(model_type):
class MiDaSInference (line 137) | class MiDaSInference(nn.Module):
method __init__ (line 150) | def __init__(self, model_type):
method forward (line 157) | def forward(self, x):
FILE: ToonCrafter/ldm/modules/midas/midas/base_model.py
class BaseModel (line 4) | class BaseModel(torch.nn.Module):
method load (line 5) | def load(self, path):
FILE: ToonCrafter/ldm/modules/midas/midas/blocks.py
function _make_encoder (line 11) | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=F...
function _make_scratch (line 49) | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
function _make_pretrained_efficientnet_lite3 (line 78) | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
function _make_efficientnet_backbone (line 88) | def _make_efficientnet_backbone(effnet):
function _make_resnet_backbone (line 101) | def _make_resnet_backbone(resnet):
function _make_pretrained_resnext101_wsl (line 114) | def _make_pretrained_resnext101_wsl(use_pretrained):
class Interpolate (line 120) | class Interpolate(nn.Module):
method __init__ (line 124) | def __init__(self, scale_factor, mode, align_corners=False):
method forward (line 138) | def forward(self, x):
class ResidualConvUnit (line 155) | class ResidualConvUnit(nn.Module):
method __init__ (line 159) | def __init__(self, features):
method forward (line 177) | def forward(self, x):
class FeatureFusionBlock (line 194) | class FeatureFusionBlock(nn.Module):
method __init__ (line 198) | def __init__(self, features):
method forward (line 209) | def forward(self, *xs):
class ResidualConvUnit_custom (line 231) | class ResidualConvUnit_custom(nn.Module):
method __init__ (line 235) | def __init__(self, features, activation, bn):
method forward (line 263) | def forward(self, x):
class FeatureFusionBlock_custom (line 291) | class FeatureFusionBlock_custom(nn.Module):
method __init__ (line 295) | def __init__(self, features, activation, deconv=False, bn=False, expan...
method forward (line 320) | def forward(self, *xs):
FILE: ToonCrafter/ldm/modules/midas/midas/dpt_depth.py
function _make_fusion_block (line 15) | def _make_fusion_block(features, use_bn):
class DPT (line 26) | class DPT(BaseModel):
method __init__ (line 27) | def __init__(
method forward (line 67) | def forward(self, x):
class DPTDepthModel (line 88) | class DPTDepthModel(DPT):
method __init__ (line 89) | def __init__(self, path=None, non_negative=True, **kwargs):
method forward (line 107) | def forward(self, x):
FILE: ToonCrafter/ldm/modules/midas/midas/midas_net.py
class MidasNet (line 12) | class MidasNet(BaseModel):
method __init__ (line 16) | def __init__(self, path=None, features=256, non_negative=True):
method forward (line 49) | def forward(self, x):
FILE: ToonCrafter/ldm/modules/midas/midas/midas_net_custom.py
class MidasNet_small (line 12) | class MidasNet_small(BaseModel):
method __init__ (line 16) | def __init__(self, path=None, features=64, backbone="efficientnet_lite...
method forward (line 73) | def forward(self, x):
function fuse_model (line 109) | def fuse_model(m):
FILE: ToonCrafter/ldm/modules/midas/midas/transforms.py
function apply_min_size (line 6) | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AR...
class Resize (line 48) | class Resize(object):
method __init__ (line 52) | def __init__(
method constrain_to_multiple_of (line 94) | def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
method get_size (line 105) | def get_size(self, width, height):
method __call__ (line 162) | def __call__(self, sample):
class NormalizeImage (line 197) | class NormalizeImage(object):
method __init__ (line 201) | def __init__(self, mean, std):
method __call__ (line 205) | def __call__(self, sample):
class PrepareForNet (line 211) | class PrepareForNet(object):
method __init__ (line 215) | def __init__(self):
method __call__ (line 218) | def __call__(self, sample):
FILE: ToonCrafter/ldm/modules/midas/midas/vit.py
class Slice (line 9) | class Slice(nn.Module):
method __init__ (line 10) | def __init__(self, start_index=1):
method forward (line 14) | def forward(self, x):
class AddReadout (line 18) | class AddReadout(nn.Module):
method __init__ (line 19) | def __init__(self, start_index=1):
method forward (line 23) | def forward(self, x):
class ProjectReadout (line 31) | class ProjectReadout(nn.Module):
method __init__ (line 32) | def __init__(self, in_features, start_index=1):
method forward (line 38) | def forward(self, x):
class Transpose (line 45) | class Transpose(nn.Module):
method __init__ (line 46) | def __init__(self, dim0, dim1):
method forward (line 51) | def forward(self, x):
function forward_vit (line 56) | def forward_vit(pretrained, x):
function _resize_pos_embed (line 100) | def _resize_pos_embed(self, posemb, gs_h, gs_w):
function forward_flex (line 117) | def forward_flex(self, x):
function get_activation (line 159) | def get_activation(name):
function get_readout_oper (line 166) | def get_readout_oper(vit_features, features, use_readout, start_index=1):
function _make_vit_b16_backbone (line 183) | def _make_vit_b16_backbone(
function _make_pretrained_vitl16_384 (line 297) | def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=...
function _make_pretrained_vitb16_384 (line 310) | def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=...
function _make_pretrained_deitb16_384 (line 319) | def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks...
function _make_pretrained_deitb16_distil_384 (line 328) | def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore"...
function _make_vit_b_rn50_backbone (line 343) | def _make_vit_b_rn50_backbone(
function _make_pretrained_vitb_rn50_384 (line 478) | def _make_pretrained_vitb_rn50_384(
FILE: ToonCrafter/ldm/modules/midas/utils.py
function read_pfm (line 9) | def read_pfm(path):
function write_pfm (line 58) | def write_pfm(path, image, scale=1):
function read_image (line 97) | def read_image(path):
function resize_image (line 116) | def resize_image(img):
function resize_depth (line 146) | def resize_depth(depth, width, height):
function write_depth (line 165) | def write_depth(path, depth, bits=1):
FILE: ToonCrafter/ldm/util.py
function log_txt_as_img (line 11) | def log_txt_as_img(wh, xc, size=10):
function ismap (line 35) | def ismap(x):
function isimage (line 41) | def isimage(x):
function exists (line 47) | def exists(x):
function default (line 51) | def default(val, d):
function mean_flat (line 57) | def mean_flat(tensor):
function count_params (line 65) | def count_params(model, verbose=False):
function instantiate_from_config (line 72) | def instantiate_from_config(config):
function get_obj_from_str (line 82) | def get_obj_from_str(string, reload=False):
class AdamWwithEMAandWings (line 90) | class AdamWwithEMAandWings(optim.Optimizer):
method __init__ (line 92) | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, #...
method __setstate__ (line 113) | def __setstate__(self, state):
method step (line 119) | def step(self, closure=None):
FILE: ToonCrafter/lvdm/basics.py
function disabled_train (line 14) | def disabled_train(self, mode=True):
function zero_module (line 20) | def zero_module(module):
function scale_module (line 29) | def scale_module(module, scale):
function conv_nd (line 38) | def conv_nd(dims, *args, **kwargs):
function linear (line 51) | def linear(*args, **kwargs):
function avg_pool_nd (line 58) | def avg_pool_nd(dims, *args, **kwargs):
function nonlinearity (line 71) | def nonlinearity(type='silu'):
class GroupNormSpecific (line 78) | class GroupNormSpecific(nn.GroupNorm):
method forward (line 79) | def forward(self, x):
function normalization (line 83) | def normalization(channels, num_groups=32):
class HybridConditioner (line 92) | class HybridConditioner(nn.Module):
method __init__ (line 94) | def __init__(self, c_concat_config, c_crossattn_config):
method forward (line 99) | def forward(self, c_concat, c_crossattn):
FILE: ToonCrafter/lvdm/common.py
function gather_data (line 8) | def gather_data(data, return_np=True):
function autocast (line 16) | def autocast(f):
function extract_into_tensor (line 25) | def extract_into_tensor(a, t, x_shape):
function noise_like (line 31) | def noise_like(shape, device, repeat=False):
function default (line 37) | def default(val, d):
function exists (line 42) | def exists(val):
function identity (line 45) | def identity(*args, **kwargs):
function uniq (line 48) | def uniq(arr):
function mean_flat (line 51) | def mean_flat(tensor):
function ismap (line 57) | def ismap(x):
function isimage (line 62) | def isimage(x):
function max_neg_value (line 67) | def max_neg_value(t):
function shape_to_str (line 70) | def shape_to_str(x):
function init_ (line 74) | def init_(tensor):
function checkpoint (line 81) | def checkpoint(func, inputs, params, flag):
FILE: ToonCrafter/lvdm/data/base.py
class Txt2ImgIterableBaseDataset (line 5) | class Txt2ImgIterableBaseDataset(IterableDataset):
method __init__ (line 9) | def __init__(self, num_records=0, valid_ids=None, size=256):
method __len__ (line 18) | def __len__(self):
method __iter__ (line 22) | def __iter__(self):
FILE: ToonCrafter/lvdm/data/webvid.py
class WebVid (line 13) | class WebVid(Dataset):
method __init__ (line 25) | def __init__(self,
method _load_metadata (line 72) | def _load_metadata(self):
method _get_video_path (line 83) | def _get_video_path(self, sample):
method __getitem__ (line 88) | def __getitem__(self, index):
method __len__ (line 170) | def __len__(self):
FILE: ToonCrafter/lvdm/distributions.py
class AbstractDistribution (line 5) | class AbstractDistribution:
method sample (line 6) | def sample(self):
method mode (line 9) | def mode(self):
class DiracDistribution (line 13) | class DiracDistribution(AbstractDistribution):
method __init__ (line 14) | def __init__(self, value):
method sample (line 17) | def sample(self):
method mode (line 20) | def mode(self):
class DiagonalGaussianDistribution (line 24) | class DiagonalGaussianDistribution(object):
method __init__ (line 25) | def __init__(self, parameters, deterministic=False):
method sample (line 35) | def sample(self, noise=None):
method kl (line 42) | def kl(self, other=None):
method nll (line 56) | def nll(self, sample, dims=[1,2,3]):
method mode (line 64) | def mode(self):
function normal_kl (line 68) | def normal_kl(mean1, logvar1, mean2, logvar2):
FILE: ToonCrafter/lvdm/ema.py
class LitEma (line 5) | class LitEma(nn.Module):
method __init__ (line 6) | def __init__(self, model, decay=0.9999, use_num_upates=True):
method forward (line 25) | def forward(self,model):
method copy_to (line 46) | def copy_to(self, model):
method store (line 55) | def store(self, parameters):
method restore (line 64) | def restore(self, parameters):
FILE: ToonCrafter/lvdm/models/autoencoder.py
class AutoencoderKL (line 15) | class AutoencoderKL(pl.LightningModule):
method __init__ (line 16) | def __init__(self,
method init_test (line 58) | def init_test(self,):
method init_from_ckpt (line 87) | def init_from_ckpt(self, path, ignore_keys=list()):
method encode (line 104) | def encode(self, x, return_hidden_states=False, **kwargs):
method decode (line 116) | def decode(self, z, **kwargs):
method forward (line 122) | def forward(self, input, sample_posterior=True, **additional_decode_kw...
method _forward (line 127) | def _forward(self, input, sample_posterior=True, **additional_decode_k...
method get_input (line 137) | def get_input(self, batch, k):
method training_step (line 147) | def training_step(self, batch, batch_idx, optimizer_idx):
method validation_step (line 168) | def validation_step(self, batch, batch_idx):
method configure_optimizers (line 182) | def configure_optimizers(self):
method get_last_layer (line 193) | def get_last_layer(self):
method log_images (line 197) | def log_images(self, batch, only_inputs=False, **kwargs):
method to_rgb (line 213) | def to_rgb(self, x):
class IdentityFirstStage (line 222) | class IdentityFirstStage(torch.nn.Module):
method __init__ (line 223) | def __init__(self, *args, vq_interface=False, **kwargs):
method encode (line 227) | def encode(self, x, *args, **kwargs):
method decode (line 230) | def decode(self, x, *args, **kwargs):
method quantize (line 233) | def quantize(self, x, *args, **kwargs):
method forward (line 238) | def forward(self, x, *args, **kwargs):
class AutoencoderKL_Dualref (line 245) | class AutoencoderKL_Dualref(AutoencoderKL):
method __init__ (line 246) | def __init__(self,
method _forward (line 266) | def _forward(self, input, sample_posterior=True, **additional_decode_k...
FILE: ToonCrafter/lvdm/models/autoencoder_dualref.py
function nonlinearity (line 24) | def nonlinearity(x):
function Normalize (line 29) | def Normalize(in_channels, num_groups=32):
class ResnetBlock (line 35) | class ResnetBlock(nn.Module):
method __init__ (line 36) | def __init__(
method forward (line 72) | def forward(self, x, temb):
class LinAttnBlock (line 95) | class LinAttnBlock(LinearAttention):
method __init__ (line 98) | def __init__(self, in_channels):
class AttnBlock (line 102) | class AttnBlock(nn.Module):
method __init__ (line 103) | def __init__(self, in_channels):
method attention (line 121) | def attention(self, h_: torch.Tensor) -> torch.Tensor:
method forward (line 138) | def forward(self, x, **kwargs):
class MemoryEfficientAttnBlock (line 145) | class MemoryEfficientAttnBlock(nn.Module):
method __init__ (line 153) | def __init__(self, in_channels):
method attention (line 172) | def attention(self, h_: torch.Tensor) -> torch.Tensor:
method forward (line 202) | def forward(self, x, **kwargs):
class CrossAttentionWrapper (line 209) | class CrossAttentionWrapper(CrossAttention):
method forward (line 210) | def forward(self, x, context=None, mask=None, **unused_kwargs):
class MemoryEfficientCrossAttentionWrapper (line 218) | class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
method forward (line 219) | def forward(self, x, context=None, mask=None, **unused_kwargs):
function make_attn (line 227) | def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
class CrossAttentionWrapperFusion (line 274) | class CrossAttentionWrapperFusion(CrossAttention):
method __init__ (line 275) | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, ...
method forward (line 282) | def forward(self, x, context=None, mask=None):
method _forward (line 288) | def _forward(
class MemoryEfficientCrossAttentionWrapperFusion (line 345) | class MemoryEfficientCrossAttentionWrapperFusion(MemoryEfficientCrossAtt...
method __init__ (line 347) | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, ...
method forward (line 353) | def forward(self, x, context=None, mask=None):
method _forward (line 359) | def _forward(
class Combiner (line 431) | class Combiner(nn.Module):
method __init__ (line 432) | def __init__(self, ch) -> None:
method forward (line 439) | def forward(self, x, context):
method _forward (line 445) | def _forward(self, x, context):
class Decoder (line 459) | class Decoder(nn.Module):
method __init__ (line 460) | def __init__(
method _make_attn (line 568) | def _make_attn(self) -> Callable:
method _make_resblock (line 571) | def _make_resblock(self) -> Callable:
method _make_conv (line 574) | def _make_conv(self) -> Callable:
method get_last_layer (line 577) | def get_last_layer(self, **kwargs):
method forward (line 580) | def forward(self, z, ref_context=None, **kwargs):
class TimestepBlock (line 636) | class TimestepBlock(nn.Module):
method forward (line 642) | def forward(self, x: torch.Tensor, emb: torch.Tensor):
class ResBlock (line 648) | class ResBlock(TimestepBlock):
method __init__ (line 664) | def __init__(
method forward (line 754) | def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
method _forward (line 766) | def _forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
class VideoTransformerBlock (line 800) | class VideoTransformerBlock(nn.Module):
method __init__ (line 806) | def __init__(
method forward (line 888) | def forward(
method _forward (line 896) | def _forward(self, x, context=None, timesteps=None):
method get_last_layer (line 929) | def get_last_layer(self):
function partialclass (line 939) | def partialclass(cls, *args, **kwargs):
class VideoResBlock (line 947) | class VideoResBlock(ResnetBlock):
method __init__ (line 948) | def __init__(
method get_alpha (line 985) | def get_alpha(self, bs):
method forward (line 993) | def forward(self, x, temb, skip_video=False, timesteps=None):
class AE3DConv (line 1015) | class AE3DConv(torch.nn.Conv2d):
method __init__ (line 1016) | def __init__(self, in_channels, out_channels, video_kernel_size=3, *ar...
method forward (line 1030) | def forward(self, input, timesteps, skip_video=False):
class VideoBlock (line 1039) | class VideoBlock(AttnBlock):
method __init__ (line 1040) | def __init__(
method forward (line 1071) | def forward(self, x, timesteps, skip_video=False):
method get_alpha (line 1098) | def get_alpha(
class MemoryEfficientVideoBlock (line 1109) | class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock):
method __init__ (line 1110) | def __init__(
method forward (line 1141) | def forward(self, x, timesteps, skip_time_block=False):
method get_alpha (line 1168) | def get_alpha(
function make_time_attn (line 1179) | def make_time_attn(
class Conv2DWrapper (line 1217) | class Conv2DWrapper(torch.nn.Conv2d):
method forward (line 1218) | def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
class VideoDecoder (line 1222) | class VideoDecoder(Decoder):
method __init__ (line 1225) | def __init__(
method get_last_layer (line 1243) | def get_last_layer(self, skip_time_mix=False, **kwargs):
method _make_attn (line 1253) | def _make_attn(self) -> Callable:
method _make_conv (line 1263) | def _make_conv(self) -> Callable:
method _make_resblock (line 1269) | def _make_resblock(self) -> Callable:
FILE: ToonCrafter/lvdm/models/ddpm3d.py
class DDPM (line 42) | class DDPM(pl.LightningModule):
method __init__ (line 44) | def __init__(self,
method register_schedule (line 125) | def register_schedule(self, given_betas=None, beta_schedule="linear", ...
method ema_scope (line 191) | def ema_scope(self, context=None):
method init_from_ckpt (line 205) | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
method q_mean_variance (line 223) | def q_mean_variance(self, x_start, t):
method predict_start_from_noise (line 235) | def predict_start_from_noise(self, x_t, t, noise):
method predict_start_from_z_and_v (line 241) | def predict_start_from_z_and_v(self, x_t, t, v):
method predict_eps_from_z_and_v (line 249) | def predict_eps_from_z_and_v(self, x_t, t, v):
method q_posterior (line 255) | def q_posterior(self, x_start, x_t, t):
method p_mean_variance (line 264) | def p_mean_variance(self, x, t, clip_denoised: bool):
method p_sample (line 277) | def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
method p_sample_loop (line 286) | def p_sample_loop(self, shape, return_intermediates=False):
method sample (line 301) | def sample(self, batch_size=16, return_intermediates=False):
method q_sample (line 307) | def q_sample(self, x_start, t, noise=None):
method get_v (line 312) | def get_v(self, x, noise, t):
method get_loss (line 318) | def get_loss(self, pred, target, mean=True):
method p_losses (line 333) | def p_losses(self, x_start, t, noise=None):
method forward (line 364) | def forward(self, x, *args, **kwargs):
method get_input (line 370) | def get_input(self, batch, k):
method shared_step (line 380) | def shared_step(self, batch):
method training_step (line 385) | def training_step(self, batch, batch_idx):
method validation_step (line 401) | def validation_step(self, batch, batch_idx):
method on_train_batch_end (line 409) | def on_train_batch_end(self, *args, **kwargs):
method _get_rows_from_list (line 413) | def _get_rows_from_list(self, samples):
method log_images (line 421) | def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=Non...
method configure_optimizers (line 458) | def configure_optimizers(self):
class LatentDiffusion (line 466) | class LatentDiffusion(DDPM):
method __init__ (line 468) | def __init__(self,
method make_cond_schedule (line 550) | def make_cond_schedule(self, ):
method on_train_batch_start (line 557) | def on_train_batch_start(self, batch, batch_idx, dataloader_idx=None):
method register_schedule (line 574) | def register_schedule(self, given_betas=None, beta_schedule="linear", ...
method instantiate_first_stage (line 582) | def instantiate_first_stage(self, config):
method instantiate_cond_stage (line 589) | def instantiate_cond_stage(self, config):
method get_learned_conditioning (line 600) | def get_learned_conditioning(self, c):
method get_first_stage_encoding (line 613) | def get_first_stage_encoding(self, encoder_posterior, noise=None):
method encode_first_stage (line 623) | def encode_first_stage(self, x):
method decode_core (line 648) | def decode_core(self, z, **kwargs):
method decode_first_stage (line 692) | def decode_first_stage(self, z, **kwargs):
method differentiable_decode_first_stage (line 696) | def differentiable_decode_first_stage(self, z, **kwargs):
method get_batch_input (line 700) | def get_batch_input(self, batch, random_uncond, return_first_stage_out...
method forward (line 733) | def forward(self, x, c, **kwargs):
method shared_step (line 739) | def shared_step(self, batch, random_uncond, **kwargs):
method apply_model (line 745) | def apply_model(self, x_noisy, t, cond, **kwargs):
method p_losses (line 762) | def p_losses(self, x_start, cond, t, noise=None, **kwargs):
method training_step (line 808) | def training_step(self, batch, batch_idx):
method _get_denoise_row_from_list (line 822) | def _get_denoise_row_from_list(self, samples, desc=''):
method log_images (line 847) | def log_images(self, batch, sample=True, ddim_steps=200, ddim_eta=1., ...
method p_mean_variance (line 902) | def p_mean_variance(self, x, c, t, clip_denoised: bool, return_x0=Fals...
method p_sample (line 928) | def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, r...
method p_sample_loop (line 950) | def p_sample_loop(self, cond, shape, return_intermediates=False, x_T=N...
method sample (line 997) | def sample(self, cond, batch_size=16, return_intermediates=False, x_T=...
method sample_log (line 1014) | def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
method configure_schedulers (line 1025) | def configure_schedulers(self, optimizer):
class LatentVisualDiffusion (line 1051) | class LatentVisualDiffusion(LatentDiffusion):
method __init__ (line 1052) | def __init__(self, img_cond_stage_config, image_proj_stage_config, fre...
method _init_img_ctx_projector (line 1058) | def _init_img_ctx_projector(self, config, trainable):
method _init_embedder (line 1066) | def _init_embedder(self, config, freeze=True):
method shared_step (line 1074) | def shared_step(self, batch, random_uncond, **kwargs):
method get_batch_input (line 1080) | def get_batch_input(self, batch, random_uncond, return_first_stage_out...
method log_images (line 1147) | def log_images(self, batch, sample=True, ddim_steps=50, ddim_eta=1., p...
method configure_optimizers (line 1218) | def configure_optimizers(self):
class DiffusionWrapper (line 1253) | class DiffusionWrapper(pl.LightningModule):
method __init__ (line 1254) | def __init__(self, diff_model_config, conditioning_key):
method forward (line 1259) | def forward(self, x, t, c_concat: list = None, c_crossattn: list = None,
FILE: ToonCrafter/lvdm/models/samplers/ddim.py
class DDIMSampler (line 10) | class DDIMSampler(object):
method __init__ (line 11) | def __init__(self, model, schedule="linear", **kwargs):
method register_buffer (line 18) | def register_buffer(self, name, attr):
method make_schedule (line 24) | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddi...
method sample (line 60) | def sample(self,
method ddim_sampling (line 135) | def ddim_sampling(self, cond, shape,
method p_sample_ddim (line 206) | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_origin...
method decode (line 282) | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale...
method stochastic_encode (line 305) | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
FILE: ToonCrafter/lvdm/models/samplers/ddim_multiplecond.py
class DDIMSampler (line 10) | class DDIMSampler(object):
method __init__ (line 11) | def __init__(self, model, schedule="linear", **kwargs):
method register_buffer (line 18) | def register_buffer(self, name, attr):
method make_schedule (line 24) | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddi...
method sample (line 60) | def sample(self,
method ddim_sampling (line 138) | def ddim_sampling(self, cond, shape,
method p_sample_ddim (line 211) | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_origin...
method decode (line 288) | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale...
method stochastic_encode (line 310) | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
FILE: ToonCrafter/lvdm/models/utils_diffusion.py
function timestep_embedding (line 8) | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=Fal...
function make_beta_schedule (line 31) | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_e...
function make_ddim_timesteps (line 56) | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_...
function make_ddim_sampling_parameters (line 79) | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbos...
function betas_for_alpha_bar (line 94) | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.9...
function rescale_zero_terminal_snr (line 112) | def rescale_zero_terminal_snr(betas):
function rescale_noise_cfg (line 147) | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
FILE: ToonCrafter/lvdm/modules/attention.py
class RelativePosition (line 20) | class RelativePosition(nn.Module):
method __init__ (line 23) | def __init__(self, num_units, max_relative_position):
method forward (line 30) | def forward(self, length_q, length_k):
class CrossAttention (line 42) | class CrossAttention(nn.Module):
method __init__ (line 44) | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, ...
method forward (line 81) | def forward(self, x, context=None, mask=None):
method efficient_forward (line 146) | def efficient_forward(self, x, context=None, mask=None):
class BasicTransformerBlock (line 212) | class BasicTransformerBlock(nn.Module):
method __init__ (line 214) | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None,...
method forward (line 231) | def forward(self, x, context=None, mask=None, **kwargs):
method _forward (line 242) | def _forward(self, x, context=None, mask=None):
class SpatialTransformer (line 249) | class SpatialTransformer(nn.Module):
method __init__ (line 259) | def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., ...
method forward (line 294) | def forward(self, x, context=None, **kwargs):
class TemporalTransformer (line 313) | class TemporalTransformer(nn.Module):
method __init__ (line 320) | def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., ...
method forward (line 365) | def forward(self, x, context=None):
class GEGLU (line 415) | class GEGLU(nn.Module):
method __init__ (line 416) | def __init__(self, dim_in, dim_out):
method forward (line 420) | def forward(self, x):
class FeedForward (line 425) | class FeedForward(nn.Module):
method __init__ (line 426) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
method forward (line 441) | def forward(self, x):
class LinearAttention (line 445) | class LinearAttention(nn.Module):
method __init__ (line 446) | def __init__(self, dim, heads=4, dim_head=32):
method forward (line 453) | def forward(self, x):
class SpatialSelfAttention (line 464) | class SpatialSelfAttention(nn.Module):
method __init__ (line 465) | def __init__(self, in_channels):
method forward (line 491) | def forward(self, x):
FILE: ToonCrafter/lvdm/modules/attention_svd.py
function exists (line 61) | def exists(val):
function uniq (line 65) | def uniq(arr):
function default (line 69) | def default(val, d):
function max_neg_value (line 75) | def max_neg_value(t):
function init_ (line 79) | def init_(tensor):
class GEGLU (line 87) | class GEGLU(nn.Module):
method __init__ (line 88) | def __init__(self, dim_in, dim_out):
method forward (line 92) | def forward(self, x):
class FeedForward (line 97) | class FeedForward(nn.Module):
method __init__ (line 98) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
method forward (line 112) | def forward(self, x):
function zero_module (line 116) | def zero_module(module):
function Normalize (line 125) | def Normalize(in_channels):
class LinearAttention (line 131) | class LinearAttention(nn.Module):
method __init__ (line 132) | def __init__(self, dim, heads=4, dim_head=32):
method forward (line 139) | def forward(self, x):
class SelfAttention (line 154) | class SelfAttention(nn.Module):
method __init__ (line 157) | def __init__(
method forward (line 182) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class SpatialSelfAttention (line 211) | class SpatialSelfAttention(nn.Module):
method __init__ (line 212) | def __init__(self, in_channels):
method forward (line 230) | def forward(self, x):
class CrossAttention (line 256) | class CrossAttention(nn.Module):
method __init__ (line 257) | def __init__(
method forward (line 282) | def forward(
class MemoryEfficientCrossAttention (line 348) | class MemoryEfficientCrossAttention(nn.Module):
method __init__ (line 350) | def __init__(
method forward (line 374) | def forward(
class BasicTransformerBlock (line 457) | class BasicTransformerBlock(nn.Module):
method __init__ (line 463) | def __init__(
method forward (line 528) | def forward(
method _forward (line 552) | def _forward(
class BasicTransformerSingleLayerBlock (line 576) | class BasicTransformerSingleLayerBlock(nn.Module):
method __init__ (line 583) | def __init__(
method forward (line 609) | def forward(self, x, context=None):
method _forward (line 614) | def _forward(self, x, context=None):
class SpatialTransformer (line 620) | class SpatialTransformer(nn.Module):
method __init__ (line 630) | def __init__(
method forward (line 703) | def forward(self, x, context=None):
class SimpleTransformer (line 727) | class SimpleTransformer(nn.Module):
method __init__ (line 728) | def __init__(
method forward (line 753) | def forward(
FILE: ToonCrafter/lvdm/modules/encoders/condition.py
class AbstractEncoder (line 12) | class AbstractEncoder(nn.Module):
method __init__ (line 13) | def __init__(self):
method encode (line 16) | def encode(self, *args, **kwargs):
class IdentityEncoder (line 20) | class IdentityEncoder(AbstractEncoder):
method encode (line 21) | def encode(self, x):
class ClassEmbedder (line 25) | class ClassEmbedder(nn.Module):
method __init__ (line 26) | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
method forward (line 33) | def forward(self, batch, key=None, disable_dropout=False):
method get_unconditional_conditioning (line 45) | def get_unconditional_conditioning(self, bs, device="cuda"):
function disabled_train (line 52) | def disabled_train(self, mode=True):
function get_available_devices (line 58) | def get_available_devices():
function get_device (line 68) | def get_device(device):
class FrozenT5Embedder (line 75) | class FrozenT5Embedder(AbstractEncoder):
method __init__ (line 78) | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_...
method freeze (line 88) | def freeze(self):
method forward (line 94) | def forward(self, text):
method encode (line 103) | def encode(self, text):
class FrozenCLIPEmbedder (line 107) | class FrozenCLIPEmbedder(AbstractEncoder):
method __init__ (line 115) | def __init__(self, version="openai/clip-vit-large-patch14", device="cu...
method freeze (line 131) | def freeze(self):
method forward (line 137) | def forward(self, text):
method encode (line 150) | def encode(self, text):
class ClipImageEmbedder (line 154) | class ClipImageEmbedder(nn.Module):
method __init__ (line 155) | def __init__(
method preprocess (line 173) | def preprocess(self, x):
method forward (line 183) | def forward(self, x, no_dropout=False):
class FrozenOpenCLIPEmbedder (line 192) | class FrozenOpenCLIPEmbedder(AbstractEncoder):
method __init__ (line 202) | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", devic...
method freeze (line 223) | def freeze(self):
method forward (line 228) | def forward(self, text):
method encode_with_transformer (line 233) | def encode_with_transformer(self, text):
method text_transformer_forward (line 242) | def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
method encode (line 252) | def encode(self, text):
class FrozenOpenCLIPImageEmbedder (line 256) | class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
method __init__ (line 261) | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", devic...
method preprocess (line 285) | def preprocess(self, x):
method freeze (line 295) | def freeze(self):
method forward (line 301) | def forward(self, image, no_dropout=False):
method encode_with_vision_transformer (line 307) | def encode_with_vision_transformer(self, img):
method encode (line 312) | def encode(self, text):
class FrozenOpenCLIPImageEmbedderV2 (line 316) | class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder):
method __init__ (line 321) | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", devic...
method preprocess (line 343) | def preprocess(self, x):
method freeze (line 353) | def freeze(self):
method forward (line 358) | def forward(self, image, no_dropout=False):
method encode_with_vision_transformer (line 363) | def encode_with_vision_transformer(self, x):
class FrozenCLIPT5Encoder (line 396) | class FrozenCLIPT5Encoder(AbstractEncoder):
method __init__ (line 397) | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_ve...
method encode (line 405) | def encode(self, text):
method forward (line 408) | def forward(self, text):
FILE: ToonCrafter/lvdm/modules/encoders/resampler.py
class ImageProjModel (line 9) | class ImageProjModel(nn.Module):
method __init__ (line 11) | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024,...
method forward (line 18) | def forward(self, image_embeds):
function FeedForward (line 27) | def FeedForward(dim, mult=4):
function reshape_tensor (line 37) | def reshape_tensor(x, heads):
class PerceiverAttention (line 48) | class PerceiverAttention(nn.Module):
method __init__ (line 49) | def __init__(self, *, dim, dim_head=64, heads=8):
method forward (line 64) | def forward(self, x, latents):
class Resampler (line 96) | class Resampler(nn.Module):
method __init__ (line 97) | def __init__(
method forward (line 134) | def forward(self, x):
FILE: ToonCrafter/lvdm/modules/networks/ae_modules.py
function nonlinearity (line 13) | def nonlinearity(x):
function Normalize (line 18) | def Normalize(in_channels, num_groups=32):
class LinAttnBlock (line 22) | class LinAttnBlock(LinearAttention):
method __init__ (line 25) | def __init__(self, in_channels):
class AttnBlock (line 29) | class AttnBlock(nn.Module):
method __init__ (line 30) | def __init__(self, in_channels):
method forward (line 56) | def forward(self, x):
function make_attn (line 84) | def make_attn(in_channels, attn_type="vanilla"):
class Downsample (line 95) | class Downsample(nn.Module):
method __init__ (line 96) | def __init__(self, in_channels, with_conv):
method forward (line 108) | def forward(self, x):
class Upsample (line 118) | class Upsample(nn.Module):
method __init__ (line 119) | def __init__(self, in_channels, with_conv):
method forward (line 130) | def forward(self, x):
function get_timestep_embedding (line 137) | def get_timestep_embedding(timesteps, embedding_dim):
class ResnetBlock (line 158) | class ResnetBlock(nn.Module):
method __init__ (line 159) | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=Fa...
method forward (line 197) | def forward(self, x, temb):
class Model (line 220) | class Model(nn.Module):
method __init__ (line 221) | def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
method forward (line 321) | def forward(self, x, t=None, context=None):
method get_last_layer (line 369) | def get_last_layer(self):
class Encoder (line 373) | class Encoder(nn.Module):
method __init__ (line 374) | def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
method forward (line 440) | def forward(self, x, return_hidden_states=False):
class Decoder (line 486) | class Decoder(nn.Module):
method __init__ (line 487) | def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
method forward (line 560) | def forward(self, z):
class SimpleDecoder (line 602) | class SimpleDecoder(nn.Module):
method __init__ (line 603) | def __init__(self, in_channels, out_channels, *args, **kwargs):
method forward (line 625) | def forward(self, x):
class UpsampleDecoder (line 638) | class UpsampleDecoder(nn.Module):
method __init__ (line 639) | def __init__(self, in_channels, out_channels, ch, num_res_blocks, reso...
method forward (line 672) | def forward(self, x):
class LatentRescaler (line 686) | class LatentRescaler(nn.Module):
method __init__ (line 687) | def __init__(self, factor, in_channels, mid_channels, out_channels, de...
method forward (line 711) | def forward(self, x):
class MergedRescaleEncoder (line 723) | class MergedRescaleEncoder(nn.Module):
method __init__ (line 724) | def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
method forward (line 736) | def forward(self, x):
class MergedRescaleDecoder (line 742) | class MergedRescaleDecoder(nn.Module):
method __init__ (line 743) | def __init__(self, z_channels, out_ch, resolution, num_res_blocks, att...
method forward (line 753) | def forward(self, x):
class Upsampler (line 759) | class Upsampler(nn.Module):
method __init__ (line 760) | def __init__(self, in_size, out_size, in_channels, out_channels, ch_mu...
method forward (line 772) | def forward(self, x):
class Resize (line 778) | class Resize(nn.Module):
method __init__ (line 779) | def __init__(self, in_channels=None, learned=False, mode="bilinear"):
method forward (line 794) | def forward(self, x, scale_factor=1.0):
class FirstStagePostProcessor (line 802) | class FirstStagePostProcessor(nn.Module):
method __init__ (line 804) | def __init__(self, ch_mult: list, in_channels,
method instantiate_pretrained (line 838) | def instantiate_pretrained(self, config):
method encode_with_pretrained (line 846) | def encode_with_pretrained(self, x):
method forward (line 852) | def forward(self, x):
FILE: ToonCrafter/lvdm/modules/networks/openaimodel3d.py
class TimestepBlock (line 19) | class TimestepBlock(nn.Module):
method forward (line 24) | def forward(self, x, emb):
class TimestepEmbedSequential (line 30) | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
method forward (line 36) | def forward(self, x, emb, context=None, batch_size=None):
class Downsample (line 51) | class Downsample(nn.Module):
method __init__ (line 60) | def __init__(self, channels, use_conv, dims=2, out_channels=None, padd...
method forward (line 75) | def forward(self, x):
class Upsample (line 80) | class Upsample(nn.Module):
method __init__ (line 89) | def __init__(self, channels, use_conv, dims=2, out_channels=None, padd...
method forward (line 98) | def forward(self, x):
class ResBlock (line 109) | class ResBlock(TimestepBlock):
method __init__ (line 126) | def __init__(
method forward (line 197) | def forward(self, x, emb, batch_size=None):
method _forward (line 210) | def _forward(self, x, emb, batch_size=None):
class TemporalConvBlock (line 239) | class TemporalConvBlock(nn.Module):
method __init__ (line 243) | def __init__(self, in_channels, out_channels=None, dropout=0.0, spatia...
method forward (line 272) | def forward(self, x):
class UNetModel (line 281) | class UNetModel(nn.Module):
method __init__ (line 311) | def __init__(self,
method forward (line 548) | def forward(self, x, timesteps, context=None, features_adapter=None, f...
FILE: ToonCrafter/lvdm/modules/x_transformer.py
class AbsolutePositionalEmbedding (line 24) | class AbsolutePositionalEmbedding(nn.Module):
method __init__ (line 25) | def __init__(self, dim, max_seq_len):
method init_ (line 30) | def init_(self):
method forward (line 33) | def forward(self, x):
class FixedPositionalEmbedding (line 38) | class FixedPositionalEmbedding(nn.Module):
method __init__ (line 39) | def __init__(self, dim):
method forward (line 44) | def forward(self, x, seq_dim=1, offset=0):
function exists (line 53) | def exists(val):
function default (line 57) | def default(val, d):
function always (line 63) | def always(val):
function not_equals (line 69) | def not_equals(val):
function equals (line 75) | def equals(val):
function max_neg_value (line 81) | def max_neg_value(tensor):
function pick_and_pop (line 87) | def pick_and_pop(keys, d):
function group_dict_by_key (line 92) | def group_dict_by_key(cond, d):
function string_begins_with (line 101) | def string_begins_with(prefix, str):
function group_by_key_prefix (line 105) | def group_by_key_prefix(prefix, d):
function groupby_prefix_and_trim (line 109) | def groupby_prefix_and_trim(prefix, d):
class Scale (line 116) | class Scale(nn.Module):
method __init__ (line 117) | def __init__(self, value, fn):
method forward (line 122) | def forward(self, x, **kwargs):
class Rezero (line 127) | class Rezero(nn.Module):
method __init__ (line 128) | def __init__(self, fn):
method forward (line 133) | def forward(self, x, **kwargs):
class ScaleNorm (line 138) | class ScaleNorm(nn.Module):
method __init__ (line 139) | def __init__(self, dim, eps=1e-5):
method forward (line 145) | def forward(self, x):
class RMSNorm (line 150) | class RMSNorm(nn.Module):
method __init__ (line 151) | def __init__(self, dim, eps=1e-8):
method forward (line 157) | def forward(self, x):
class Residual (line 162) | class Residual(nn.Module):
method forward (line 163) | def forward(self, x, residual):
class GRUGating (line 167) | class GRUGating(nn.Module):
method __init__ (line 168) | def __init__(self, dim):
method forward (line 172) | def forward(self, x, residual):
class GEGLU (line 183) | class GEGLU(nn.Module):
method __init__ (line 184) | def __init__(self, dim_in, dim_out):
method forward (line 188) | def forward(self, x):
class FeedForward (line 193) | class FeedForward(nn.Module):
method __init__ (line 194) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
method forward (line 209) | def forward(self, x):
class Attention (line 214) | class Attention(nn.Module):
method __init__ (line 215) | def __init__(
method forward (line 267) | def forward(
class AttentionLayers (line 369) | class AttentionLayers(nn.Module):
method __init__ (line 370) | def __init__(
method forward (line 480) | def forward(
class Encoder (line 540) | class Encoder(AttentionLayers):
method __init__ (line 541) | def __init__(self, **kwargs):
class TransformerWrapper (line 547) | class TransformerWrapper(nn.Module):
method __init__ (line 548) | def __init__(
method init_ (line 594) | def init_(self):
method forward (line 597) | def forward(
FILE: ToonCrafter/main/callbacks.py
class ImageLogger (line 15) | class ImageLogger(Callback):
method __init__ (line 16) | def __init__(self, batch_frequency, max_images=8, clamp=True, rescale=...
method log_to_tensorboard (line 31) | def log_to_tensorboard(self, pl_module, batch_logs, filename, split, s...
method log_batch_imgs (line 58) | def log_batch_imgs(self, pl_module, batch, batch_idx, split="train"):
method on_train_batch_end (line 90) | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch...
method on_validation_batch_end (line 94) | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, ...
class CUDACallback (line 104) | class CUDACallback(Callback):
method on_train_epoch_start (line 106) | def on_train_epoch_start(self, trainer, pl_module):
method on_train_epoch_end (line 117) | def on_train_epoch_end(self, trainer, pl_module):
FILE: ToonCrafter/main/trainer.py
function get_parser (line 14) | def get_parser(**parser_kwargs):
function get_nondefault_trainer_args (line 33) | def get_nondefault_trainer_args(args):
function melk (line 129) | def melk(*args, **kwargs):
function divein (line 136) | def divein(*args, **kwargs):
FILE: ToonCrafter/main/utils_data.py
function worker_init_fn (line 15) | def worker_init_fn(_):
class WrappedDataset (line 31) | class WrappedDataset(Dataset):
method __init__ (line 34) | def __init__(self, dataset):
method __len__ (line 37) | def __len__(self):
method __getitem__ (line 40) | def __getitem__(self, idx):
class DataModuleFromConfig (line 44) | class DataModuleFromConfig(pl.LightningDataModule):
method __init__ (line 45) | def __init__(self, batch_size, train=None, validation=None, test=None,...
method prepare_data (line 72) | def prepare_data(self):
method setup (line 75) | def setup(self, stage=None):
method _train_dataloader (line 81) | def _train_dataloader(self):
method _val_dataloader (line 93) | def _val_dataloader(self, shuffle=False):
method _test_dataloader (line 106) | def _test_dataloader(self, shuffle=False):
method _predict_dataloader (line 128) | def _predict_dataloader(self, shuffle=False):
FILE: ToonCrafter/main/utils_train.py
function init_workspace (line 9) | def init_workspace(name, logdir, model_config, lightning_config, rank=0):
function check_config_attribute (line 28) | def check_config_attribute(config, name):
function get_trainer_callbacks (line 35) | def get_trainer_callbacks(lightning_config, config, logdir, ckptdir, log...
function get_trainer_logger (line 99) | def get_trainer_logger(lightning_config, logdir, on_debug):
function get_trainer_strategy (line 125) | def get_trainer_strategy(lightning_config):
function load_checkpoints (line 138) | def load_checkpoints(model, model_cfg):
function set_logger (line 162) | def set_logger(logfile, name='mainlogger'):
FILE: ToonCrafter/scripts/evaluation/ddp_wrapper.py
function setup_dist (line 8) | def setup_dist(local_rank):
function get_dist_info (line 15) | def get_dist_info():
FILE: ToonCrafter/scripts/evaluation/funcs.py
function batch_ddim_sampling (line 17) | def batch_ddim_sampling(model: LatentDiffusion, cond, noise_shape, n_sam...
function get_filelist (line 99) | def get_filelist(data_dir, ext='*'):
function get_dirlist (line 105) | def get_dirlist(path):
function load_model_checkpoint (line 117) | def load_model_checkpoint(model, ckpt):
function load_prompts (line 155) | def load_prompts(prompt_file):
function load_video_batch (line 166) | def load_video_batch(filepath_list, frame_stride, video_size=(256, 256),...
function load_image_batch (line 208) | def load_image_batch(filepath_list, image_size=(256, 256)):
function save_videos (line 232) | def save_videos(batch_tensors, savedir, filenames, fps=10):
function get_latent_z (line 247) | def get_latent_z(model, videos):
FILE: ToonCrafter/scripts/evaluation/inference.py
function get_filelist (line 19) | def get_filelist(data_dir, postfixes):
function load_model_checkpoint (line 27) | def load_model_checkpoint(model, ckpt):
function load_prompts (line 54) | def load_prompts(prompt_file):
function load_data_prompts (line 64) | def load_data_prompts(data_dir, video_size=(256,256), video_frames=16, i...
function save_results (line 109) | def save_results(prompt, samples, filename, fakedir, fps=8, loop=False):
function save_results_seperate (line 135) | def save_results_seperate(prompt, samples, filename, fakedir, fps=10, lo...
function get_latent_z (line 157) | def get_latent_z(model, videos):
function get_latent_z_with_hidden_states (line 164) | def get_latent_z_with_hidden_states(model, videos):
function image_guided_synthesis (line 180) | def image_guided_synthesis(model, prompts, videos, noise_shape, n_sample...
function run_inference (line 277) | def run_inference(args, gpu_num, gpu_no):
function get_parser (line 344) | def get_parser():
FILE: ToonCrafter/scripts/gradio/i2v_test.py
class Image2Video (line 13) | class Image2Video():
method __init__ (line 14) | def __init__(self,result_dir='./tmp/',gpu_num=1,resolution='256_256') ...
method get_image (line 37) | def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, f...
method download_model (line 94) | def download_model(self):
FILE: ToonCrafter/scripts/gradio/i2v_test_application.py
class Image2Video (line 13) | class Image2Video():
method __init__ (line 14) | def __init__(self,result_dir='./tmp/',gpu_num=1,resolution='256_256') ...
method get_image (line 38) | def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, f...
method download_model (line 117) | def download_model(self):
method get_latent_z_with_hidden_states (line 127) | def get_latent_z_with_hidden_states(self, model, videos):
FILE: ToonCrafter/utils/save_video.py
function frames_to_mp4 (line 14) | def frames_to_mp4(frame_dir,output_path,fps):
function tensor_to_mp4 (line 27) | def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None):
function tensor2videogrids (line 44) | def tensor2videogrids(video, root, filename, fps, rescale=True, clamp=Tr...
function log_local (line 62) | def log_local(batch_logs, save_dir, filename, save_fps=10, rescale=True):
function prepare_to_log (line 120) | def prepare_to_log(batch_logs, max_images=100000, clamp=True):
function fill_with_black_squares (line 140) | def fill_with_black_squares(video, desired_len: int) -> Tensor:
function load_num_videos (line 150) | def load_num_videos(data_path, num_videos):
function npz_to_video_grid (line 163) | def npz_to_video_grid(data_path, out_path, num_frames, fps, num_videos=N...
FILE: ToonCrafter/utils/utils.py
function count_params (line 9) | def count_params(model, verbose=False):
function check_istarget (line 16) | def check_istarget(name, para_list):
function instantiate_from_config (line 28) | def instantiate_from_config(config):
function get_obj_from_str (line 38) | def get_obj_from_str(string, reload=False):
function load_npz_from_dir (line 46) | def load_npz_from_dir(data_dir):
function load_npz_from_paths (line 52) | def load_npz_from_paths(data_paths):
function resize_numpy_image (line 58) | def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edg...
function setup_dist (line 71) | def setup_dist(args):
FILE: __init__.py
function instantiate_from_config (line 36) | def instantiate_from_config(config):
function get_obj_from_str (line 46) | def get_obj_from_str(string, reload=False):
function get_state_dict (line 54) | def get_state_dict(d):
function load_state_dict (line 58) | def load_state_dict(ckpt_path, location='cpu'):
function get_models (line 71) | def get_models(root: Path = ROOT.joinpath("checkpoints"), ignoreed: tupl...
class ToonCrafterNode (line 83) | class ToonCrafterNode:
method INPUT_TYPES (line 86) | def INPUT_TYPES(s):
method init (line 111) | def init(self, ckpt_name="", result_dir=ROOT.joinpath("tmp/"), gpu_num...
method optional_autocast (line 144) | def optional_autocast(device):
method get_image (line 152) | def get_image(self, image: torch.Tensor, ckpt_name, vram_opt_strategy,...
method save_videos (line 266) | def save_videos(self, batch_tensors, savedir, filenames, fps=10):
method download_model (line 281) | def download_model(self):
method get_latent_z_with_hidden_states (line 291) | def get_latent_z_with_hidden_states(self, model, videos):
class ToonCrafterWithSketch (line 308) | class ToonCrafterWithSketch(ToonCrafterNode):
method INPUT_TYPES (line 311) | def INPUT_TYPES(s):
method init (line 337) | def init(self, ckpt_name="", result_dir=ROOT.joinpath("tmp/"), gpu_num...
method get_image (line 380) | def get_image(self, image: torch.Tensor, ckpt_name, vram_opt_strategy,...
FILE: pre_run.py
function run (line 9) | def run():
Condensed preview — 98 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (991K chars).
[
{
"path": ".github/workflows/publish.yml",
"chars": 581,
"preview": "name: Publish to Comfy registry\non:\n workflow_dispatch:\n push:\n branches:\n - main\n - master\n paths:\n "
},
{
"path": ".gitignore",
"chars": 135,
"preview": ".DS_Store\n*pyc\n.vscode\n__pycache__\n*.egg-info\n\ncheckpoints\nToonCrafter/checkpoints\nresults\nbackup\nLOG\n/models\nToonCrafte"
},
{
"path": "LICENSE",
"chars": 11331,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "ToonCrafter/.gitignore",
"chars": 77,
"preview": ".DS_Store\n*pyc\n.vscode\n__pycache__\n*.egg-info\n\ncheckpoints\nresults\nbackup\nLOG"
},
{
"path": "ToonCrafter/LICENSE",
"chars": 11331,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "ToonCrafter/README.md",
"chars": 5619,
"preview": "## ___***ToonCrafter: Generative Cartoon Interpolation***___\n<!-- {: width"
},
{
"path": "ToonCrafter/__init__.py",
"chars": 86,
"preview": "import sys\nfrom pathlib import Path\nsys.path.append(Path(__file__).parent.as_posix())\n"
},
{
"path": "ToonCrafter/cldm/cldm.py",
"chars": 21052,
"preview": "import einops\nimport torch\nimport torch as th\nimport torch.nn as nn\n\nfrom ToonCrafter.ldm.modules.diffusionmodules.util "
},
{
"path": "ToonCrafter/cldm/ddim_hacked.py",
"chars": 16446,
"preview": "\"\"\"SAMPLING ONLY.\"\"\"\n\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom ToonCrafter.ldm.modules.diffusionmodul"
},
{
"path": "ToonCrafter/cldm/hack.py",
"chars": 3579,
"preview": "import torch\nimport einops\n\nimport ldm.modules.encoders.modules\nimport ldm.modules.attention\n\nfrom transformers import l"
},
{
"path": "ToonCrafter/cldm/logger.py",
"chars": 3182,
"preview": "import os\n\nimport numpy as np\nimport torch\nimport torchvision\nfrom PIL import Image\nfrom pytorch_lightning.callbacks imp"
},
{
"path": "ToonCrafter/cldm/model.py",
"chars": 842,
"preview": "import os\nimport torch\n\nfrom omegaconf import OmegaConf\nfrom comfy.ldm.util import instantiate_from_config\n\n\ndef get_sta"
},
{
"path": "ToonCrafter/configs/cldm_v21.yaml",
"chars": 476,
"preview": "control_stage_config:\n target: ToonCrafter.cldm.cldm.ControlNet\n params:\n use_checkpoint: True\n image_size: 32 #"
},
{
"path": "ToonCrafter/configs/inference_512_v1.0.yaml",
"chars": 2519,
"preview": "model:\n target: lvdm.models.ddpm3d.LatentVisualDiffusion\n params:\n rescale_betas_zero_snr: True\n parameterizatio"
},
{
"path": "ToonCrafter/configs/training_1024_v1.0/config.yaml",
"chars": 4260,
"preview": "model:\n pretrained_checkpoint: checkpoints/dynamicrafter_1024_v1/model.ckpt\n base_learning_rate: 1.0e-05\n scale_lr: F"
},
{
"path": "ToonCrafter/configs/training_1024_v1.0/run.sh",
"chars": 1020,
"preview": "# NCCL configuration\n# export NCCL_DEBUG=INFO\n# export NCCL_IB_DISABLE=0\n# export NCCL_IB_GID_INDEX=3\n# export NCCL_NET_"
},
{
"path": "ToonCrafter/configs/training_512_v1.0/config.yaml",
"chars": 4257,
"preview": "model:\n pretrained_checkpoint: checkpoints/dynamicrafter_512_v1/model.ckpt\n base_learning_rate: 1.0e-05\n scale_lr: Fa"
},
{
"path": "ToonCrafter/configs/training_512_v1.0/run.sh",
"chars": 1019,
"preview": "# NCCL configuration\n# export NCCL_DEBUG=INFO\n# export NCCL_IB_DISABLE=0\n# export NCCL_IB_GID_INDEX=3\n# export NCCL_NET_"
},
{
"path": "ToonCrafter/gradio_app.py",
"chars": 4212,
"preview": "import os\nimport argparse\nimport sys\nimport gradio as gr\nfrom scripts.gradio.i2v_test_application import Image2Video\nsys"
},
{
"path": "ToonCrafter/ldm/data/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ToonCrafter/ldm/data/util.py",
"chars": 641,
"preview": "import torch\n\nfrom ToonCrafter.ldm.modules.midas.api import load_midas_transform\n\n\nclass AddMiDaS(object):\n def __ini"
},
{
"path": "ToonCrafter/ldm/models/autoencoder.py",
"chars": 8608,
"preview": "import torch\nimport pytorch_lightning as pl\nimport torch.nn.functional as F\nfrom contextlib import contextmanager\n\nfrom "
},
{
"path": "ToonCrafter/ldm/models/diffusion/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ToonCrafter/ldm/models/diffusion/ddim.py",
"chars": 17316,
"preview": "\"\"\"SAMPLING ONLY.\"\"\"\n\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom ToonCrafter.ldm.modules.diffusionmodul"
},
{
"path": "ToonCrafter/ldm/models/diffusion/ddpm.py",
"chars": 84719,
"preview": "\"\"\"\nwild mixture of\nhttps://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e316"
},
{
"path": "ToonCrafter/ldm/models/diffusion/dpm_solver/__init__.py",
"chars": 37,
"preview": "from .sampler import DPMSolverSampler"
},
{
"path": "ToonCrafter/ldm/models/diffusion/dpm_solver/dpm_solver.py",
"chars": 65968,
"preview": "import torch\nimport torch.nn.functional as F\nimport math\nfrom tqdm import tqdm\n\n\nclass NoiseScheduleVP:\n def __init__"
},
{
"path": "ToonCrafter/ldm/models/diffusion/dpm_solver/sampler.py",
"chars": 2990,
"preview": "\"\"\"SAMPLING ONLY.\"\"\"\nimport torch\n\nfrom .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver\n\n\nMODEL_TYPES = {\n"
},
{
"path": "ToonCrafter/ldm/models/diffusion/plms.py",
"chars": 12951,
"preview": "\"\"\"SAMPLING ONLY.\"\"\"\n\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\nfrom functools import partial\n\nfrom ToonCraf"
},
{
"path": "ToonCrafter/ldm/models/diffusion/sampling_util.py",
"chars": 753,
"preview": "import torch\nimport numpy as np\n\n\ndef append_dims(x, target_dims):\n \"\"\"Appends dimensions to the end of a tensor unti"
},
{
"path": "ToonCrafter/ldm/modules/attention.py",
"chars": 11818,
"preview": "from inspect import isfunction\nimport math\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn, einsum\nfro"
},
{
"path": "ToonCrafter/ldm/modules/diffusionmodules/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ToonCrafter/ldm/modules/diffusionmodules/model.py",
"chars": 34396,
"preview": "# pytorch_diffusion + derived encoder decoder\nimport math\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom ein"
},
{
"path": "ToonCrafter/ldm/modules/diffusionmodules/openaimodel.py",
"chars": 30413,
"preview": "from abc import abstractmethod\nimport math\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nimport torch.nn."
},
{
"path": "ToonCrafter/ldm/modules/diffusionmodules/upscaling.py",
"chars": 3448,
"preview": "import torch\nimport torch.nn as nn\nimport numpy as np\nfrom functools import partial\n\nfrom ToonCrafter.ldm.modules.diffus"
},
{
"path": "ToonCrafter/ldm/modules/diffusionmodules/util.py",
"chars": 9880,
"preview": "# adopted from\n# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py\n# and\n#"
},
{
"path": "ToonCrafter/ldm/modules/distributions/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ToonCrafter/ldm/modules/distributions/distributions.py",
"chars": 2970,
"preview": "import torch\nimport numpy as np\n\n\nclass AbstractDistribution:\n def sample(self):\n raise NotImplementedError()\n"
},
{
"path": "ToonCrafter/ldm/modules/ema.py",
"chars": 3110,
"preview": "import torch\nfrom torch import nn\n\n\nclass LitEma(nn.Module):\n def __init__(self, model, decay=0.9999, use_num_upates="
},
{
"path": "ToonCrafter/ldm/modules/encoders/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ToonCrafter/ldm/modules/encoders/modules.py",
"chars": 7623,
"preview": "import torch\nimport torch.nn as nn\nfrom torch.utils.checkpoint import checkpoint\n\nfrom transformers import T5Tokenizer, "
},
{
"path": "ToonCrafter/ldm/modules/image_degradation/__init__.py",
"chars": 232,
"preview": "from ToonCrafter.ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr\nfrom ToonC"
},
{
"path": "ToonCrafter/ldm/modules/image_degradation/bsrgan.py",
"chars": 25198,
"preview": "# -*- coding: utf-8 -*-\n\"\"\"\n# --------------------------------------------\n# Super-Resolution\n# ------------------------"
},
{
"path": "ToonCrafter/ldm/modules/image_degradation/bsrgan_light.py",
"chars": 22341,
"preview": "# -*- coding: utf-8 -*-\nimport numpy as np\nimport cv2\nimport torch\n\nfrom functools import partial\nimport random\nfrom sci"
},
{
"path": "ToonCrafter/ldm/modules/image_degradation/utils_image.py",
"chars": 29022,
"preview": "import os\nimport math\nimport random\nimport numpy as np\nimport torch\nimport cv2\nfrom torchvision.utils import make_grid\nf"
},
{
"path": "ToonCrafter/ldm/modules/midas/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ToonCrafter/ldm/modules/midas/api.py",
"chars": 5386,
"preview": "# based on https://github.com/isl-org/MiDaS\n\nimport cv2\nimport torch\nimport torch.nn as nn\nfrom torchvision.transforms i"
},
{
"path": "ToonCrafter/ldm/modules/midas/midas/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ToonCrafter/ldm/modules/midas/midas/base_model.py",
"chars": 367,
"preview": "import torch\n\n\nclass BaseModel(torch.nn.Module):\n def load(self, path):\n \"\"\"Load model from file.\n\n Arg"
},
{
"path": "ToonCrafter/ldm/modules/midas/midas/blocks.py",
"chars": 9242,
"preview": "import torch\nimport torch.nn as nn\n\nfrom .vit import (\n _make_pretrained_vitb_rn50_384,\n _make_pretrained_vitl16_3"
},
{
"path": "ToonCrafter/ldm/modules/midas/midas/dpt_depth.py",
"chars": 3154,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .base_model import BaseModel\nfrom .blocks impor"
},
{
"path": "ToonCrafter/ldm/modules/midas/midas/midas_net.py",
"chars": 2709,
"preview": "\"\"\"MidashNet: Network for monocular depth estimation trained by mixing several datasets.\nThis file contains code that is"
},
{
"path": "ToonCrafter/ldm/modules/midas/midas/midas_net_custom.py",
"chars": 5207,
"preview": "\"\"\"MidashNet: Network for monocular depth estimation trained by mixing several datasets.\nThis file contains code that is"
},
{
"path": "ToonCrafter/ldm/modules/midas/midas/transforms.py",
"chars": 7869,
"preview": "import numpy as np\nimport cv2\nimport math\n\n\ndef apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):"
},
{
"path": "ToonCrafter/ldm/modules/midas/midas/vit.py",
"chars": 14625,
"preview": "import torch\nimport torch.nn as nn\nimport timm\nimport types\nimport math\nimport torch.nn.functional as F\n\n\nclass Slice(nn"
},
{
"path": "ToonCrafter/ldm/modules/midas/utils.py",
"chars": 4582,
"preview": "\"\"\"Utils for monoDepth.\"\"\"\nimport sys\nimport re\nimport numpy as np\nimport cv2\nimport torch\n\n\ndef read_pfm(path):\n \"\"\""
},
{
"path": "ToonCrafter/ldm/util.py",
"chars": 7227,
"preview": "import importlib\n\nimport torch\nfrom torch import optim\nimport numpy as np\n\nfrom inspect import isfunction\nfrom PIL impor"
},
{
"path": "ToonCrafter/lvdm/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ToonCrafter/lvdm/basics.py",
"chars": 2864,
"preview": "# adopted from\n# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py\n# and\n#"
},
{
"path": "ToonCrafter/lvdm/common.py",
"chars": 2819,
"preview": "import math\nfrom inspect import isfunction\nimport torch\nfrom torch import nn\nimport torch.distributed as dist\n\n\ndef gath"
},
{
"path": "ToonCrafter/lvdm/data/base.py",
"chars": 655,
"preview": "from abc import abstractmethod\nfrom torch.utils.data import IterableDataset\n\n\nclass Txt2ImgIterableBaseDataset(IterableD"
},
{
"path": "ToonCrafter/lvdm/data/webvid.py",
"chars": 7926,
"preview": "import os\nimport random\nfrom tqdm import tqdm\nimport pandas as pd\nfrom decord import VideoReader, cpu\n\nimport torch\nfrom"
},
{
"path": "ToonCrafter/lvdm/distributions.py",
"chars": 3042,
"preview": "import torch\nimport numpy as np\n\n\nclass AbstractDistribution:\n def sample(self):\n raise NotImplementedError()\n"
},
{
"path": "ToonCrafter/lvdm/ema.py",
"chars": 2981,
"preview": "import torch\nfrom torch import nn\n\n\nclass LitEma(nn.Module):\n def __init__(self, model, decay=0.9999, use_num_upates="
},
{
"path": "ToonCrafter/lvdm/models/autoencoder.py",
"chars": 11382,
"preview": "import os\nfrom contextlib import contextmanager\nimport torch\nimport numpy as np\nfrom einops import rearrange\nimport torc"
},
{
"path": "ToonCrafter/lvdm/models/autoencoder_dualref.py",
"chars": 42914,
"preview": "#### https://github.com/Stability-AI/generative-models\nfrom einops import rearrange, repeat\nimport logging\nfrom typing i"
},
{
"path": "ToonCrafter/lvdm/models/ddpm3d.py",
"chars": 58857,
"preview": "\"\"\"\nwild mixture of\nhttps://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_"
},
{
"path": "ToonCrafter/lvdm/models/samplers/ddim.py",
"chars": 16095,
"preview": "import numpy as np\nfrom tqdm import tqdm\nimport torch\nfrom lvdm.models.utils_diffusion import make_ddim_sampling_paramet"
},
{
"path": "ToonCrafter/lvdm/models/samplers/ddim_multiplecond.py",
"chars": 16351,
"preview": "import numpy as np\nfrom tqdm import tqdm\nimport torch\nfrom lvdm.models.utils_diffusion import make_ddim_sampling_paramet"
},
{
"path": "ToonCrafter/lvdm/models/utils_diffusion.py",
"chars": 6833,
"preview": "import math\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom einops import repeat\n\n\ndef timestep_emb"
},
{
"path": "ToonCrafter/lvdm/modules/attention.py",
"chars": 21384,
"preview": "import torch\nfrom torch import nn, einsum\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\nfrom func"
},
{
"path": "ToonCrafter/lvdm/modules/attention_svd.py",
"chars": 25167,
"preview": "import logging\nimport math\nfrom inspect import isfunction\nfrom typing import Any, Optional\n\nimport torch\nimport torch.nn"
},
{
"path": "ToonCrafter/lvdm/modules/encoders/condition.py",
"chars": 15332,
"preview": "import torch\nimport torch.nn as nn\nimport kornia\nimport open_clip\nimport os\nfrom torch.utils.checkpoint import checkpoin"
},
{
"path": "ToonCrafter/lvdm/modules/encoders/resampler.py",
"chars": 4961,
"preview": "# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py\n# and https://gith"
},
{
"path": "ToonCrafter/lvdm/modules/networks/ae_modules.py",
"chars": 34864,
"preview": "# pytorch_diffusion + derived encoder decoder\nimport math\n\nimport torch\nimport numpy as np\nimport torch.nn as nn\nfrom ei"
},
{
"path": "ToonCrafter/lvdm/modules/networks/openaimodel3d.py",
"chars": 26086,
"preview": "from functools import partial\nfrom abc import abstractmethod\nimport torch\nimport torch.nn as nn\nfrom einops import rearr"
},
{
"path": "ToonCrafter/lvdm/modules/x_transformer.py",
"chars": 20157,
"preview": "\"\"\"shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers\"\"\"\nfrom functools import partial\nf"
},
{
"path": "ToonCrafter/main/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ToonCrafter/main/callbacks.py",
"chars": 6312,
"preview": "import os\nimport time\nimport logging\nmainlogger = logging.getLogger('mainlogger')\n\nimport torch\nimport torchvision\nimpor"
},
{
"path": "ToonCrafter/main/trainer.py",
"chars": 7494,
"preview": "import argparse, os, sys, datetime\nfrom omegaconf import OmegaConf\nfrom transformers import logging as transf_logging\nim"
},
{
"path": "ToonCrafter/main/utils_data.py",
"chars": 5603,
"preview": "from functools import partial\nimport numpy as np\n\nimport torch\nimport pytorch_lightning as pl\nfrom torch.utils.data impo"
},
{
"path": "ToonCrafter/main/utils_train.py",
"chars": 6963,
"preview": "import os, re\nfrom omegaconf import OmegaConf\nimport logging\nmainlogger = logging.getLogger('mainlogger')\n\nimport torch\n"
},
{
"path": "ToonCrafter/prompts/512_interp/prompts.txt",
"chars": 41,
"preview": "walking man\nan anime scene\nan anime scene"
},
{
"path": "ToonCrafter/requirements.txt",
"chars": 307,
"preview": "decord==0.6.0\neinops==0.3.0\nimageio==2.9.0\nnumpy==1.24.2\nomegaconf==2.1.1\nopencv_python\npandas==2.0.0\nPillow==9.5.0\npyto"
},
{
"path": "ToonCrafter/scripts/evaluation/ddp_wrapper.py",
"chars": 1535,
"preview": "import datetime\nimport argparse, importlib\nfrom pytorch_lightning import seed_everything\n\nimport torch\nimport torch.dist"
},
{
"path": "ToonCrafter/scripts/evaluation/funcs.py",
"chars": 10293,
"preview": "import os\nimport sys\nimport glob\nimport numpy as np\nfrom collections import OrderedDict\nfrom decord import VideoReader, "
},
{
"path": "ToonCrafter/scripts/evaluation/inference.py",
"chars": 18558,
"preview": "import argparse, os, sys, glob\nimport datetime, time\nfrom omegaconf import OmegaConf\nfrom tqdm import tqdm\nfrom einops i"
},
{
"path": "ToonCrafter/scripts/gradio/i2v_test.py",
"chars": 5037,
"preview": "import os\nimport time\nfrom omegaconf import OmegaConf\nimport torch\nfrom scripts.evaluation.funcs import load_model_check"
},
{
"path": "ToonCrafter/scripts/gradio/i2v_test_application.py",
"chars": 6671,
"preview": "import os\nimport time\nfrom omegaconf import OmegaConf\nimport torch\nfrom scripts.evaluation.funcs import load_model_check"
},
{
"path": "ToonCrafter/scripts/run.sh",
"chars": 727,
"preview": "\nckpt=checkpoints/tooncrafter_512_interp_v1/model.ckpt\nconfig=configs/inference_512_v1.0.yaml\n\nprompt_dir=prompts/512_in"
},
{
"path": "ToonCrafter/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ToonCrafter/utils/save_video.py",
"chars": 8485,
"preview": "import os\nimport numpy as np\nfrom tqdm import tqdm\nfrom PIL import Image\nfrom einops import rearrange\n\nimport torch\nimpo"
},
{
"path": "ToonCrafter/utils/utils.py",
"chars": 2179,
"preview": "import os\nimport importlib\nimport numpy as np\nimport cv2\nimport torch\nimport torch.distributed as dist\n\n\ndef count_param"
},
{
"path": "__init__.py",
"chars": 23263,
"preview": "import os\nimport sys\nimport torch\nimport time\nimport logging as logger\nimport importlib\n\nfrom functools import cache\nfro"
},
{
"path": "pre_run.py",
"chars": 1435,
"preview": "import sys\nimport argparse\nimport os\nfrom pathlib import Path\nsys.path.append(Path(__file__).parent.as_posix())\nfrom scr"
},
{
"path": "pyproject.toml",
"chars": 774,
"preview": "[project]\nname = \"comfyui-tooncrafter\"\ndescription = \"This project is used to enable [a/ToonCrafter](https://github.com/"
},
{
"path": "readme.md",
"chars": 2150,
"preview": "# Introduction\nThis project is used to enable [ToonCrafter](https://github.com/ToonCrafter/ToonCrafter) to be used in Co"
},
{
"path": "requirements.txt",
"chars": 178,
"preview": "imageio==2.9.0\nnumpy\nomegaconf==2.1.1\nopencv_python\npandas\npytorch_lightning==1.9.3\npyyaml\nsetuptools\nmoviepy\nav\nxformer"
}
]
About this extraction
This page contains the full source code of the AIGODLIKE/ComfyUI-ToonCrafter GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 98 files (933.1 KB), approximately 235.4k tokens, and a symbol index with 1216 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.