main 96024189ecb2 cached
98 files
933.1 KB
235.4k tokens
1216 symbols
1 requests
Download .txt
Showing preview only (975K chars total). Download the full file or copy to clipboard to get everything.
Repository: AIGODLIKE/ComfyUI-ToonCrafter
Branch: main
Commit: 96024189ecb2
Files: 98
Total size: 933.1 KB

Directory structure:
gitextract_h58eh22a/

├── .github/
│   └── workflows/
│       └── publish.yml
├── .gitignore
├── LICENSE
├── ToonCrafter/
│   ├── .gitignore
│   ├── LICENSE
│   ├── README.md
│   ├── __init__.py
│   ├── cldm/
│   │   ├── cldm.py
│   │   ├── ddim_hacked.py
│   │   ├── hack.py
│   │   ├── logger.py
│   │   └── model.py
│   ├── configs/
│   │   ├── cldm_v21.yaml
│   │   ├── inference_512_v1.0.yaml
│   │   ├── training_1024_v1.0/
│   │   │   ├── config.yaml
│   │   │   └── run.sh
│   │   └── training_512_v1.0/
│   │       ├── config.yaml
│   │       └── run.sh
│   ├── gradio_app.py
│   ├── ldm/
│   │   ├── data/
│   │   │   ├── __init__.py
│   │   │   └── util.py
│   │   ├── models/
│   │   │   ├── autoencoder.py
│   │   │   └── diffusion/
│   │   │       ├── __init__.py
│   │   │       ├── ddim.py
│   │   │       ├── ddpm.py
│   │   │       ├── dpm_solver/
│   │   │       │   ├── __init__.py
│   │   │       │   ├── dpm_solver.py
│   │   │       │   └── sampler.py
│   │   │       ├── plms.py
│   │   │       └── sampling_util.py
│   │   ├── modules/
│   │   │   ├── attention.py
│   │   │   ├── diffusionmodules/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── model.py
│   │   │   │   ├── openaimodel.py
│   │   │   │   ├── upscaling.py
│   │   │   │   └── util.py
│   │   │   ├── distributions/
│   │   │   │   ├── __init__.py
│   │   │   │   └── distributions.py
│   │   │   ├── ema.py
│   │   │   ├── encoders/
│   │   │   │   ├── __init__.py
│   │   │   │   └── modules.py
│   │   │   ├── image_degradation/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── bsrgan.py
│   │   │   │   ├── bsrgan_light.py
│   │   │   │   └── utils_image.py
│   │   │   └── midas/
│   │   │       ├── __init__.py
│   │   │       ├── api.py
│   │   │       ├── midas/
│   │   │       │   ├── __init__.py
│   │   │       │   ├── base_model.py
│   │   │       │   ├── blocks.py
│   │   │       │   ├── dpt_depth.py
│   │   │       │   ├── midas_net.py
│   │   │       │   ├── midas_net_custom.py
│   │   │       │   ├── transforms.py
│   │   │       │   └── vit.py
│   │   │       └── utils.py
│   │   └── util.py
│   ├── lvdm/
│   │   ├── __init__.py
│   │   ├── basics.py
│   │   ├── common.py
│   │   ├── data/
│   │   │   ├── base.py
│   │   │   └── webvid.py
│   │   ├── distributions.py
│   │   ├── ema.py
│   │   ├── models/
│   │   │   ├── autoencoder.py
│   │   │   ├── autoencoder_dualref.py
│   │   │   ├── ddpm3d.py
│   │   │   ├── samplers/
│   │   │   │   ├── ddim.py
│   │   │   │   └── ddim_multiplecond.py
│   │   │   └── utils_diffusion.py
│   │   └── modules/
│   │       ├── attention.py
│   │       ├── attention_svd.py
│   │       ├── encoders/
│   │       │   ├── condition.py
│   │       │   └── resampler.py
│   │       ├── networks/
│   │       │   ├── ae_modules.py
│   │       │   └── openaimodel3d.py
│   │       └── x_transformer.py
│   ├── main/
│   │   ├── __init__.py
│   │   ├── callbacks.py
│   │   ├── trainer.py
│   │   ├── utils_data.py
│   │   └── utils_train.py
│   ├── prompts/
│   │   └── 512_interp/
│   │       └── prompts.txt
│   ├── requirements.txt
│   ├── scripts/
│   │   ├── evaluation/
│   │   │   ├── ddp_wrapper.py
│   │   │   ├── funcs.py
│   │   │   └── inference.py
│   │   ├── gradio/
│   │   │   ├── i2v_test.py
│   │   │   └── i2v_test_application.py
│   │   └── run.sh
│   └── utils/
│       ├── __init__.py
│       ├── save_video.py
│       └── utils.py
├── __init__.py
├── pre_run.py
├── pyproject.toml
├── readme.md
└── requirements.txt

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/workflows/publish.yml
================================================
name: Publish to Comfy registry
on:
  workflow_dispatch:
  push:
    branches:
      - main
      - master
    paths:
      - "pyproject.toml"

jobs:
  publish-node:
    name: Publish Custom Node to registry
    runs-on: ubuntu-latest
    steps:
      - name: Check out code
        uses: actions/checkout@v4
      - name: Publish Custom Node
        uses: Comfy-Org/publish-node-action@main
        with:
          ## Add your own personal access token to your Github Repository secrets and reference it here.
          personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}


================================================
FILE: .gitignore
================================================
.DS_Store
*pyc
.vscode
__pycache__
*.egg-info

checkpoints
ToonCrafter/checkpoints
results
backup
LOG
/models
ToonCrafter/tmp
Thumbs.db

================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright Tencent

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

================================================
FILE: ToonCrafter/.gitignore
================================================
.DS_Store
*pyc
.vscode
__pycache__
*.egg-info

checkpoints
results
backup
LOG

================================================
FILE: ToonCrafter/LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright Tencent

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

================================================
FILE: ToonCrafter/README.md
================================================
## ___***ToonCrafter: Generative Cartoon Interpolation***___
<!-- ![](./assets/logo_long.png#gh-light-mode-only){: width="50%"} -->
<!-- ![](./assets/logo_long_dark.png#gh-dark-mode-only=100x20) -->
<div align="center">



</div>
 
## 🔆 Introduction

⚠️ Please check our [disclaimer](#disc) first.

🤗 ToonCrafter can interpolate two cartoon images by leveraging the pre-trained image-to-video diffusion priors. Please check our project page and paper for more information. <br>







### 1.1 Showcases (512x320)
<table class="center">
    <tr style="font-weight: bolder;text-align:center;">
        <td>Input starting frame</td>
        <td>Input ending frame</td>
        <td>Generated video</td>
    </tr>
  <tr>
  <td>
    <img src=assets/72109_125.mp4_00-00.png width="250">
  </td>
  <td>
    <img src=assets/72109_125.mp4_00-01.png width="250">
  </td>
  <td>
    <img src=assets/00.gif width="250">
  </td>
  </tr>


   <tr>
  <td>
    <img src=assets/Japan_v2_2_062266_s2_frame1.png width="250">
  </td>
  <td>
    <img src=assets/Japan_v2_2_062266_s2_frame3.png width="250">
  </td>
  <td>
    <img src=assets/03.gif width="250">
  </td>
  </tr>
  <tr>
  <td>
    <img src=assets/Japan_v2_1_070321_s3_frame1.png width="250">
  </td>
  <td>
    <img src=assets/Japan_v2_1_070321_s3_frame3.png width="250">
  </td>
  <td>
    <img src=assets/02.gif width="250">
  </td>
  </tr> 
  <tr>
  <td>
    <img src=assets/74302_1349_frame1.png width="250">
  </td>
  <td>
    <img src=assets/74302_1349_frame3.png width="250">
  </td>
  <td>
    <img src=assets/01.gif width="250">
  </td>
  </tr>
</table>

### 1.2 Sparse sketch guidance
<table class="center">
    <tr style="font-weight: bolder;text-align:center;">
        <td>Input starting frame</td>
        <td>Input ending frame</td>
        <td>Input sketch guidance</td>
        <td>Generated video</td>
    </tr>
  <tr>
  <td>
    <img src=assets/72105_388.mp4_00-00.png width="200">
  </td>
  <td>
    <img src=assets/72105_388.mp4_00-01.png width="200">
  </td>
  <td>
    <img src=assets/06.gif width="200">
  </td>
   <td>
    <img src=assets/07.gif width="200">
  </td>
  </tr>

  <tr>
  <td>
    <img src=assets/72110_255.mp4_00-00.png width="200">
  </td>
  <td>
    <img src=assets/72110_255.mp4_00-01.png width="200">
  </td>
  <td>
    <img src=assets/12.gif width="200">
  </td>
   <td>
    <img src=assets/13.gif width="200">
  </td>
  </tr>


</table>


### 2. Applications
#### 2.1 Cartoon Sketch Interpolation (see project page for more details)
<table class="center">
    <tr style="font-weight: bolder;text-align:center;">
        <td>Input starting frame</td>
        <td>Input ending frame</td>
        <td>Generated video</td>
    </tr>

  <tr>
  <td>
    <img src=assets/frame0001_10.png width="250">
  </td>
  <td>
    <img src=assets/frame0016_10.png width="250">
  </td>
  <td>
    <img src=assets/10.gif width="250">
  </td>
  </tr>


   <tr>
  <td>
    <img src=assets/frame0001_11.png width="250">
  </td>
  <td>
    <img src=assets/frame0016_11.png width="250">
  </td>
  <td>
    <img src=assets/11.gif width="250">
  </td>
  </tr>

</table>


#### 2.2 Reference-based Sketch Colorization
<table class="center">
    <tr style="font-weight: bolder;text-align:center;">
        <td>Input sketch</td>
        <td>Input reference</td>
        <td>Colorization results</td>
    </tr>
    
  <tr>
  <td>
    <img src=assets/04.gif width="250">
  </td>
  <td>
    <img src=assets/frame0001_05.png width="250">
  </td>
  <td>
    <img src=assets/05.gif width="250">
  </td>
  </tr>


   <tr>
  <td>
    <img src=assets/08.gif width="250">
  </td>
  <td>
    <img src=assets/frame0001_09.png width="250">
  </td>
  <td>
    <img src=assets/09.gif width="250">
  </td>
  </tr>

</table>







## 📝 Changelog
- [ ] Add sketch control and colorization function.
- __[2024.05.29]__: 🔥🔥 Release code and model weights.
- __[2024.05.28]__: Launch the project page and update the arXiv preprint.
<br>


## 🧰 Models

|Model|Resolution|GPU Mem. & Inference Time (A100, ddim 50steps)|Checkpoint|
|:---------|:---------|:--------|:--------|
|ToonCrafter_512|320x512| TBD (`perframe_ae=True`)|[Hugging Face](https://huggingface.co/Doubiiu/ToonCrafter/blob/main/model.ckpt)|


Currently, our ToonCrafter can support generating videos of up to 16 frames with a resolution of 512x320. The inference time can be reduced by using fewer DDIM steps.



## ⚙️ Setup

### Install Environment via Anaconda (Recommended)
```bash
conda create -n tooncrafter python=3.8.5
conda activate tooncrafter
pip install -r requirements.txt
```


## 💫 Inference
### 1. Command line

Download pretrained ToonCrafter_512 and put the `model.ckpt` in `checkpoints/tooncrafter_512_interp_v1/model.ckpt`.
```bash
  sh scripts/run.sh
```


### 2. Local Gradio demo

Download the pretrained model and put it in the corresponding directory according to the previous guidelines.
```bash
  python gradio_app.py 
```






<!-- ## 🤝 Community Support -->



<a name="disc"></a>
## 📢 Disclaimer
Calm down. Our framework opens up the era of generative cartoon interpolation, but due to the variaity of generative video prior, the success rate is not guaranteed.

⚠️This is an open-source research exploration, instead of commercial products. It can't meet all your expectations.

This project strives to impact the domain of AI-driven video generation positively. Users are granted the freedom to create videos using this tool, but they are expected to comply with local laws and utilize it responsibly. The developers do not assume any responsibility for potential misuse by users.
****

================================================
FILE: ToonCrafter/__init__.py
================================================
import sys
from pathlib import Path
sys.path.append(Path(__file__).parent.as_posix())


================================================
FILE: ToonCrafter/cldm/cldm.py
================================================
import einops
import torch
import torch as th
import torch.nn as nn

from ToonCrafter.ldm.modules.diffusionmodules.util import (
    conv_nd,
    linear,
    zero_module,
    timestep_embedding,
)

from einops import rearrange, repeat
from torchvision.utils import make_grid
from ToonCrafter.ldm.modules.attention import SpatialTransformer
from ToonCrafter.ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
from lvdm.modules.networks.openaimodel3d import UNetModel
from ToonCrafter.ldm.models.diffusion.ddpm import LatentDiffusion
from ToonCrafter.ldm.util import log_txt_as_img, exists, instantiate_from_config
from ToonCrafter.ldm.models.diffusion.ddim import DDIMSampler


class ControlledUnetModel(UNetModel):
    def forward(self, x, timesteps, context=None, features_adapter=None, fs=None, control = None, **kwargs):
        b,_,t,_,_ = x.shape
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).type(x.dtype)
        emb = self.time_embed(t_emb)
        ## repeat t times for context [(b t) 77 768] & time embedding
        ## check if we use per-frame image conditioning
        _, l_context, _ = context.shape
        if l_context == 77 + t*16: ## !!! HARD CODE here
            context_text, context_img = context[:,:77,:], context[:,77:,:]
            context_text = context_text.repeat_interleave(repeats=t, dim=0)
            context_img = rearrange(context_img, 'b (t l) c -> (b t) l c', t=t)
            context = torch.cat([context_text, context_img], dim=1)
        else:
            context = context.repeat_interleave(repeats=t, dim=0)
        emb = emb.repeat_interleave(repeats=t, dim=0)
        
        ## always in shape (b t) c h w, except for temporal layer
        x = rearrange(x, 'b c t h w -> (b t) c h w')

        ## combine emb
        if self.fs_condition:
            if fs is None:
                fs = torch.tensor(
                    [self.default_fs] * b, dtype=torch.long, device=x.device)
            fs_emb = timestep_embedding(fs, self.model_channels, repeat_only=False).type(x.dtype)

            fs_embed = self.fps_embedding(fs_emb)
            fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
            emb = emb + fs_embed

        h = x.type(self.dtype)
        adapter_idx = 0
        hs = []
        with torch.no_grad():
            for id, module in enumerate(self.input_blocks):
                h = module(h, emb, context=context, batch_size=b)
                if id ==0 and self.addition_attention:
                    h = self.init_attn(h, emb, context=context, batch_size=b)
                ## plug-in adapter features
                if ((id+1)%3 == 0) and features_adapter is not None:
                    h = h + features_adapter[adapter_idx]
                    adapter_idx += 1
                hs.append(h)
            if features_adapter is not None:
                assert len(features_adapter)==adapter_idx, 'Wrong features_adapter'

            h = self.middle_block(h, emb, context=context, batch_size=b)
        
        if control is not None:
            h += control.pop()
        
        for module in self.output_blocks:
            if control is None:
                h = torch.cat([h, hs.pop()], dim=1)
            else:
                h = torch.cat([h, hs.pop() + control.pop()], dim=1)
            h = module(h, emb, context=context, batch_size=b)

        h = h.type(x.dtype)
        y = self.out(h)
        
        # reshape back to (b c t h w)
        y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
        return y


class ControlNet(nn.Module):
    def __init__(
            self,
            image_size,
            in_channels,
            model_channels,
            hint_channels,
            num_res_blocks,
            attention_resolutions,
            dropout=0,
            channel_mult=(1, 2, 4, 8),
            conv_resample=True,
            dims=2,
            use_checkpoint=False,
            use_fp16=False,
            num_heads=-1,
            num_head_channels=-1,
            num_heads_upsample=-1,
            use_scale_shift_norm=False,
            resblock_updown=False,
            use_new_attention_order=False,
            use_spatial_transformer=False,  # custom transformer support
            transformer_depth=1,  # custom transformer support
            context_dim=None,  # custom transformer support
            n_embed=None,  # custom support for prediction of discrete ids into codebook of first stage vq model
            legacy=True,
            disable_self_attentions=None,
            num_attention_blocks=None,
            disable_middle_self_attn=False,
            use_linear_in_transformer=False,
    ):
        super().__init__()
        if use_spatial_transformer:
            assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'

        if context_dim is not None:
            assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
            from omegaconf.listconfig import ListConfig
            if type(context_dim) == ListConfig:
                context_dim = list(context_dim)

        if num_heads_upsample == -1:
            num_heads_upsample = num_heads

        if num_heads == -1:
            assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'

        if num_head_channels == -1:
            assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'

        self.dims = dims
        self.image_size = image_size
        self.in_channels = in_channels
        self.model_channels = model_channels
        if isinstance(num_res_blocks, int):
            self.num_res_blocks = len(channel_mult) * [num_res_blocks]
        else:
            if len(num_res_blocks) != len(channel_mult):
                raise ValueError("provide num_res_blocks either as an int (globally constant) or "
                                 "as a list/tuple (per-level) with the same length as channel_mult")
            self.num_res_blocks = num_res_blocks
        if disable_self_attentions is not None:
            # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
            assert len(disable_self_attentions) == len(channel_mult)
        if num_attention_blocks is not None:
            assert len(num_attention_blocks) == len(self.num_res_blocks)
            assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
            print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
                  f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
                  f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
                  f"attention will still not be set.")

        self.attention_resolutions = attention_resolutions
        self.dropout = dropout
        self.channel_mult = channel_mult
        self.conv_resample = conv_resample
        self.use_checkpoint = use_checkpoint
        self.dtype = th.float16 if use_fp16 else th.float32
        self.num_heads = num_heads
        self.num_head_channels = num_head_channels
        self.num_heads_upsample = num_heads_upsample
        self.predict_codebook_ids = n_embed is not None

        time_embed_dim = model_channels * 4
        self.time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            nn.SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )

        self.input_blocks = nn.ModuleList(
            [
                TimestepEmbedSequential(
                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
                )
            ]
        )
        self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])

        self.input_hint_block = TimestepEmbedSequential(
            conv_nd(dims, hint_channels, 16, 3, padding=1),
            nn.SiLU(),
            conv_nd(dims, 16, 16, 3, padding=1),
            nn.SiLU(),
            conv_nd(dims, 16, 32, 3, padding=1, stride=2),
            nn.SiLU(),
            conv_nd(dims, 32, 32, 3, padding=1),
            nn.SiLU(),
            conv_nd(dims, 32, 96, 3, padding=1, stride=2),
            nn.SiLU(),
            conv_nd(dims, 96, 96, 3, padding=1),
            nn.SiLU(),
            conv_nd(dims, 96, 256, 3, padding=1, stride=2),
            nn.SiLU(),
            zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
        )

        self._feature_size = model_channels
        input_block_chans = [model_channels]
        ch = model_channels
        ds = 1
        for level, mult in enumerate(channel_mult):
            for nr in range(self.num_res_blocks[level]):
                layers = [
                    ResBlock(
                        ch,
                        time_embed_dim,
                        dropout,
                        out_channels=mult * model_channels,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = mult * model_channels
                if ds in attention_resolutions:
                    if num_head_channels == -1:
                        dim_head = ch // num_heads
                    else:
                        num_heads = ch // num_head_channels
                        dim_head = num_head_channels
                    if legacy:
                        # num_heads = 1
                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
                    if exists(disable_self_attentions):
                        disabled_sa = disable_self_attentions[level]
                    else:
                        disabled_sa = False

                    if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
                        layers.append(
                            AttentionBlock(
                                ch,
                                use_checkpoint=use_checkpoint,
                                num_heads=num_heads,
                                num_head_channels=dim_head,
                                use_new_attention_order=use_new_attention_order,
                            ) if not use_spatial_transformer else SpatialTransformer(
                                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
                                disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
                                use_checkpoint=use_checkpoint
                            )
                        )
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                self.zero_convs.append(self.make_zero_conv(ch))
                self._feature_size += ch
                input_block_chans.append(ch)
            if level != len(channel_mult) - 1:
                out_ch = ch
                self.input_blocks.append(
                    TimestepEmbedSequential(
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            down=True,
                        )
                        if resblock_updown
                        else Downsample(
                            ch, conv_resample, dims=dims, out_channels=out_ch
                        )
                    )
                )
                ch = out_ch
                input_block_chans.append(ch)
                self.zero_convs.append(self.make_zero_conv(ch))
                ds *= 2
                self._feature_size += ch

        if num_head_channels == -1:
            dim_head = ch // num_heads
        else:
            num_heads = ch // num_head_channels
            dim_head = num_head_channels
        if legacy:
            # num_heads = 1
            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
        self.middle_block = TimestepEmbedSequential(
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
            AttentionBlock(
                ch,
                use_checkpoint=use_checkpoint,
                num_heads=num_heads,
                num_head_channels=dim_head,
                use_new_attention_order=use_new_attention_order,
            ) if not use_spatial_transformer else SpatialTransformer(  # always uses a self-attn
                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
                disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
                use_checkpoint=use_checkpoint
            ),
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
        )
        self.middle_block_out = self.make_zero_conv(ch)
        self._feature_size += ch

    def make_zero_conv(self, channels):
        return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))

    def forward(self, x, hint, timesteps, context, **kwargs):
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
        emb = self.time_embed(t_emb)

        guided_hint = self.input_hint_block(hint, emb, context)

        outs = []

        h = x.type(self.dtype)
     
        for module, zero_conv in zip(self.input_blocks, self.zero_convs):
            if guided_hint is not None:
                h = module(h, emb, context)
                h += guided_hint
                guided_hint = None
            else:
                h = module(h, emb, context)
            outs.append(zero_conv(h, emb, context, True))

        h = self.middle_block(h, emb, context)
        outs.append(self.middle_block_out(h, emb, context))

        return outs


class ControlLDM(LatentDiffusion):

    def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.control_model = instantiate_from_config(control_stage_config)
        self.control_key = control_key
        self.only_mid_control = only_mid_control
        self.control_scales = [1.0] * 13

    @torch.no_grad()
    def get_input(self, batch, k, bs=None, *args, **kwargs):
        x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
        control = batch[self.control_key]
        if bs is not None:
            control = control[:bs]
        control = control.to(self.device)
        control = einops.rearrange(control, 'b h w c -> b c h w')
        control = control.to(memory_format=torch.contiguous_format).float()
        return x, dict(c_crossattn=[c], c_concat=[control])

    def apply_model(self, x_noisy, t, cond, *args, **kwargs):
        assert isinstance(cond, dict)
        diffusion_model = self.model.diffusion_model

        cond_txt = torch.cat(cond['c_crossattn'], 1)

        if cond['c_concat'] is None:
            eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
        else:
            control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
            control = [c * scale for c, scale in zip(control, self.control_scales)]
            eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)

        return eps

    @torch.no_grad()
    def get_unconditional_conditioning(self, N):
        return self.get_learned_conditioning([""] * N)

    @torch.no_grad()
    def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
                   quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
                   plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
                   use_ema_scope=True,
                   **kwargs):
        use_ddim = ddim_steps is not None

        log = dict()
        z, c = self.get_input(batch, self.first_stage_key, bs=N)
        c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
        N = min(z.shape[0], N)
        n_row = min(z.shape[0], n_row)
        log["reconstruction"] = self.decode_first_stage(z)
        log["control"] = c_cat * 2.0 - 1.0
        log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)

        if plot_diffusion_rows:
            # get diffusion row
            diffusion_row = list()
            z_start = z[:n_row]
            for t in range(self.num_timesteps):
                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
                    t = t.to(self.device).long()
                    noise = torch.randn_like(z_start)
                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
                    diffusion_row.append(self.decode_first_stage(z_noisy))

            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
            log["diffusion_row"] = diffusion_grid

        if sample:
            # get denoise row
            samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
                                                     batch_size=N, ddim=use_ddim,
                                                     ddim_steps=ddim_steps, eta=ddim_eta)
            x_samples = self.decode_first_stage(samples)
            log["samples"] = x_samples
            if plot_denoise_rows:
                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
                log["denoise_row"] = denoise_grid

        if unconditional_guidance_scale > 1.0:
            uc_cross = self.get_unconditional_conditioning(N)
            uc_cat = c_cat  # torch.zeros_like(c_cat)
            uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
            samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
                                             batch_size=N, ddim=use_ddim,
                                             ddim_steps=ddim_steps, eta=ddim_eta,
                                             unconditional_guidance_scale=unconditional_guidance_scale,
                                             unconditional_conditioning=uc_full,
                                             )
            x_samples_cfg = self.decode_first_stage(samples_cfg)
            log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg

        return log

    @torch.no_grad()
    def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
        ddim_sampler = DDIMSampler(self)
        b, c, h, w = cond["c_concat"][0].shape
        shape = (self.channels, h // 8, w // 8)
        samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
        return samples, intermediates

    def configure_optimizers(self):
        lr = self.learning_rate
        params = list(self.control_model.parameters())
        if not self.sd_locked:
            params += list(self.model.diffusion_model.output_blocks.parameters())
            params += list(self.model.diffusion_model.out.parameters())
        opt = torch.optim.AdamW(params, lr=lr)
        return opt

    def low_vram_shift(self, is_diffusing):
        if is_diffusing:
            self.model = self.model.cuda()
            self.control_model = self.control_model.cuda()
            self.first_stage_model = self.first_stage_model.cpu()
            self.cond_stage_model = self.cond_stage_model.cpu()
        else:
            self.model = self.model.cpu()
            self.control_model = self.control_model.cpu()
            self.first_stage_model = self.first_stage_model.cuda()
            self.cond_stage_model = self.cond_stage_model.cuda()


================================================
FILE: ToonCrafter/cldm/ddim_hacked.py
================================================
"""SAMPLING ONLY."""

import torch
import numpy as np
from tqdm import tqdm

from ToonCrafter.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor


class DDIMSampler(object):
    def __init__(self, model, schedule="linear", **kwargs):
        super().__init__()
        self.model = model
        self.ddpm_num_timesteps = model.num_timesteps
        self.schedule = schedule

    def register_buffer(self, name, attr):
        if type(attr) == torch.Tensor:
            if attr.device != torch.device("cuda"):
                attr = attr.to(torch.device("cuda"))
        setattr(self, name, attr)

    def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
        self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
                                                  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
        alphas_cumprod = self.model.alphas_cumprod
        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)

        self.register_buffer('betas', to_torch(self.model.betas))
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
        self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))

        # ddim sampling parameters
        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
                                                                                   ddim_timesteps=self.ddim_timesteps,
                                                                                   eta=ddim_eta,verbose=verbose)
        self.register_buffer('ddim_sigmas', ddim_sigmas)
        self.register_buffer('ddim_alphas', ddim_alphas)
        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
            (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
                        1 - self.alphas_cumprod / self.alphas_cumprod_prev))
        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)

    @torch.no_grad()
    def sample(self,
               S,
               batch_size,
               shape,
               conditioning=None,
               callback=None,
               normals_sequence=None,
               img_callback=None,
               quantize_x0=False,
               eta=0.,
               mask=None,
               x0=None,
               temperature=1.,
               noise_dropout=0.,
               score_corrector=None,
               corrector_kwargs=None,
               verbose=True,
               x_T=None,
               log_every_t=100,
               unconditional_guidance_scale=1.,
               unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
               dynamic_threshold=None,
               ucg_schedule=None,
               **kwargs
               ):
        if conditioning is not None:
            if isinstance(conditioning, dict):
                ctmp = conditioning[list(conditioning.keys())[0]]
                while isinstance(ctmp, list): ctmp = ctmp[0]
                cbs = ctmp.shape[0]
                if cbs != batch_size:
                    print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")

            elif isinstance(conditioning, list):
                for ctmp in conditioning:
                    if ctmp.shape[0] != batch_size:
                        print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")

            else:
                if conditioning.shape[0] != batch_size:
                    print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")

        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
        # sampling
        C, H, W = shape
        size = (batch_size, C, H, W)
        print(f'Data shape for DDIM sampling is {size}, eta {eta}')

        samples, intermediates = self.ddim_sampling(conditioning, size,
                                                    callback=callback,
                                                    img_callback=img_callback,
                                                    quantize_denoised=quantize_x0,
                                                    mask=mask, x0=x0,
                                                    ddim_use_original_steps=False,
                                                    noise_dropout=noise_dropout,
                                                    temperature=temperature,
                                                    score_corrector=score_corrector,
                                                    corrector_kwargs=corrector_kwargs,
                                                    x_T=x_T,
                                                    log_every_t=log_every_t,
                                                    unconditional_guidance_scale=unconditional_guidance_scale,
                                                    unconditional_conditioning=unconditional_conditioning,
                                                    dynamic_threshold=dynamic_threshold,
                                                    ucg_schedule=ucg_schedule
                                                    )
        return samples, intermediates

    @torch.no_grad()
    def ddim_sampling(self, cond, shape,
                      x_T=None, ddim_use_original_steps=False,
                      callback=None, timesteps=None, quantize_denoised=False,
                      mask=None, x0=None, img_callback=None, log_every_t=100,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                      unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
                      ucg_schedule=None):
        device = self.model.betas.device
        b = shape[0]
        if x_T is None:
            img = torch.randn(shape, device=device)
        else:
            img = x_T

        if timesteps is None:
            timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
        elif timesteps is not None and not ddim_use_original_steps:
            subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
            timesteps = self.ddim_timesteps[:subset_end]

        intermediates = {'x_inter': [img], 'pred_x0': [img]}
        time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
        print(f"Running DDIM Sampling with {total_steps} timesteps")

        iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)

        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            ts = torch.full((b,), step, device=device, dtype=torch.long)

            if mask is not None:
                assert x0 is not None
                img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
                img = img_orig * mask + (1. - mask) * img

            if ucg_schedule is not None:
                assert len(ucg_schedule) == len(time_range)
                unconditional_guidance_scale = ucg_schedule[i]

            outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
                                      quantize_denoised=quantize_denoised, temperature=temperature,
                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
                                      corrector_kwargs=corrector_kwargs,
                                      unconditional_guidance_scale=unconditional_guidance_scale,
                                      unconditional_conditioning=unconditional_conditioning,
                                      dynamic_threshold=dynamic_threshold)
            img, pred_x0 = outs
            if callback: callback(i)
            if img_callback: img_callback(pred_x0, i)

            if index % log_every_t == 0 or index == total_steps - 1:
                intermediates['x_inter'].append(img)
                intermediates['pred_x0'].append(pred_x0)

        return img, intermediates

    @torch.no_grad()
    def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                      unconditional_guidance_scale=1., unconditional_conditioning=None,
                      dynamic_threshold=None):
        b, *_, device = *x.shape, x.device

        if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
            model_output = self.model.apply_model(x, t, c)
        else:
            model_t = self.model.apply_model(x, t, c)
            model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
            model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)

        if self.model.parameterization == "v":
            e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
        else:
            e_t = model_output

        if score_corrector is not None:
            assert self.model.parameterization == "eps", 'not implemented'
            e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)

        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
        alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
        sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
        sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
        # select parameters corresponding to the currently considered timestep
        a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
        a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
        sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
        sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)

        # current prediction for x_0
        if self.model.parameterization != "v":
            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
        else:
            pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)

        if quantize_denoised:
            pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)

        if dynamic_threshold is not None:
            raise NotImplementedError()

        # direction pointing to x_t
        dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
        if noise_dropout > 0.:
            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
        return x_prev, pred_x0

    @torch.no_grad()
    def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
               unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
        timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
        num_reference_steps = timesteps.shape[0]

        assert t_enc <= num_reference_steps
        num_steps = t_enc

        if use_original_steps:
            alphas_next = self.alphas_cumprod[:num_steps]
            alphas = self.alphas_cumprod_prev[:num_steps]
        else:
            alphas_next = self.ddim_alphas[:num_steps]
            alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])

        x_next = x0
        intermediates = []
        inter_steps = []
        for i in tqdm(range(num_steps), desc='Encoding Image'):
            t = torch.full((x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long)
            if unconditional_guidance_scale == 1.:
                noise_pred = self.model.apply_model(x_next, t, c)
            else:
                assert unconditional_conditioning is not None
                e_t_uncond, noise_pred = torch.chunk(
                    self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
                                           torch.cat((unconditional_conditioning, c))), 2)
                noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)

            xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
            weighted_noise_pred = alphas_next[i].sqrt() * (
                    (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
            x_next = xt_weighted + weighted_noise_pred
            if return_intermediates and i % (
                    num_steps // return_intermediates) == 0 and i < num_steps - 1:
                intermediates.append(x_next)
                inter_steps.append(i)
            elif return_intermediates and i >= num_steps - 2:
                intermediates.append(x_next)
                inter_steps.append(i)
            if callback: callback(i)

        out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
        if return_intermediates:
            out.update({'intermediates': intermediates})
        return x_next, out

    @torch.no_grad()
    def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
        # fast, but does not allow for exact reconstruction
        # t serves as an index to gather the correct alphas
        if use_original_steps:
            sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
            sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
        else:
            sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
            sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas

        if noise is None:
            noise = torch.randn_like(x0)
        return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
                extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)

    @torch.no_grad()
    def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
               use_original_steps=False, callback=None):

        timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
        timesteps = timesteps[:t_start]

        time_range = np.flip(timesteps)
        total_steps = timesteps.shape[0]
        print(f"Running DDIM Sampling with {total_steps} timesteps")

        iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
        x_dec = x_latent
        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
            x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
                                          unconditional_guidance_scale=unconditional_guidance_scale,
                                          unconditional_conditioning=unconditional_conditioning)
            if callback: callback(i)
        return x_dec


================================================
FILE: ToonCrafter/cldm/hack.py
================================================
import torch
import einops

import ldm.modules.encoders.modules
import ldm.modules.attention

from transformers import logging
from ToonCrafter.ldm.modules.attention import default


def disable_verbosity():
    logging.set_verbosity_error()
    print('logging improved.')
    return


def enable_sliced_attention():
    ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
    print('Enabled sliced_attention.')
    return


def hack_everything(clip_skip=0):
    disable_verbosity()
    ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
    ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
    print('Enabled clip hacks.')
    return


# Written by Lvmin
def _hacked_clip_forward(self, text):
    PAD = self.tokenizer.pad_token_id
    EOS = self.tokenizer.eos_token_id
    BOS = self.tokenizer.bos_token_id

    def tokenize(t):
        return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]

    def transformer_encode(t):
        if self.clip_skip > 1:
            rt = self.transformer(input_ids=t, output_hidden_states=True)
            return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
        else:
            return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state

    def split(x):
        return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]

    def pad(x, p, i):
        return x[:i] if len(x) >= i else x + [p] * (i - len(x))

    raw_tokens_list = tokenize(text)
    tokens_list = []

    for raw_tokens in raw_tokens_list:
        raw_tokens_123 = split(raw_tokens)
        raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
        raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
        tokens_list.append(raw_tokens_123)

    tokens_list = torch.IntTensor(tokens_list).to(self.device)

    feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
    y = transformer_encode(feed)
    z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)

    return z


# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
    h = self.heads

    q = self.to_q(x)
    context = default(context, x)
    k = self.to_k(context)
    v = self.to_v(context)
    del context, x

    q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

    limit = k.shape[0]
    att_step = 1
    q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
    k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
    v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))

    q_chunks.reverse()
    k_chunks.reverse()
    v_chunks.reverse()
    sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
    del k, q, v
    for i in range(0, limit, att_step):
        q_buffer = q_chunks.pop()
        k_buffer = k_chunks.pop()
        v_buffer = v_chunks.pop()
        sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale

        del k_buffer, q_buffer
        # attention, what we cannot get enough of, by chunks

        sim_buffer = sim_buffer.softmax(dim=-1)

        sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
        del v_buffer
        sim[i:i + att_step, :, :] = sim_buffer

        del sim_buffer
    sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
    return self.to_out(sim)


================================================
FILE: ToonCrafter/cldm/logger.py
================================================
import os

import numpy as np
import torch
import torchvision
from PIL import Image
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.distributed import rank_zero_only


class ImageLogger(Callback):
    def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
                 rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
                 log_images_kwargs=None):
        super().__init__()
        self.rescale = rescale
        self.batch_freq = batch_frequency
        self.max_images = max_images
        if not increase_log_steps:
            self.log_steps = [self.batch_freq]
        self.clamp = clamp
        self.disabled = disabled
        self.log_on_batch_idx = log_on_batch_idx
        self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
        self.log_first_step = log_first_step

    @rank_zero_only
    def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
        root = os.path.join(save_dir, "image_log", split)
        for k in images:
            grid = torchvision.utils.make_grid(images[k], nrow=4)
            if self.rescale:
                grid = (grid + 1.0) / 2.0  # -1,1 -> 0,1; c,h,w
            grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
            grid = grid.numpy()
            grid = (grid * 255).astype(np.uint8)
            filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
            path = os.path.join(root, filename)
            os.makedirs(os.path.split(path)[0], exist_ok=True)
            Image.fromarray(grid).save(path)

    def log_img(self, pl_module, batch, batch_idx, split="train"):
        check_idx = batch_idx  # if self.log_on_batch_idx else pl_module.global_step
        if (self.check_frequency(check_idx) and  # batch_idx % self.batch_freq == 0
                hasattr(pl_module, "log_images") and
                callable(pl_module.log_images) and
                self.max_images > 0):
            logger = type(pl_module.logger)

            is_train = pl_module.training
            if is_train:
                pl_module.eval()

            with torch.no_grad():
                images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)

            for k in images:
                N = min(images[k].shape[0], self.max_images)
                images[k] = images[k][:N]
                if isinstance(images[k], torch.Tensor):
                    images[k] = images[k].detach().cpu()
                    if self.clamp:
                        images[k] = torch.clamp(images[k], -1., 1.)

            self.log_local(pl_module.logger.save_dir, split, images,
                           pl_module.global_step, pl_module.current_epoch, batch_idx)

            if is_train:
                pl_module.train()

    def check_frequency(self, check_idx):
        return check_idx % self.batch_freq == 0

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        if not self.disabled:
            self.log_img(pl_module, batch, batch_idx, split="train")


================================================
FILE: ToonCrafter/cldm/model.py
================================================
import os
import torch

from omegaconf import OmegaConf
from comfy.ldm.util import instantiate_from_config


def get_state_dict(d):
    return d.get('state_dict', d)


def load_state_dict(ckpt_path, location='cpu'):
    _, extension = os.path.splitext(ckpt_path)
    if extension.lower() == ".safetensors":
        import safetensors.torch
        state_dict = safetensors.torch.load_file(ckpt_path, device=location)
    else:
        state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
    state_dict = get_state_dict(state_dict)
    print(f'Loaded state_dict from [{ckpt_path}]')
    return state_dict


def create_model(config_path):
    config = OmegaConf.load(config_path)
    model = instantiate_from_config(config.model).cpu()
    print(f'Loaded model config from [{config_path}]')
    return model


================================================
FILE: ToonCrafter/configs/cldm_v21.yaml
================================================
control_stage_config:
  target: ToonCrafter.cldm.cldm.ControlNet
  params:
    use_checkpoint: True
    image_size: 32 # unused
    in_channels: 4
    hint_channels: 1
    model_channels: 320
    attention_resolutions: [ 4, 2, 1 ]
    num_res_blocks: 2
    channel_mult: [ 1, 2, 4, 4 ]
    num_head_channels: 64 # need to fix for flash-attn
    use_spatial_transformer: True
    use_linear_in_transformer: True
    transformer_depth: 1
    context_dim: 1024
    legacy: False


================================================
FILE: ToonCrafter/configs/inference_512_v1.0.yaml
================================================
model:
  target: lvdm.models.ddpm3d.LatentVisualDiffusion
  params:
    rescale_betas_zero_snr: True
    parameterization: "v"
    linear_start: 0.00085
    linear_end: 0.012
    num_timesteps_cond: 1
    timesteps: 1000
    first_stage_key: video
    cond_stage_key: caption
    cond_stage_trainable: False
    conditioning_key: hybrid
    image_size: [40, 64]
    channels: 4
    scale_by_std: False
    scale_factor: 0.18215
    use_ema: False
    uncond_type: 'empty_seq'
    use_dynamic_rescale: true
    base_scale: 0.7
    fps_condition_type: 'fps'
    perframe_ae: True
    loop_video: true
    unet_config:
      target: lvdm.modules.networks.openaimodel3d.UNetModel
      params:
        in_channels: 8
        out_channels: 4
        model_channels: 320
        attention_resolutions:
        - 4
        - 2
        - 1
        num_res_blocks: 2
        channel_mult:
        - 1
        - 2
        - 4
        - 4
        dropout: 0.1
        num_head_channels: 64
        transformer_depth: 1
        context_dim: 1024
        use_linear: true
        use_checkpoint: True
        temporal_conv: True
        temporal_attention: True
        temporal_selfatt_only: true
        use_relative_position: false
        use_causal_attention: False
        temporal_length: 16
        addition_attention: true
        image_cross_attention: true
        default_fs: 24
        fs_condition: true

    first_stage_config:
      target: lvdm.models.autoencoder.AutoencoderKL_Dualref
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          double_z: True
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
      params:
        freeze: true
        layer: "penultimate"

    img_cond_stage_config:
      target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
      params:
        freeze: true
    
    image_proj_stage_config:
      target: lvdm.modules.encoders.resampler.Resampler
      params:
        dim: 1024
        depth: 4
        dim_head: 64
        heads: 12
        num_queries: 16
        embedding_dim: 1280
        output_dim: 1024
        ff_mult: 4
        video_length: 16


================================================
FILE: ToonCrafter/configs/training_1024_v1.0/config.yaml
================================================
model:
  pretrained_checkpoint: checkpoints/dynamicrafter_1024_v1/model.ckpt
  base_learning_rate: 1.0e-05
  scale_lr: False
  target: lvdm.models.ddpm3d.LatentVisualDiffusion
  params:
    rescale_betas_zero_snr: True
    parameterization: "v"
    linear_start: 0.00085
    linear_end: 0.012
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: video
    cond_stage_key: caption
    cond_stage_trainable: False
    image_proj_model_trainable: True
    conditioning_key: hybrid
    image_size: [72, 128]
    channels: 4
    scale_by_std: False
    scale_factor: 0.18215
    use_ema: False
    uncond_prob: 0.05
    uncond_type: 'empty_seq'
    rand_cond_frame: true
    use_dynamic_rescale: true
    base_scale: 0.3
    fps_condition_type: 'fps'
    perframe_ae: True

    unet_config:
      target: lvdm.modules.networks.openaimodel3d.UNetModel
      params:
        in_channels: 8
        out_channels: 4
        model_channels: 320
        attention_resolutions:
        - 4
        - 2
        - 1
        num_res_blocks: 2
        channel_mult:
        - 1
        - 2
        - 4
        - 4
        dropout: 0.1
        num_head_channels: 64
        transformer_depth: 1
        context_dim: 1024
        use_linear: true
        use_checkpoint: True
        temporal_conv: True
        temporal_attention: True
        temporal_selfatt_only: true
        use_relative_position: false
        use_causal_attention: False
        temporal_length: 16
        addition_attention: true
        image_cross_attention: true
        default_fs: 10
        fs_condition: true

    first_stage_config:
      target: lvdm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          double_z: True
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
      params:
        freeze: true
        layer: "penultimate"

    img_cond_stage_config:
      target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
      params:
        freeze: true
    
    image_proj_stage_config:
      target: lvdm.modules.encoders.resampler.Resampler
      params:
        dim: 1024
        depth: 4
        dim_head: 64
        heads: 12
        num_queries: 16
        embedding_dim: 1280
        output_dim: 1024
        ff_mult: 4
        video_length: 16

data:
  target: utils_data.DataModuleFromConfig
  params:
    batch_size: 1
    num_workers: 12
    wrap: false
    train:
      target: lvdm.data.webvid.WebVid
      params:
        data_dir: <WebVid10M DATA>
        meta_path: <.csv FILE>
        video_length: 16
        frame_stride: 6
        load_raw_resolution: true
        resolution: [576, 1024]
        spatial_transform: resize_center_crop
        random_fs: true  ## if true, we uniformly sample fs with max_fs=frame_stride (above)

lightning:
  precision: 16
  # strategy: deepspeed_stage_2
  trainer:
    benchmark: True
    accumulate_grad_batches: 2
    max_steps: 100000
    # logger
    log_every_n_steps: 50
    # val
    val_check_interval: 0.5
    gradient_clip_algorithm: 'norm'
    gradient_clip_val: 0.5
  callbacks:
    model_checkpoint:
      target: pytorch_lightning.callbacks.ModelCheckpoint
      params:
        every_n_train_steps: 9000 #1000
        filename: "{epoch}-{step}"
        save_weights_only: True
    metrics_over_trainsteps_checkpoint:
      target: pytorch_lightning.callbacks.ModelCheckpoint
      params:
        filename: '{epoch}-{step}'
        save_weights_only: True
        every_n_train_steps: 10000 #20000 # 3s/step*2w=
    batch_logger:
      target: callbacks.ImageLogger
      params:
        batch_frequency: 500
        to_local: False
        max_images: 8
        log_images_kwargs:
          ddim_steps: 50
          unconditional_guidance_scale: 7.5
          timestep_spacing: uniform_trailing
          guidance_rescale: 0.7

================================================
FILE: ToonCrafter/configs/training_1024_v1.0/run.sh
================================================
# NCCL configuration
# export NCCL_DEBUG=INFO
# export NCCL_IB_DISABLE=0
# export NCCL_IB_GID_INDEX=3
# export NCCL_NET_GDR_LEVEL=3
# export NCCL_TOPO_FILE=/tmp/topo.txt

# args
name="training_1024_v1.0"
config_file=configs/${name}/config.yaml

# save root dir for logs, checkpoints, tensorboard record, etc.
save_root="<YOUR_SAVE_ROOT_DIR>"

mkdir -p $save_root/$name

## run
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \
--nproc_per_node=$HOST_GPU_NUM --nnodes=1 --master_addr=127.0.0.1 --master_port=12352 --node_rank=0 \
./main/trainer.py \
--base $config_file \
--train \
--name $name \
--logdir $save_root \
--devices $HOST_GPU_NUM \
lightning.trainer.num_nodes=1

## debugging
# CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch \
# --nproc_per_node=4 --nnodes=1 --master_addr=127.0.0.1 --master_port=12352 --node_rank=0 \
# ./main/trainer.py \
# --base $config_file \
# --train \
# --name $name \
# --logdir $save_root \
# --devices 4 \
# lightning.trainer.num_nodes=1

================================================
FILE: ToonCrafter/configs/training_512_v1.0/config.yaml
================================================
model:
  pretrained_checkpoint: checkpoints/dynamicrafter_512_v1/model.ckpt
  base_learning_rate: 1.0e-05
  scale_lr: False
  target: lvdm.models.ddpm3d.LatentVisualDiffusion
  params:
    rescale_betas_zero_snr: True
    parameterization: "v"
    linear_start: 0.00085
    linear_end: 0.012
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: video
    cond_stage_key: caption
    cond_stage_trainable: False
    image_proj_model_trainable: True
    conditioning_key: hybrid
    image_size: [40, 64]
    channels: 4
    scale_by_std: False
    scale_factor: 0.18215
    use_ema: False
    uncond_prob: 0.05
    uncond_type: 'empty_seq'
    rand_cond_frame: true
    use_dynamic_rescale: true
    base_scale: 0.7
    fps_condition_type: 'fps'
    perframe_ae: True

    unet_config:
      target: lvdm.modules.networks.openaimodel3d.UNetModel
      params:
        in_channels: 8
        out_channels: 4
        model_channels: 320
        attention_resolutions:
        - 4
        - 2
        - 1
        num_res_blocks: 2
        channel_mult:
        - 1
        - 2
        - 4
        - 4
        dropout: 0.1
        num_head_channels: 64
        transformer_depth: 1
        context_dim: 1024
        use_linear: true
        use_checkpoint: True
        temporal_conv: True
        temporal_attention: True
        temporal_selfatt_only: true
        use_relative_position: false
        use_causal_attention: False
        temporal_length: 16
        addition_attention: true
        image_cross_attention: true
        default_fs: 10
        fs_condition: true

    first_stage_config:
      target: lvdm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          double_z: True
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
      params:
        freeze: true
        layer: "penultimate"

    img_cond_stage_config:
      target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
      params:
        freeze: true
    
    image_proj_stage_config:
      target: lvdm.modules.encoders.resampler.Resampler
      params:
        dim: 1024
        depth: 4
        dim_head: 64
        heads: 12
        num_queries: 16
        embedding_dim: 1280
        output_dim: 1024
        ff_mult: 4
        video_length: 16

data:
  target: utils_data.DataModuleFromConfig
  params:
    batch_size: 2
    num_workers: 12
    wrap: false
    train:
      target: lvdm.data.webvid.WebVid
      params:
        data_dir: <WebVid10M DATA>
        meta_path: <.csv FILE>
        video_length: 16
        frame_stride: 6
        load_raw_resolution: true
        resolution: [320, 512]
        spatial_transform: resize_center_crop
        random_fs: true  ## if true, we uniformly sample fs with max_fs=frame_stride (above)

lightning:
  precision: 16
  # strategy: deepspeed_stage_2
  trainer:
    benchmark: True
    accumulate_grad_batches: 2
    max_steps: 100000
    # logger
    log_every_n_steps: 50
    # val
    val_check_interval: 0.5
    gradient_clip_algorithm: 'norm'
    gradient_clip_val: 0.5
  callbacks:
    model_checkpoint:
      target: pytorch_lightning.callbacks.ModelCheckpoint
      params:
        every_n_train_steps: 9000 #1000
        filename: "{epoch}-{step}"
        save_weights_only: True
    metrics_over_trainsteps_checkpoint:
      target: pytorch_lightning.callbacks.ModelCheckpoint
      params:
        filename: '{epoch}-{step}'
        save_weights_only: True
        every_n_train_steps: 10000 #20000 # 3s/step*2w=
    batch_logger:
      target: callbacks.ImageLogger
      params:
        batch_frequency: 500
        to_local: False
        max_images: 8
        log_images_kwargs:
          ddim_steps: 50
          unconditional_guidance_scale: 7.5
          timestep_spacing: uniform_trailing
          guidance_rescale: 0.7

================================================
FILE: ToonCrafter/configs/training_512_v1.0/run.sh
================================================
# NCCL configuration
# export NCCL_DEBUG=INFO
# export NCCL_IB_DISABLE=0
# export NCCL_IB_GID_INDEX=3
# export NCCL_NET_GDR_LEVEL=3
# export NCCL_TOPO_FILE=/tmp/topo.txt

# args
name="training_512_v1.0"
config_file=configs/${name}/config.yaml

# save root dir for logs, checkpoints, tensorboard record, etc.
save_root="<YOUR_SAVE_ROOT_DIR>"

mkdir -p $save_root/$name

## run
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \
--nproc_per_node=$HOST_GPU_NUM --nnodes=1 --master_addr=127.0.0.1 --master_port=12352 --node_rank=0 \
./main/trainer.py \
--base $config_file \
--train \
--name $name \
--logdir $save_root \
--devices $HOST_GPU_NUM \
lightning.trainer.num_nodes=1

## debugging
# CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch \
# --nproc_per_node=4 --nnodes=1 --master_addr=127.0.0.1 --master_port=12352 --node_rank=0 \
# ./main/trainer.py \
# --base $config_file \
# --train \
# --name $name \
# --logdir $save_root \
# --devices 4 \
# lightning.trainer.num_nodes=1

================================================
FILE: ToonCrafter/gradio_app.py
================================================
import os
import argparse
import sys
import gradio as gr
from scripts.gradio.i2v_test_application import Image2Video
sys.path.insert(1, os.path.join(sys.path[0], 'lvdm'))


i2v_examples_interp_512 = [
    ['prompts/512_interp/74906_1462_frame1.png', 'walking man', 50, 7.5, 1.0, 10, 123, 'prompts/512_interp/74906_1462_frame3.png'],
    ['prompts/512_interp/Japan_v2_2_062266_s2_frame1.png', 'an anime scene', 50, 7.5, 1.0, 10, 789, 'prompts/512_interp/Japan_v2_2_062266_s2_frame3.png'],
    ['prompts/512_interp/Japan_v2_3_119235_s2_frame1.png', 'an anime scene', 50, 7.5, 1.0, 10, 123, 'prompts/512_interp/Japan_v2_3_119235_s2_frame3.png'],
]


def dynamicrafter_demo(result_dir='./tmp/', res=512):
    if res == 1024:
        resolution = '576_1024'
        css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height:576px}"""
    elif res == 512:
        resolution = '320_512'
        css = """#input_img {max-width: 512px !important} #output_vid {max-width: 512px; max-height: 320px} #input_img2 {max-width: 512px !important} #output_vid {max-width: 512px; max-height: 320px}"""
    elif res == 256:
        resolution = '256_256'
        css = """#input_img {max-width: 256px !important} #output_vid {max-width: 256px; max-height: 256px}"""
    else:
        raise NotImplementedError(f"Unsupported resolution: {res}")
    image2video = Image2Video(result_dir, resolution=resolution)
    with gr.Blocks(analytics_enabled=False, css=css) as dynamicrafter_iface:

        with gr.Tab(label='ToonCrafter_320x512'):
            with gr.Column():
                with gr.Row():
                    with gr.Column():
                        with gr.Row():
                            i2v_input_image = gr.Image(label="Input Image1", elem_id="input_img")
                        with gr.Row():
                            i2v_input_text = gr.Text(label='Prompts')
                        with gr.Row():
                            i2v_seed = gr.Slider(label='Random Seed', minimum=0, maximum=50000, step=1, value=123)
                            i2v_eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label='ETA', value=1.0, elem_id="i2v_eta")
                            i2v_cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.5, elem_id="i2v_cfg_scale")
                        with gr.Row():
                            i2v_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="i2v_steps", label="Sampling steps", value=50)
                            i2v_motion = gr.Slider(minimum=5, maximum=30, step=1, elem_id="i2v_motion", label="FPS", value=10)
                        i2v_end_btn = gr.Button("Generate")
                    with gr.Column():
                        with gr.Row():
                            i2v_input_image2 = gr.Image(label="Input Image2", elem_id="input_img2")
                        with gr.Row():
                            i2v_output_video = gr.Video(label="Generated Video", elem_id="output_vid", autoplay=True, show_share_button=True)

                gr.Examples(examples=i2v_examples_interp_512,
                            inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed, i2v_input_image2],
                            outputs=[i2v_output_video],
                            fn=image2video.get_image,
                            cache_examples=False,
                            )
            i2v_end_btn.click(inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed, i2v_input_image2],
                              outputs=[i2v_output_video],
                              fn=image2video.get_image
                              )

    return dynamicrafter_iface


def get_parser():
    parser = argparse.ArgumentParser()
    return parser


if __name__ == "__main__":
    parser = get_parser()
    args = parser.parse_args()

    result_dir = os.path.join('./', 'results')
    dynamicrafter_iface = dynamicrafter_demo(result_dir)
    dynamicrafter_iface.queue(max_size=12)
    dynamicrafter_iface.launch(max_threads=1)
    # dynamicrafter_iface.launch(server_name='0.0.0.0', server_port=80, max_threads=1)


================================================
FILE: ToonCrafter/ldm/data/__init__.py
================================================


================================================
FILE: ToonCrafter/ldm/data/util.py
================================================
import torch

from ToonCrafter.ldm.modules.midas.api import load_midas_transform


class AddMiDaS(object):
    def __init__(self, model_type):
        super().__init__()
        self.transform = load_midas_transform(model_type)

    def pt2np(self, x):
        x = ((x + 1.0) * .5).detach().cpu().numpy()
        return x

    def np2pt(self, x):
        x = torch.from_numpy(x) * 2 - 1.
        return x

    def __call__(self, sample):
        # sample['jpg'] is tensor hwc in [-1, 1] at this point
        x = self.pt2np(sample['jpg'])
        x = self.transform({"image": x})["image"]
        sample['midas_in'] = x
        return sample

================================================
FILE: ToonCrafter/ldm/models/autoencoder.py
================================================
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager

from ToonCrafter.ldm.modules.diffusionmodules.model import Encoder, Decoder
from ToonCrafter.ldm.modules.distributions.distributions import DiagonalGaussianDistribution

from ToonCrafter.ldm.util import instantiate_from_config
from ToonCrafter.ldm.modules.ema import LitEma


class AutoencoderKL(pl.LightningModule):
    def __init__(self,
                 ddconfig,
                 lossconfig,
                 embed_dim,
                 ckpt_path=None,
                 ignore_keys=[],
                 image_key="image",
                 colorize_nlabels=None,
                 monitor=None,
                 ema_decay=None,
                 learn_logvar=False
                 ):
        super().__init__()
        self.learn_logvar = learn_logvar
        self.image_key = image_key
        self.encoder = Encoder(**ddconfig)
        self.decoder = Decoder(**ddconfig)
        self.loss = instantiate_from_config(lossconfig)
        assert ddconfig["double_z"]
        self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
        self.embed_dim = embed_dim
        if colorize_nlabels is not None:
            assert type(colorize_nlabels)==int
            self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
        if monitor is not None:
            self.monitor = monitor

        self.use_ema = ema_decay is not None
        if self.use_ema:
            self.ema_decay = ema_decay
            assert 0. < ema_decay < 1.
            self.model_ema = LitEma(self, decay=ema_decay)
            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)

    def init_from_ckpt(self, path, ignore_keys=list()):
        sd = torch.load(path, map_location="cpu")["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k))
                    del sd[k]
        self.load_state_dict(sd, strict=False)
        print(f"Restored from {path}")

    @contextmanager
    def ema_scope(self, context=None):
        if self.use_ema:
            self.model_ema.store(self.parameters())
            self.model_ema.copy_to(self)
            if context is not None:
                print(f"{context}: Switched to EMA weights")
        try:
            yield None
        finally:
            if self.use_ema:
                self.model_ema.restore(self.parameters())
                if context is not None:
                    print(f"{context}: Restored training weights")

    def on_train_batch_end(self, *args, **kwargs):
        if self.use_ema:
            self.model_ema(self)

    def encode(self, x):
        h = self.encoder(x)
        moments = self.quant_conv(h)
        posterior = DiagonalGaussianDistribution(moments)
        return posterior

    def decode(self, z):
        z = self.post_quant_conv(z)
        dec = self.decoder(z)
        return dec

    def forward(self, input, sample_posterior=True):
        posterior = self.encode(input)
        if sample_posterior:
            z = posterior.sample()
        else:
            z = posterior.mode()
        dec = self.decode(z)
        return dec, posterior

    def get_input(self, batch, k):
        x = batch[k]
        if len(x.shape) == 3:
            x = x[..., None]
        x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
        return x

    def training_step(self, batch, batch_idx, optimizer_idx):
        inputs = self.get_input(batch, self.image_key)
        reconstructions, posterior = self(inputs)

        if optimizer_idx == 0:
            # train encoder+decoder+logvar
            aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
                                            last_layer=self.get_last_layer(), split="train")
            self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
            self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
            return aeloss

        if optimizer_idx == 1:
            # train the discriminator
            discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
                                                last_layer=self.get_last_layer(), split="train")

            self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
            return discloss

    def validation_step(self, batch, batch_idx):
        log_dict = self._validation_step(batch, batch_idx)
        with self.ema_scope():
            log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
        return log_dict

    def _validation_step(self, batch, batch_idx, postfix=""):
        inputs = self.get_input(batch, self.image_key)
        reconstructions, posterior = self(inputs)
        aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
                                        last_layer=self.get_last_layer(), split="val"+postfix)

        discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
                                            last_layer=self.get_last_layer(), split="val"+postfix)

        self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
        self.log_dict(log_dict_ae)
        self.log_dict(log_dict_disc)
        return self.log_dict

    def configure_optimizers(self):
        lr = self.learning_rate
        ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
            self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
        if self.learn_logvar:
            print(f"{self.__class__.__name__}: Learning logvar")
            ae_params_list.append(self.loss.logvar)
        opt_ae = torch.optim.Adam(ae_params_list,
                                  lr=lr, betas=(0.5, 0.9))
        opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
                                    lr=lr, betas=(0.5, 0.9))
        return [opt_ae, opt_disc], []

    def get_last_layer(self):
        return self.decoder.conv_out.weight

    @torch.no_grad()
    def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
        log = dict()
        x = self.get_input(batch, self.image_key)
        x = x.to(self.device)
        if not only_inputs:
            xrec, posterior = self(x)
            if x.shape[1] > 3:
                # colorize with random projection
                assert xrec.shape[1] > 3
                x = self.to_rgb(x)
                xrec = self.to_rgb(xrec)
            log["samples"] = self.decode(torch.randn_like(posterior.sample()))
            log["reconstructions"] = xrec
            if log_ema or self.use_ema:
                with self.ema_scope():
                    xrec_ema, posterior_ema = self(x)
                    if x.shape[1] > 3:
                        # colorize with random projection
                        assert xrec_ema.shape[1] > 3
                        xrec_ema = self.to_rgb(xrec_ema)
                    log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
                    log["reconstructions_ema"] = xrec_ema
        log["inputs"] = x
        return log

    def to_rgb(self, x):
        assert self.image_key == "segmentation"
        if not hasattr(self, "colorize"):
            self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
        x = F.conv2d(x, weight=self.colorize)
        x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
        return x


class IdentityFirstStage(torch.nn.Module):
    def __init__(self, *args, vq_interface=False, **kwargs):
        self.vq_interface = vq_interface
        super().__init__()

    def encode(self, x, *args, **kwargs):
        return x

    def decode(self, x, *args, **kwargs):
        return x

    def quantize(self, x, *args, **kwargs):
        if self.vq_interface:
            return x, None, [None, None, None]
        return x

    def forward(self, x, *args, **kwargs):
        return x



================================================
FILE: ToonCrafter/ldm/models/diffusion/__init__.py
================================================


================================================
FILE: ToonCrafter/ldm/models/diffusion/ddim.py
================================================
"""SAMPLING ONLY."""

import torch
import numpy as np
from tqdm import tqdm

from ToonCrafter.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor


class DDIMSampler(object):
    def __init__(self, model, schedule="linear", **kwargs):
        super().__init__()
        self.model = model
        self.ddpm_num_timesteps = model.num_timesteps
        self.schedule = schedule

    def register_buffer(self, name, attr):
        if type(attr) == torch.Tensor:
            if attr.device != torch.device("cuda"):
                attr = attr.to(torch.device("cuda"))
        setattr(self, name, attr)

    def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
        self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
                                                  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
        alphas_cumprod = self.model.alphas_cumprod
        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)

        self.register_buffer('betas', to_torch(self.model.betas))
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
        self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))

        # ddim sampling parameters
        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
                                                                                   ddim_timesteps=self.ddim_timesteps,
                                                                                   eta=ddim_eta,verbose=verbose)
        self.register_buffer('ddim_sigmas', ddim_sigmas)
        self.register_buffer('ddim_alphas', ddim_alphas)
        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
            (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
                        1 - self.alphas_cumprod / self.alphas_cumprod_prev))
        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)

    @torch.no_grad()
    def sample(self,
               S,
               batch_size,
               shape,
               conditioning=None,
               callback=None,
               normals_sequence=None,
               img_callback=None,
               quantize_x0=False,
               eta=0.,
               mask=None,
               x0=None,
               temperature=1.,
               noise_dropout=0.,
               score_corrector=None,
               corrector_kwargs=None,
               verbose=True,
               x_T=None,
               log_every_t=100,
               unconditional_guidance_scale=1.,
               unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
               dynamic_threshold=None,
               ucg_schedule=None,
               **kwargs
               ):
        if conditioning is not None:
            if isinstance(conditioning, dict):
                ctmp = conditioning[list(conditioning.keys())[0]]
                while isinstance(ctmp, list): ctmp = ctmp[0]
                cbs = ctmp.shape[0]
                if cbs != batch_size:
                    print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")

            elif isinstance(conditioning, list):
                for ctmp in conditioning:
                    if ctmp.shape[0] != batch_size:
                        print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")

            else:
                if conditioning.shape[0] != batch_size:
                    print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")

        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
        # sampling
        C, H, W = shape
        size = (batch_size, C, H, W)
        print(f'Data shape for DDIM sampling is {size}, eta {eta}')

        samples, intermediates = self.ddim_sampling(conditioning, size,
                                                    callback=callback,
                                                    img_callback=img_callback,
                                                    quantize_denoised=quantize_x0,
                                                    mask=mask, x0=x0,
                                                    ddim_use_original_steps=False,
                                                    noise_dropout=noise_dropout,
                                                    temperature=temperature,
                                                    score_corrector=score_corrector,
                                                    corrector_kwargs=corrector_kwargs,
                                                    x_T=x_T,
                                                    log_every_t=log_every_t,
                                                    unconditional_guidance_scale=unconditional_guidance_scale,
                                                    unconditional_conditioning=unconditional_conditioning,
                                                    dynamic_threshold=dynamic_threshold,
                                                    ucg_schedule=ucg_schedule
                                                    )
        return samples, intermediates

    @torch.no_grad()
    def ddim_sampling(self, cond, shape,
                      x_T=None, ddim_use_original_steps=False,
                      callback=None, timesteps=None, quantize_denoised=False,
                      mask=None, x0=None, img_callback=None, log_every_t=100,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                      unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
                      ucg_schedule=None):
        device = self.model.betas.device
        b = shape[0]
        if x_T is None:
            img = torch.randn(shape, device=device)
        else:
            img = x_T

        if timesteps is None:
            timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
        elif timesteps is not None and not ddim_use_original_steps:
            subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
            timesteps = self.ddim_timesteps[:subset_end]

        intermediates = {'x_inter': [img], 'pred_x0': [img]}
        time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
        print(f"Running DDIM Sampling with {total_steps} timesteps")

        iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)

        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            ts = torch.full((b,), step, device=device, dtype=torch.long)

            if mask is not None:
                assert x0 is not None
                img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
                img = img_orig * mask + (1. - mask) * img

            if ucg_schedule is not None:
                assert len(ucg_schedule) == len(time_range)
                unconditional_guidance_scale = ucg_schedule[i]

            outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
                                      quantize_denoised=quantize_denoised, temperature=temperature,
                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
                                      corrector_kwargs=corrector_kwargs,
                                      unconditional_guidance_scale=unconditional_guidance_scale,
                                      unconditional_conditioning=unconditional_conditioning,
                                      dynamic_threshold=dynamic_threshold)
            img, pred_x0 = outs
            if callback: callback(i)
            if img_callback: img_callback(pred_x0, i)

            if index % log_every_t == 0 or index == total_steps - 1:
                intermediates['x_inter'].append(img)
                intermediates['pred_x0'].append(pred_x0)

        return img, intermediates

    @torch.no_grad()
    def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                      unconditional_guidance_scale=1., unconditional_conditioning=None,
                      dynamic_threshold=None):
        b, *_, device = *x.shape, x.device

        if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
            model_output = self.model.apply_model(x, t, c)
        else:
            x_in = torch.cat([x] * 2)
            t_in = torch.cat([t] * 2)
            if isinstance(c, dict):
                assert isinstance(unconditional_conditioning, dict)
                c_in = dict()
                for k in c:
                    if isinstance(c[k], list):
                        c_in[k] = [torch.cat([
                            unconditional_conditioning[k][i],
                            c[k][i]]) for i in range(len(c[k]))]
                    else:
                        c_in[k] = torch.cat([
                                unconditional_conditioning[k],
                                c[k]])
            elif isinstance(c, list):
                c_in = list()
                assert isinstance(unconditional_conditioning, list)
                for i in range(len(c)):
                    c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
            else:
                c_in = torch.cat([unconditional_conditioning, c])
            model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
            model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)

        if self.model.parameterization == "v":
            e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
        else:
            e_t = model_output

        if score_corrector is not None:
            assert self.model.parameterization == "eps", 'not implemented'
            e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)

        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
        alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
        sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
        sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
        # select parameters corresponding to the currently considered timestep
        a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
        a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
        sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
        sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)

        # current prediction for x_0
        if self.model.parameterization != "v":
            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
        else:
            pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)

        if quantize_denoised:
            pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)

        if dynamic_threshold is not None:
            raise NotImplementedError()

        # direction pointing to x_t
        dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
        if noise_dropout > 0.:
            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
        return x_prev, pred_x0

    @torch.no_grad()
    def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
               unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
        num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]

        assert t_enc <= num_reference_steps
        num_steps = t_enc

        if use_original_steps:
            alphas_next = self.alphas_cumprod[:num_steps]
            alphas = self.alphas_cumprod_prev[:num_steps]
        else:
            alphas_next = self.ddim_alphas[:num_steps]
            alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])

        x_next = x0
        intermediates = []
        inter_steps = []
        for i in tqdm(range(num_steps), desc='Encoding Image'):
            t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
            if unconditional_guidance_scale == 1.:
                noise_pred = self.model.apply_model(x_next, t, c)
            else:
                assert unconditional_conditioning is not None
                e_t_uncond, noise_pred = torch.chunk(
                    self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
                                           torch.cat((unconditional_conditioning, c))), 2)
                noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)

            xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
            weighted_noise_pred = alphas_next[i].sqrt() * (
                    (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
            x_next = xt_weighted + weighted_noise_pred
            if return_intermediates and i % (
                    num_steps // return_intermediates) == 0 and i < num_steps - 1:
                intermediates.append(x_next)
                inter_steps.append(i)
            elif return_intermediates and i >= num_steps - 2:
                intermediates.append(x_next)
                inter_steps.append(i)
            if callback: callback(i)

        out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
        if return_intermediates:
            out.update({'intermediates': intermediates})
        return x_next, out

    @torch.no_grad()
    def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
        # fast, but does not allow for exact reconstruction
        # t serves as an index to gather the correct alphas
        if use_original_steps:
            sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
            sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
        else:
            sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
            sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas

        if noise is None:
            noise = torch.randn_like(x0)
        return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
                extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)

    @torch.no_grad()
    def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
               use_original_steps=False, callback=None):

        timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
        timesteps = timesteps[:t_start]

        time_range = np.flip(timesteps)
        total_steps = timesteps.shape[0]
        print(f"Running DDIM Sampling with {total_steps} timesteps")

        iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
        x_dec = x_latent
        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
            x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
                                          unconditional_guidance_scale=unconditional_guidance_scale,
                                          unconditional_conditioning=unconditional_conditioning)
            if callback: callback(i)
        return x_dec

================================================
FILE: ToonCrafter/ldm/models/diffusion/ddpm.py
================================================
"""
wild mixture of
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
https://github.com/CompVis/taming-transformers
-- merci
"""

import torch
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
from torch.optim.lr_scheduler import LambdaLR
from einops import rearrange, repeat
from contextlib import contextmanager, nullcontext
from functools import partial
import itertools
from tqdm import tqdm
from torchvision.utils import make_grid
from pytorch_lightning.utilities import rank_zero_only
from omegaconf import ListConfig

from ToonCrafter.ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
from ToonCrafter.ldm.modules.ema import LitEma
from ToonCrafter.ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
from ToonCrafter.ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
from ToonCrafter.ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
from ToonCrafter.ldm.models.diffusion.ddim import DDIMSampler


__conditioning_keys__ = {'concat': 'c_concat',
                         'crossattn': 'c_crossattn',
                         'adm': 'y'}


def disabled_train(self, mode=True):
    """Overwrite model.train with this function to make sure train/eval mode
    does not change anymore."""
    return self


def uniform_on_device(r1, r2, shape, device):
    return (r1 - r2) * torch.rand(*shape, device=device) + r2


class DDPM(pl.LightningModule):
    # classic DDPM with Gaussian diffusion, in image space
    def __init__(self,
                 unet_config,
                 timesteps=1000,
                 beta_schedule="linear",
                 loss_type="l2",
                 ckpt_path=None,
                 ignore_keys=[],
                 load_only_unet=False,
                 monitor="val/loss",
                 use_ema=True,
                 first_stage_key="image",
                 image_size=256,
                 channels=3,
                 log_every_t=100,
                 clip_denoised=True,
                 linear_start=1e-4,
                 linear_end=2e-2,
                 cosine_s=8e-3,
                 given_betas=None,
                 original_elbo_weight=0.,
                 v_posterior=0.,  # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
                 l_simple_weight=1.,
                 conditioning_key=None,
                 parameterization="eps",  # all assuming fixed variance schedules
                 scheduler_config=None,
                 use_positional_encodings=False,
                 learn_logvar=False,
                 logvar_init=0.,
                 make_it_fit=False,
                 ucg_training=None,
                 reset_ema=False,
                 reset_num_ema_updates=False,
                 ):
        super().__init__()
        assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
        self.parameterization = parameterization
        print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
        self.cond_stage_model = None
        self.clip_denoised = clip_denoised
        self.log_every_t = log_every_t
        self.first_stage_key = first_stage_key
        self.image_size = image_size  # try conv?
        self.channels = channels
        self.use_positional_encodings = use_positional_encodings
        self.model = DiffusionWrapper(unet_config, conditioning_key)
        count_params(self.model, verbose=True)
        self.use_ema = use_ema
        if self.use_ema:
            self.model_ema = LitEma(self.model)
            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

        self.use_scheduler = scheduler_config is not None
        if self.use_scheduler:
            self.scheduler_config = scheduler_config

        self.v_posterior = v_posterior
        self.original_elbo_weight = original_elbo_weight
        self.l_simple_weight = l_simple_weight

        if monitor is not None:
            self.monitor = monitor
        self.make_it_fit = make_it_fit
        if reset_ema: assert exists(ckpt_path)
        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
            if reset_ema:
                assert self.use_ema
                print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
                self.model_ema = LitEma(self.model)
        if reset_num_ema_updates:
            print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
            assert self.use_ema
            self.model_ema.reset_num_updates()

        self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
                               linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)

        self.loss_type = loss_type

        self.learn_logvar = learn_logvar
        logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
        if self.learn_logvar:
            self.logvar = nn.Parameter(self.logvar, requires_grad=True)
        else:
            self.register_buffer('logvar', logvar)

        self.ucg_training = ucg_training or dict()
        if self.ucg_training:
            self.ucg_prng = np.random.RandomState()

    def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
        if exists(given_betas):
            betas = given_betas
        else:
            betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
                                       cosine_s=cosine_s)
        alphas = 1. - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])

        timesteps, = betas.shape
        self.num_timesteps = int(timesteps)
        self.linear_start = linear_start
        self.linear_end = linear_end
        assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'

        to_torch = partial(torch.tensor, dtype=torch.float32)

        self.register_buffer('betas', to_torch(betas))
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
                1. - alphas_cumprod) + self.v_posterior * betas
        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
        self.register_buffer('posterior_variance', to_torch(posterior_variance))
        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
        self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
        self.register_buffer('posterior_mean_coef1', to_torch(
            betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
        self.register_buffer('posterior_mean_coef2', to_torch(
            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))

        if self.parameterization == "eps":
            lvlb_weights = self.betas ** 2 / (
                    2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
        elif self.parameterization == "x0":
            lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
        elif self.parameterization == "v":
            lvlb_weights = torch.ones_like(self.betas ** 2 / (
                    2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
        else:
            raise NotImplementedError("mu not supported")
        lvlb_weights[0] = lvlb_weights[1]
        self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
        assert not torch.isnan(self.lvlb_weights).all()

    @contextmanager
    def ema_scope(self, context=None):
        if self.use_ema:
            self.model_ema.store(self.model.parameters())
            self.model_ema.copy_to(self.model)
            if context is not None:
                print(f"{context}: Switched to EMA weights")
        try:
            yield None
        finally:
            if self.use_ema:
                self.model_ema.restore(self.model.parameters())
                if context is not None:
                    print(f"{context}: Restored training weights")

    @torch.no_grad()
    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
        sd = torch.load(path, map_location="cpu")
        if "state_dict" in list(sd.keys()):
            sd = sd["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k))
                    del sd[k]
        if self.make_it_fit:
            n_params = len([name for name, _ in
                            itertools.chain(self.named_parameters(),
                                            self.named_buffers())])
            for name, param in tqdm(
                    itertools.chain(self.named_parameters(),
                                    self.named_buffers()),
                    desc="Fitting old weights to new weights",
                    total=n_params
            ):
                if not name in sd:
                    continue
                old_shape = sd[name].shape
                new_shape = param.shape
                assert len(old_shape) == len(new_shape)
                if len(new_shape) > 2:
                    # we only modify first two axes
                    assert new_shape[2:] == old_shape[2:]
                # assumes first axis corresponds to output dim
                if not new_shape == old_shape:
                    new_param = param.clone()
                    old_param = sd[name]
                    if len(new_shape) == 1:
                        for i in range(new_param.shape[0]):
                            new_param[i] = old_param[i % old_shape[0]]
                    elif len(new_shape) >= 2:
                        for i in range(new_param.shape[0]):
                            for j in range(new_param.shape[1]):
                                new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]

                        n_used_old = torch.ones(old_shape[1])
                        for j in range(new_param.shape[1]):
                            n_used_old[j % old_shape[1]] += 1
                        n_used_new = torch.zeros(new_shape[1])
                        for j in range(new_param.shape[1]):
                            n_used_new[j] = n_used_old[j % old_shape[1]]

                        n_used_new = n_used_new[None, :]
                        while len(n_used_new.shape) < len(new_shape):
                            n_used_new = n_used_new.unsqueeze(-1)
                        new_param /= n_used_new

                    sd[name] = new_param

        missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
            sd, strict=False)
        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
        if len(missing) > 0:
            print(f"Missing Keys:\n {missing}")
        if len(unexpected) > 0:
            print(f"\nUnexpected Keys:\n {unexpected}")

    def q_mean_variance(self, x_start, t):
        """
        Get the distribution q(x_t | x_0).
        :param x_start: the [N x C x ...] tensor of noiseless inputs.
        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
        :return: A tuple (mean, variance, log_variance), all of x_start's shape.
        """
        mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
        variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
        log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
        return mean, variance, log_variance

    def predict_start_from_noise(self, x_t, t, noise):
        return (
                extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
                extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )

    def predict_start_from_z_and_v(self, x_t, t, v):
        # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
        # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
        return (
                extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
        )

    def predict_eps_from_z_and_v(self, x_t, t, v):
        return (
                extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
        )

    def q_posterior(self, x_start, x_t, t):
        posterior_mean = (
                extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
                extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def p_mean_variance(self, x, t, clip_denoised: bool):
        model_out = self.model(x, t)
        if self.parameterization == "eps":
            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
        elif self.parameterization == "x0":
            x_recon = model_out
        if clip_denoised:
            x_recon.clamp_(-1., 1.)

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
        return model_mean, posterior_variance, posterior_log_variance

    @torch.no_grad()
    def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
        b, *_, device = *x.shape, x.device
        model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
        noise = noise_like(x.shape, device, repeat_noise)
        # no noise when t == 0
        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise

    @torch.no_grad()
    def p_sample_loop(self, shape, return_intermediates=False):
        device = self.betas.device
        b = shape[0]
        img = torch.randn(shape, device=device)
        intermediates = [img]
        for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
                                clip_denoised=self.clip_denoised)
            if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
                intermediates.append(img)
        if return_intermediates:
            return img, intermediates
        return img

    @torch.no_grad()
    def sample(self, batch_size=16, return_intermediates=False):
        image_size = self.image_size
        channels = self.channels
        return self.p_sample_loop((batch_size, channels, image_size, image_size),
                                  return_intermediates=return_intermediates)

    def q_sample(self, x_start, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))
        return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)

    def get_v(self, x, noise, t):
        return (
                extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
        )

    def get_loss(self, pred, target, mean=True):
        if self.loss_type == 'l1':
            loss = (target - pred).abs()
            if mean:
                loss = loss.mean()
        elif self.loss_type == 'l2':
            if mean:
                loss = torch.nn.functional.mse_loss(target, pred)
            else:
                loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
        else:
            raise NotImplementedError("unknown loss type '{loss_type}'")

        return loss

    def p_losses(self, x_start, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))
        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
        model_out = self.model(x_noisy, t)

        loss_dict = {}
        if self.parameterization == "eps":
            target = noise
        elif self.parameterization == "x0":
            target = x_start
        elif self.parameterization == "v":
            target = self.get_v(x_start, noise, t)
        else:
            raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")

        loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])

        log_prefix = 'train' if self.training else 'val'

        loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
        loss_simple = loss.mean() * self.l_simple_weight

        loss_vlb = (self.lvlb_weights[t] * loss).mean()
        loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})

        loss = loss_simple + self.original_elbo_weight * loss_vlb

        loss_dict.update({f'{log_prefix}/loss': loss})

        return loss, loss_dict

    def forward(self, x, *args, **kwargs):
        # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
        # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
        t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
        return self.p_losses(x, t, *args, **kwargs)

    def get_input(self, batch, k):
        x = batch[k]
        if len(x.shape) == 3:
            x = x[..., None]
        x = rearrange(x, 'b h w c -> b c h w')
        x = x.to(memory_format=torch.contiguous_format).float()
        return x

    def shared_step(self, batch):
        x = self.get_input(batch, self.first_stage_key)
        loss, loss_dict = self(x)
        return loss, loss_dict

    def training_step(self, batch, batch_idx):
        for k in self.ucg_training:
            p = self.ucg_training[k]["p"]
            val = self.ucg_training[k]["val"]
            if val is None:
                val = ""
            for i in range(len(batch[k])):
                if self.ucg_prng.choice(2, p=[1 - p, p]):
                    batch[k][i] = val

        loss, loss_dict = self.shared_step(batch)

        self.log_dict(loss_dict, prog_bar=True,
                      logger=True, on_step=True, on_epoch=True)

        self.log("global_step", self.global_step,
                 prog_bar=True, logger=True, on_step=True, on_epoch=False)

        if self.use_scheduler:
            lr = self.optimizers().param_groups[0]['lr']
            self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)

        return loss

    @torch.no_grad()
    def validation_step(self, batch, batch_idx):
        _, loss_dict_no_ema = self.shared_step(batch)
        with self.ema_scope():
            _, loss_dict_ema = self.shared_step(batch)
            loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
        self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
        self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)

    def on_train_batch_end(self, *args, **kwargs):
        if self.use_ema:
            self.model_ema(self.model)

    def _get_rows_from_list(self, samples):
        n_imgs_per_row = len(samples)
        denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
        denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
        return denoise_grid

    @torch.no_grad()
    def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
        log = dict()
        x = self.get_input(batch, self.first_stage_key)
        N = min(x.shape[0], N)
        n_row = min(x.shape[0], n_row)
        x = x.to(self.device)[:N]
        log["inputs"] = x

        # get diffusion row
        diffusion_row = list()
        x_start = x[:n_row]

        for t in range(self.num_timesteps):
            if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
                t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
                t = t.to(self.device).long()
                noise = torch.randn_like(x_start)
                x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
                diffusion_row.append(x_noisy)

        log["diffusion_row"] = self._get_rows_from_list(diffusion_row)

        if sample:
            # get denoise row
            with self.ema_scope("Plotting"):
                samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)

            log["samples"] = samples
            log["denoise_row"] = self._get_rows_from_list(denoise_row)

        if return_keys:
            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
                return log
            else:
                return {key: log[key] for key in return_keys}
        return log

    def configure_optimizers(self):
        lr = self.learning_rate
        params = list(self.model.parameters())
        if self.learn_logvar:
            params = params + [self.logvar]
        opt = torch.optim.AdamW(params, lr=lr)
        return opt


class LatentDiffusion(DDPM):
    """main class"""

    def __init__(self,
                 first_stage_config,
                 cond_stage_config,
                 num_timesteps_cond=None,
                 cond_stage_key="image",
                 cond_stage_trainable=False,
                 concat_mode=True,
                 cond_stage_forward=None,
                 conditioning_key=None,
                 scale_factor=1.0,
                 scale_by_std=False,
                 force_null_conditioning=False,
                 *args, **kwargs):
        self.force_null_conditioning = force_null_conditioning
        self.num_timesteps_cond = default(num_timesteps_cond, 1)
        self.scale_by_std = scale_by_std
        assert self.num_timesteps_cond <= kwargs['timesteps']
        # for backwards compatibility after implementation of DiffusionWrapper
        if conditioning_key is None:
            conditioning_key = 'concat' if concat_mode else 'crossattn'
        if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning:
            conditioning_key = None
        ckpt_path = kwargs.pop("ckpt_path", None)
        reset_ema = kwargs.pop("reset_ema", False)
        reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
        ignore_keys = kwargs.pop("ignore_keys", [])
        super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
        self.concat_mode = concat_mode
        self.cond_stage_trainable = cond_stage_trainable
        self.cond_stage_key = cond_stage_key
        try:
            self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
        except:
            self.num_downs = 0
        if not scale_by_std:
            self.scale_factor = scale_factor
        else:
            self.register_buffer('scale_factor', torch.tensor(scale_factor))
        self.instantiate_first_stage(first_stage_config)
        self.instantiate_cond_stage(cond_stage_config)
        self.cond_stage_forward = cond_stage_forward
        self.clip_denoised = False
        self.bbox_tokenizer = None

        self.restarted_from_ckpt = False
        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys)
            self.restarted_from_ckpt = True
            if reset_ema:
                assert self.use_ema
                print(
                    f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
                self.model_ema = LitEma(self.model)
        if reset_num_ema_updates:
            print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
            assert self.use_ema
            self.model_ema.reset_num_updates()

    def make_cond_schedule(self, ):
        self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
        ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
        self.cond_ids[:self.num_timesteps_cond] = ids

    @rank_zero_only
    @torch.no_grad()
    def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
        # only for very first batch
        if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
            assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
            # set rescale weight to 1./std of encodings
            print("### USING STD-RESCALING ###")
            x = super().get_input(batch, self.first_stage_key)
            x = x.to(self.device)
            encoder_posterior = self.encode_first_stage(x)
            z = self.get_first_stage_encoding(encoder_posterior).detach()
            del self.scale_factor
            self.register_buffer('scale_factor', 1. / z.flatten().std())
            print(f"setting self.scale_factor to {self.scale_factor}")
            print("### USING STD-RESCALING ###")

    def register_schedule(self,
                          given_betas=None, beta_schedule="linear", timesteps=1000,
                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
        super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)

        self.shorten_cond_schedule = self.num_timesteps_cond > 1
        if self.shorten_cond_schedule:
            self.make_cond_schedule()

    def instantiate_first_stage(self, config):
        model = instantiate_from_config(config)
        self.first_stage_model = model.eval()
        self.first_stage_model.train = disabled_train
        for param in self.first_stage_model.parameters():
            param.requires_grad = False

    def instantiate_cond_stage(self, config):
        if not self.cond_stage_trainable:
            if config == "__is_first_stage__":
                print("Using first stage also as cond stage.")
                self.cond_stage_model = self.first_stage_model
            elif config == "__is_unconditional__":
                print(f"Training {self.__class__.__name__} as an unconditional model.")
                self.cond_stage_model = None
                # self.be_unconditional = True
            else:
                model = instantiate_from_config(config)
                self.cond_stage_model = model.eval()
                self.cond_stage_model.train = disabled_train
                for param in self.cond_stage_model.parameters():
                    param.requires_grad = False
        else:
            assert config != '__is_first_stage__'
            assert config != '__is_unconditional__'
            model = instantiate_from_config(config)
            self.cond_stage_model = model

    def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
        denoise_row = []
        for zd in tqdm(samples, desc=desc):
            denoise_row.append(self.decode_first_stage(zd.to(self.device),
                                                       force_not_quantize=force_no_decoder_quantization))
        n_imgs_per_row = len(denoise_row)
        denoise_row = torch.stack(denoise_row)  # n_log_step, n_row, C, H, W
        denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
        denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
        return denoise_grid

    def get_first_stage_encoding(self, encoder_posterior):
        if isinstance(encoder_posterior, DiagonalGaussianDistribution):
            z = encoder_posterior.sample()
        elif isinstance(encoder_posterior, torch.Tensor):
            z = encoder_posterior
        else:
            raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
        return self.scale_factor * z

    def get_learned_conditioning(self, c):
        if self.cond_stage_forward is None:
            if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
                c = self.cond_stage_model.encode(c)
                if isinstance(c, DiagonalGaussianDistribution):
                    c = c.mode()
            else:
                c = self.cond_stage_model(c)
        else:
            assert hasattr(self.cond_stage_model, self.cond_stage_forward)
            c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
        return c

    def meshgrid(self, h, w):
        y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
        x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)

        arr = torch.cat([y, x], dim=-1)
        return arr

    def delta_border(self, h, w):
        """
        :param h: height
        :param w: width
        :return: normalized distance to image border,
         wtith min distance = 0 at border and max dist = 0.5 at image center
        """
        lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
        arr = self.meshgrid(h, w) / lower_right_corner
        dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
        dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
        edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
        return edge_dist

    def get_weighting(self, h, w, Ly, Lx, device):
        weighting = self.delta_border(h, w)
        weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
                               self.split_input_params["clip_max_weight"], )
        weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)

        if self.split_input_params["tie_braker"]:
            L_weighting = self.delta_border(Ly, Lx)
            L_weighting = torch.clip(L_weighting,
                                     self.split_input_params["clip_min_tie_weight"],
                                     self.split_input_params["clip_max_tie_weight"])

            L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
            weighting = weighting * L_weighting
        return weighting

    def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1):  # todo load once not every time, shorten code
        """
        :param x: img of size (bs, c, h, w)
        :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
        """
        bs, nc, h, w = x.shape

        # number of crops in image
        Ly = (h - kernel_size[0]) // stride[0] + 1
        Lx = (w - kernel_size[1]) // stride[1] + 1

        if uf == 1 and df == 1:
            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
            unfold = torch.nn.Unfold(**fold_params)

            fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)

            weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
            normalization = fold(weighting).view(1, 1, h, w)  # normalizes the overlap
            weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))

        elif uf > 1 and df == 1:
            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
            unfold = torch.nn.Unfold(**fold_params)

            fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
                                dilation=1, padding=0,
                                stride=(stride[0] * uf, stride[1] * uf))
            fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)

            weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
            normalization = fold(weighting).view(1, 1, h * uf, w * uf)  # normalizes the overlap
            weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))

        elif df > 1 and uf == 1:
            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
            unfold = torch.nn.Unfold(**fold_params)

            fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
                                dilation=1, padding=0,
                                stride=(stride[0] // df, stride[1] // df))
            fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)

            weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
            normalization = fold(weighting).view(1, 1, h // df, w // df)  # normalizes the overlap
            weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))

        else:
            raise NotImplementedError

        return fold, unfold, normalization, weighting

    @torch.no_grad()
    def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
                  cond_key=None, return_original_cond=False, bs=None, return_x=False):
        x = super().get_input(batch, k)
        if bs is not None:
            x = x[:bs]
        x = x.to(self.device)
        encoder_posterior = self.encode_first_stage(x)
        z = self.get_first_stage_encoding(encoder_posterior).detach()

        if self.model.conditioning_key is not None and not self.force_null_conditioning:
            if cond_key is None:
                cond_key = self.cond_stage_key
            if cond_key != self.first_stage_key:
                if cond_key in ['caption', 'coordinates_bbox', "txt"]:
                    xc = batch[cond_key]
                elif cond_key in ['class_label', 'cls']:
                    xc = batch
                else:
                    xc = super().get_input(batch, cond_key).to(self.device)
            else:
                xc = x
            if not self.cond_stage_trainable or force_c_encode:
                if isinstance(xc, dict) or isinstance(xc, list):
                    c = self.get_learned_conditioning(xc)
                else:
                    c = self.get_learned_conditioning(xc.to(self.device))
            else:
                c = xc
            if bs is not None:
                c = c[:bs]

            if self.use_positional_encodings:
                pos_x, pos_y = self.compute_latent_shifts(batch)
                ckey = __conditioning_keys__[self.model.conditioning_key]
                c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}

        else:
            c = None
            xc = None
            if self.use_positional_encodings:
                pos_x, pos_y = self.compute_latent_shifts(batch)
                c = {'pos_x': pos_x, 'pos_y': pos_y}
        out = [z, c]
        if return_first_stage_outputs:
            xrec = self.decode_first_stage(z)
            out.extend([x, xrec])
        if return_x:
            out.extend([x])
        if return_original_cond:
            out.append(xc)
        return out

    @torch.no_grad()
    def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
        if predict_cids:
            if z.dim() == 4:
                z = torch.argmax(z.exp(), dim=1).long()
            z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
            z = rearrange(z, 'b h w c -> b c h w').contiguous()

        z = 1. / self.scale_factor * z
        return self.first_stage_model.decode(z)

    @torch.no_grad()
    def encode_first_stage(self, x):
        return self.first_stage_model.encode(x)

    def shared_step(self, batch, **kwargs):
        x, c = self.get_input(batch, self.first_stage_key)
        loss = self(x, c)
        return loss

    def forward(self, x, c, *args, **kwargs):
        t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
        if self.model.conditioning_key is not None:
            assert c is not None
            if self.cond_stage_trainable:
                c = self.get_learned_conditioning(c)
            if self.shorten_cond_schedule:  # TODO: drop this option
                tc = self.cond_ids[t].to(self.device)
                c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
        return self.p_losses(x, c, t, *args, **kwargs)

    def apply_model(self, x_noisy, t, cond, return_ids=False):
        if isinstance(cond, dict):
            # hybrid case, cond is expected to be a dict
            pass
        else:
            if not isinstance(cond, list):
                cond = [cond]
            key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
            cond = {key: cond}

        x_recon = self.model(x_noisy, t, **cond)

        if isinstance(x_recon, tuple) and not return_ids:
            return x_recon[0]
        else:
            return x_recon

    def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
        return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
               extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)

    def _prior_bpd(self, x_start):
        """
        Get the prior KL term for the variational lower-bound, measured in
        bits-per-dim.
        This term can't be optimized, as it only depends on the encoder.
        :param x_start: the [N x C x ...] tensor of inputs.
        :return: a batch of [N] KL values (in bits), one per batch element.
        """
        batch_size = x_start.shape[0]
        t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
        qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
        kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
        return mean_flat(kl_prior) / np.log(2.0)

    def p_losses(self, x_start, cond, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))
        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
        model_output = self.apply_model(x_noisy, t, cond)

        loss_dict = {}
        prefix = 'train' if self.training else 'val'

        if self.parameterization == "x0":
            target = x_start
        elif self.parameterization == "eps":
            target = noise
        elif self.parameterization == "v":
            target = self.get_v(x_start, noise, t)
        else:
            raise NotImplementedError()

        loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
        loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})

        logvar_t = self.logvar[t].to(self.device)
        loss = loss_simple / torch.exp(logvar_t) + logvar_t
        # loss = loss_simple / torch.exp(self.logvar) + self.logvar
        if self.learn_logvar:
            loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
            loss_dict.update({'logvar': self.logvar.data.mean()})

        loss = self.l_simple_weight * loss.mean()

        loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
        loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
        loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
        loss += (self.original_elbo_weight * loss_vlb)
        loss_dict.update({f'{prefix}/loss': loss})

        return loss, loss_dict

    def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
                        return_x0=False, score_corrector=None, corrector_kwargs=None):
        t_in = t
        model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)

        if score_corrector is not None:
            assert self.parameterization == "eps"
            model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)

        if return_codebook_ids:
            model_out, logits = model_out

        if self.parameterization == "eps":
            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
        elif self.parameterization == "x0":
            x_recon = model_out
        else:
            raise NotImplementedError()

        if clip_denoised:
            x_recon.clamp_(-1., 1.)
        if quantize_denoised:
            x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
        if return_codebook_ids:
            return model_mean, posterior_variance, posterior_log_variance, logits
        elif return_x0:
            return model_mean, posterior_variance, posterior_log_variance, x_recon
        else:
            return model_mean, posterior_variance, posterior_log_variance

    @torch.no_grad()
    def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
                 return_codebook_ids=False, quantize_denoised=False, return_x0=False,
                 temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
        b, *_, device = *x.shape, x.device
        outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
                                       return_codebook_ids=return_codebook_ids,
                                       quantize_denoised=quantize_denoised,
                                       return_x0=return_x0,
                                       score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
        if return_codebook_ids:
            raise DeprecationWarning("Support dropped.")
            model_mean, _, model_log_variance, logits = outputs
        elif return_x0:
            model_mean, _, model_log_variance, x0 = outputs
        else:
            model_mean, _, model_log_variance = outputs

        noise = noise_like(x.shape, device, repeat_noise) * temperature
        if noise_dropout > 0.:
            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
        # no noise when t == 0
        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))

        if return_codebook_ids:
            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
        if return_x0:
            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
        else:
            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise

    @torch.no_grad()
    def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
                              img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
                              score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
                              log_every_t=None):
        if not log_every_t:
            log_every_t = self.log_every_t
        timesteps = self.num_timesteps
        if batch_size is not None:
            b = batch_size if batch_size is not None else shape[0]
            shape = [batch_size] + list(shape)
        else:
            b = batch_size = shape[0]
        if x_T is None:
            img = torch.randn(shape, device=self.device)
        else:
            img = x_T
        intermediates = []
        if cond is not None:
            if isinstance(cond, dict):
                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
            else:
                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]

        if start_T is not None:
            timesteps = min(timesteps, start_T)
        iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
                        total=timesteps) if verbose else reversed(
            range(0, timesteps))
        if type(temperature) == float:
            temperature = [temperature] * timesteps

        for i in iterator:
            ts = torch.full((b,), i, device=self.device, dtype=torch.long)
            if self.shorten_cond_schedule:
                assert self.model.conditioning_key != 'hybrid'
                tc = self.cond_ids[ts].to(cond.device)
                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))

            img, x0_partial = self.p_sample(img, cond, ts,
                                            clip_denoised=self.clip_denoised,
                                            quantize_denoised=quantize_denoised, return_x0=True,
                                            temperature=temperature[i], noise_dropout=noise_dropout,
                                            score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
            if mask is not None:
                assert x0 is not None
                img_orig = self.q_sample(x0, ts)
                img = img_orig * mask + (1. - mask) * img

            if i % log_every_t == 0 or i == timesteps - 1:
                intermediates.append(x0_partial)
            if callback: callback(i)
            if img_callback: img_callback(img, i)
        return img, intermediates

    @torch.no_grad()
    def p_sample_loop(self, cond, shape, return_intermediates=False,
                      x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
                      mask=None, x0=None, img_callback=None, start_T=None,
                      log_every_t=None):

        if not log_every_t:
            log_every_t = self.log_every_t
        device = self.betas.device
        b = shape[0]
        if x_T is None:
            img = torch.randn(shape, device=device)
        else:
            img = x_T

        intermediates = [img]
        if timesteps is None:
            timesteps = self.num_timesteps

        if start_T is not None:
            timesteps = min(timesteps, start_T)
        iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
            range(0, timesteps))

        if mask is not None:
            assert x0 is not None
            assert x0.shape[2:3] == mask.shape[2:3]  # spatial size has to match

        for i in iterator:
            ts = torch.full((b,), i, device=device, dtype=torch.long)
            if self.shorten_cond_schedule:
                assert self.model.conditioning_key != 'hybrid'
                tc = self.cond_ids[ts].to(cond.device)
                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))

            img = self.p_sample(img, cond, ts,
                                clip_denoised=self.clip_denoised,
                                quantize_denoised=quantize_denoised)
            if mask is not None:
                img_orig = self.q_sample(x0, ts)
                img = img_orig * mask + (1. - mask) * img

            if i % log_every_t == 0 or i == timesteps - 1:
                intermediates.append(img)
            if callback: callback(i)
            if img_callback: img_callback(img, i)

        if return_intermediates:
            return img, intermediates
        return img

    @torch.no_grad()
    def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
               verbose=True, timesteps=None, quantize_denoised=False,
               mask=None, x0=None, shape=None, **kwargs):
        if shape is None:
            shape = (batch_size, self.channels, self.image_size, self.image_size)
        if cond is not None:
            if isinstance(cond, dict):
                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
            else:
                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
        return self.p_sample_loop(cond,
                                  shape,
                                  return_intermediates=return_intermediates, x_T=x_T,
                                  verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
                                  mask=mask, x0=x0)

    @torch.no_grad()
    def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
        if ddim:
            ddim_sampler = DDIMSampler(self)
            shape = (self.channels, self.image_size, self.image_size)
            samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
                                                         shape, cond, verbose=False, **kwargs)

        else:
            samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
                                                 return_intermediates=True, **kwargs)

        return samples, intermediates

    @torch.no_grad()
    def get_unconditional_conditioning(self, batch_size, null_label=None):
        if null_label is not None:
            xc = null_label
            if isinstance(xc, ListConfig):
                xc = list(xc)
            if isinstance(xc, dict) or isinstance(xc, list):
                c = self.get_learned_conditioning(xc)
            else:
                if hasattr(xc, "to"):
                    xc = xc.to(self.device)
                c = self.get_learned_conditioning(xc)
        else:
            if self.cond_stage_key in ["class_label", "cls"]:
                xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
                return self.get_learned_conditioning(xc)
            else:
                raise NotImplementedError("todo")
        if isinstance(c, list):  # in case the encoder gives us a list
            for i in range(len(c)):
                c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
        else:
            c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
        return c

    @torch.no_grad()
    def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
                   quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
                   plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
                   use_ema_scope=True,
                   **kwargs):
        ema_scope = self.ema_scope if use_ema_scope else nullcontext
        use_ddim = ddim_steps is not None

        log = dict()
        z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
                                           return_first_stage_outputs=True,
                                           force_c_encode=True,
                                           return_original_cond=True,
                                           bs=N)
        N = min(x.shape[0], N)
        n_row = min(x.shape[0], n_row)
        log["inputs"] = x
        log["reconstruction"] = xrec
        if self.model.conditioning_key is not None:
            if hasattr(self.cond_stage_model, "decode"):
                xc = self.cond_stage_model.decode(c)
                log["conditioning"] = xc
            elif self.cond_stage_key in ["caption", "txt"]:
                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
                log["conditioning"] = xc
            elif self.cond_stage_key in ['class_label', "cls"]:
                try:
                    xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
                    log['conditioning'] = xc
                except KeyError:
                    # probably no "human_label" in batch
                    pass
            elif isimage(xc):
                log["conditioning"] = xc
            if ismap(xc):
                log["original_conditioning"] = self.to_rgb(xc)

        if plot_diffusion_rows:
            # get diffusion row
            diffusion_row = list()
            z_start = z[:n_row]
            for t in range(self.num_timesteps):
                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
                    t = t.to(self.device).long()
                    noise = torch.randn_like(z_start)
                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
                    diffusion_row.append(self.decode_first_stage(z_noisy))

            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
            log["diffusion_row"] = diffusion_grid

        if sample:
            # get denoise row
            with ema_scope("Sampling"):
                samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
                                                         ddim_steps=ddim_steps, eta=ddim_eta)
                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
            x_samples = self.decode_first_stage(samples)
            log["samples"] = x_samples
            if plot_denoise_rows:
                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
                log["denoise_row"] = denoise_grid

            if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
                    self.first_stage_model, IdentityFirstStage):
                # also display when quantizing x0 while sampling
                with ema_scope("Plotting Quantized Denoised"):
                    samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
                                                             ddim_steps=ddim_steps, eta=ddim_eta,
                                                             quantize_denoised=True)
                    # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
                    #                                      quantize_denoised=True)
                x_samples = self.decode_first_stage(samples.to(self.device))
                log["samples_x0_quantized"] = x_samples

        if unconditional_guidance_scale > 1.0:
            uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
            if self.model.conditioning_key == "crossattn-adm":
                uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
            with ema_scope("Sampling with classifier-free guidance"):
                samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
                                                 ddim_steps=ddim_steps, eta=ddim_eta,
                                                 unconditional_guidance_scale=unconditional_guidance_scale,
                                                 unconditional_conditioning=uc,
                                                 )
                x_samples_cfg = self.decode_first_stage(samples_cfg)
                log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg

        if inpaint:
            # make a simple center square
            b, h, w = z.shape[0], z.shape[2], z.shape[3]
            mask = torch.ones(N, h, w).to(self.device)
            # zeros will be filled in
            mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
            mask = mask[:, None, ...]
            with ema_scope("Plotting Inpaint"):
                samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
                                             ddim_steps=ddim_steps, x0=z[:N], mask=mask)
            x_samples = self.decode_first_stage(samples.to(self.device))
            log["samples_inpainting"] = x_samples
            log["mask"] = mask

            # outpaint
            mask = 1. - mask
            with ema_scope("Plotting Outpaint"):
                samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
                                             ddim_steps=ddim_steps, x0=z[:N], mask=mask)
            x_samples = self.decode_first_stage(samples.to(self.device))
            log["samples_outpainting"] = x_samples

        if plot_progressive_rows:
            with ema_scope("Plotting Progressives"):
                img, progressives = self.progressive_denoising(c,
                                                               shape=(self.channels, self.image_size, self.image_size),
                                                               batch_size=N)
            prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
            log["progressive_row"] = prog_row

        if return_keys:
            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
                return log
            else:
                return {key: log[key] for key in return_keys}
        return log

    def configure_optimizers(self):
        lr = self.learning_rate
        params = list(self.model.parameters())
        if self.cond_stage_trainable:
            print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
            params = params + list(self.cond_stage_model.parameters())
        if self.learn_logvar:
            print('Diffusion model optimizing logvar')
            params.append(self.logvar)
        opt = torch.optim.AdamW(params, lr=lr)
        if self.use_scheduler:
            assert 'target' in self.scheduler_config
            scheduler = instantiate_from_config(self.scheduler_config)

            print("Setting up LambdaLR scheduler...")
            scheduler = [
                {
                    'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
                    'interval': 'step',
                    'frequency': 1
                }]
            return [opt], scheduler
        return opt

    @torch.no_grad()
    def to_rgb(self, x):
        x = x.float()
        if not hasattr(self, "colorize"):
            self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
        x = nn.functional.conv2d(x, weight=self.colorize)
        x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
        return x


class DiffusionWrapper(pl.LightningModule):
    def __init__(self, diff_model_config, conditioning_key):
        super().__init__()
        self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
        self.diffusion_model = instantiate_from_config(diff_model_config)
        self.conditioning_key = conditioning_key
        assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']

    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
        if self.conditioning_key is None:
            out = self.diffusion_model(x, t)
        elif self.conditioning_key == 'concat':
            xc = torch.cat([x] + c_concat, dim=1)
            out = self.diffusion_model(xc, t)
        elif self.conditioning_key == 'crossattn':
            if not self.sequential_cross_attn:
                cc = torch.cat(c_crossattn, 1)
            else:
                cc = c_crossattn
            out = self.diffusion_model(x, t, context=cc)
        elif self.conditioning_key == 'hybrid':
            xc = torch.cat([x] + c_concat, dim=1)
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(xc, t, context=cc)
        elif self.conditioning_key == 'hybrid-adm':
            assert c_adm is not None
            xc = torch.cat([x] + c_concat, dim=1)
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(xc, t, context=cc, y=c_adm)
        elif self.conditioning_key == 'crossattn-adm':
            assert c_adm is not None
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(x, t, context=cc, y=c_adm)
        elif self.conditioning_key == 'adm':
            cc = c_crossattn[0]
            out = self.diffusion_model(x, t, y=cc)
        else:
            raise NotImplementedError()

        return out


class LatentUpscaleDiffusion(LatentDiffusion):
    def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs):
        super().__init__(*args, **kwargs)
        # assumes that neither the cond_stage nor the low_scale_model contain trainable params
        assert not self.cond_stage_trainable
        self.instantiate_low_stage(low_scale_config)
        self.low_scale_key = low_scale_key
        self.noise_level_key = noise_level_key

    def instantiate_low_stage(self, config):
        model = instantiate_from_config(config)
        self.low_scale_model = model.eval()
        self.low_scale_model.train = disabled_train
        for param in self.low_scale_model.parameters():
            param.requires_grad = False

    @torch.no_grad()
    def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
        if not log_mode:
            z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
        else:
            z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
                                                  force_c_encode=True, return_original_cond=True, bs=bs)
        x_low = batch[self.low_scale_key][:bs]
        x_low = rearrange(x_low, 'b h w c -> b c h w')
        x_low = x_low.to(memory_format=torch.contiguous_format).float()
        zx, noise_level = self.low_scale_model(x_low)
        if self.noise_level_key is not None:
            # get noise level from batch instead, e.g. when extracting a custom noise level for bsr
            raise NotImplementedError('TODO')

        all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
        if log_mode:
            # TODO: maybe disable if too expensive
            x_low_rec = self.low_scale_model.decode(zx)
            return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
        return z, all_conds

    @torch.no_grad()
    def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
                   plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
                   unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
                   **kwargs):
        ema_scope = self.ema_scope if use_ema_scope else nullcontext
        use_ddim = ddim_steps is not None

        log = dict()
        z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
                                                                          log_mode=True)
        N = min(x.shape[0], N)
        n_row = min(x.shape[0], n_row)
        log["inputs"] = x
        log["reconstruction"] = xrec
        log["x_lr"] = x_low
        log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
        if self.model.conditioning_key is not None:
            if hasattr(self.cond_stage_model, "decode"):
                xc = self.cond_stage_model.decode(c)
                log["conditioning"] = xc
            elif self.cond_stage_key in ["caption", "txt"]:
                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
                log["conditioning"] = xc
            elif self.cond_stage_key in ['class_label', 'cls']:
                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
                log['conditioning'] = xc
            elif isimage(xc):
                log["conditioning"] = xc
            if ismap(xc):
                log["original_conditioning"] = self.to_rgb(xc)

        if plot_diffusion_rows:
            # get diffusion row
            diffusion_row = list()
            z_start = z[:n_row]
            for t in range(self.num_timesteps):
                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
                    t = t.to(self.device).long()
                    noise = torch.randn_like(z_start)
                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
                    diffusion_row.append(self.decode_first_stage(z_noisy))

            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
            log["diffusion_row"] = diffusion_grid

        if sample:
            # get denoise row
            with ema_scope("Sampling"):
                samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
                                                         ddim_steps=ddim_steps, eta=ddim_eta)
                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
            x_samples = self.decode_first_stage(samples)
            log["samples"] = x_samples
            if plot_denoise_rows:
                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
                log["denoise_row"] = denoise_grid

        if unconditional_guidance_scale > 1.0:
            uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
            # TODO explore better "unconditional" choices for the other keys
            # maybe guide away from empty text label and highest noise level and maximally degraded zx?
            uc = dict()
            for k in c:
                if k == "c_crossattn":
                    assert isinstance(c[k], list) and len(c[k]) == 1
                    uc[k] = [uc_tmp]
                elif k == "c_adm":  # todo: only run with text-based guidance?
                    assert isinstance(c[k], torch.Tensor)
                    #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
                    uc[k] = c[k]
                elif isinstance(c[k], list):
                    uc[k] = [c[k][i] for i in range(len(c[k]))]
                else:
                    uc[k] = c[k]

            with ema_scope("Sampling with classifier-free guidance"):
                samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
                                                 ddim_steps=ddim_steps, eta=ddim_eta,
                                                 unconditional_guidance_scale=unconditional_guidance_scale,
                                                 unconditional_conditioning=uc,
                                                 )
                x_samples_cfg = self.decode_first_stage(samples_cfg)
                log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg

        if plot_progressive_rows:
            with ema_scope("Plotting Progressives"):
                img, progressives = self.progressive_denoising(c,
                                                               shape=(self.channels, self.image_size, self.image_size),
                                                               batch_size=N)
            prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
            log["progressive_row"] = prog_row

        return log


class LatentFinetuneDiffusion(LatentDiffusion):
    """
         Basis for different finetunas, such as inpainting or depth2image
         To disable finetuning mode, set finetune_keys to None
    """

    def __init__(self,
                 concat_keys: tuple,
                 finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
                                "model_ema.diffusion_modelinput_blocks00weight"
                                ),
                 keep_finetune_dims=4,
                 # if model was trained without concat mode before and we would like to keep these channels
                 c_concat_log_start=None,  # to log reconstruction of c_concat codes
                 c_concat_log_end=None,
                 *args, **kwargs
                 ):
        ckpt_path = kwargs.pop("ckpt_path", None)
        ignore_keys = kwargs.pop("ignore_keys", list())
        super().__init__(*args, **kwargs)
        self.finetune_keys = finetune_keys
        self.concat_keys = concat_keys
        self.keep_dims = keep_finetune_dims
        self.c_concat_log_start = c_concat_log_start
        self.c_concat_log_end = c_concat_log_end
        if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
        if exists(ckpt_path):
            self.init_from_ckpt(ckpt_path, ignore_keys)

    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
        sd = torch.load(path, map_location="cpu")
        if "state_dict" in list(sd.keys()):
            sd = sd["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k))
                    del sd[k]

            # make it explicit, finetune by including extra input channels
            if exists(self.finetune_keys) and k in self.finetune_keys:
                new_entry = None
                for name, param in self.named_parameters():
                    if name in self.finetune_keys:
                        print(
                            f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
                        new_entry = torch.zeros_like(param)  # zero init
                assert exists(new_entry), 'did not find matching parameter to modify'
                new_entry[:, :self.keep_dims, ...] = sd[k]
                sd[k] = new_entry

        missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
            sd, strict=False)
        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
        if len(missing) > 0:
            print(f"Missing Keys: {missing}")
        if len(unexpected) > 0:
            print(f"Unexpected Keys: {unexpected}")

    @torch.no_grad()
    def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
                   quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
                   plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
                   use_ema_scope=True,
                   **kwargs):
        ema_scope = self.ema_scope if use_ema_scope else nullcontext
        use_ddim = ddim_steps is not None

        log = dict()
        z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
        c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
        N = min(x.shape[0], N)
        n_row = min(x.shape[0], n_row)
        log["inputs"] = x
        log["reconstruction"] = xrec
        if self.model.conditioning_key is not None:
            if hasattr(self.cond_stage_model, "decode"):
                xc = self.cond_stage_model.decode(c)
                log["conditioning"] = xc
            elif self.cond_stage_key in ["caption", "txt"]:
                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
                log["conditioning"] = xc
            elif self.cond_stage_key in ['class_label', 'cls']:
                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
                log['conditioning'] = xc
            elif isimage(xc):
                log["conditioning"] = xc
            if ismap(xc):
                log["original_conditioning"] = self.to_rgb(xc)

        if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
            log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end])

        if plot_diffusion_rows:
            # get diffusion row
            diffusion_row = list()
            z_start = z[:n_row]
            for t in range(self.num_timesteps):
                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
                    t = t.to(self.device).long()
                    noise = torch.randn_like(z_start)
                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
                    diffusion_row.append(self.decode_first_stage(z_noisy))

            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
            log["diffusion_row"] = diffusion_grid

        if sample:
            # get denoise row
            with ema_scope("Sampling"):
                samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
                                                         batch_size=N, ddim=use_ddim,
                                                         ddim_steps=ddim_steps, eta=ddim_eta)
                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
            x_samples = self.decode_first_stage(samples)
            log["samples"] = x_samples
            if plot_denoise_rows:
                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
                log["denoise_row"] = denoise_grid

        if unconditional_guidance_scale > 1.0:
            uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
            uc_cat = c_cat
            uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
            with ema_scope("Sampling with classifier-free guidance"):
                samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
                                                 batch_size=N, ddim=use_ddim,
                                                 ddim_steps=ddim_steps, eta=ddim_eta,
                                                 unconditional_guidance_scale=unconditional_guidance_scale,
                                                 unconditional_conditioning=uc_full,
                                                 )
                x_samples_cfg = self.decode_first_stage(samples_cfg)
                log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg

        return log


class LatentInpaintDiffusion(LatentFinetuneDiffusion):
    """
    can either run as pure inpainting model (only concat mode) or with mixed conditionings,
    e.g. mask as concat and text via cross-attn.
    To disable finetuning mode, set finetune_keys to None
     """

    def __init__(self,
                 concat_keys=("mask", "masked_image"),
                 masked_image_key="masked_image",
                 *args, **kwargs
                 ):
        super().__init__(concat_keys, *args, **kwargs)
        self.masked_image_key = masked_image_key
        assert self.masked_image_key in concat_keys

    @torch.no_grad()
    def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
Download .txt
gitextract_h58eh22a/

├── .github/
│   └── workflows/
│       └── publish.yml
├── .gitignore
├── LICENSE
├── ToonCrafter/
│   ├── .gitignore
│   ├── LICENSE
│   ├── README.md
│   ├── __init__.py
│   ├── cldm/
│   │   ├── cldm.py
│   │   ├── ddim_hacked.py
│   │   ├── hack.py
│   │   ├── logger.py
│   │   └── model.py
│   ├── configs/
│   │   ├── cldm_v21.yaml
│   │   ├── inference_512_v1.0.yaml
│   │   ├── training_1024_v1.0/
│   │   │   ├── config.yaml
│   │   │   └── run.sh
│   │   └── training_512_v1.0/
│   │       ├── config.yaml
│   │       └── run.sh
│   ├── gradio_app.py
│   ├── ldm/
│   │   ├── data/
│   │   │   ├── __init__.py
│   │   │   └── util.py
│   │   ├── models/
│   │   │   ├── autoencoder.py
│   │   │   └── diffusion/
│   │   │       ├── __init__.py
│   │   │       ├── ddim.py
│   │   │       ├── ddpm.py
│   │   │       ├── dpm_solver/
│   │   │       │   ├── __init__.py
│   │   │       │   ├── dpm_solver.py
│   │   │       │   └── sampler.py
│   │   │       ├── plms.py
│   │   │       └── sampling_util.py
│   │   ├── modules/
│   │   │   ├── attention.py
│   │   │   ├── diffusionmodules/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── model.py
│   │   │   │   ├── openaimodel.py
│   │   │   │   ├── upscaling.py
│   │   │   │   └── util.py
│   │   │   ├── distributions/
│   │   │   │   ├── __init__.py
│   │   │   │   └── distributions.py
│   │   │   ├── ema.py
│   │   │   ├── encoders/
│   │   │   │   ├── __init__.py
│   │   │   │   └── modules.py
│   │   │   ├── image_degradation/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── bsrgan.py
│   │   │   │   ├── bsrgan_light.py
│   │   │   │   └── utils_image.py
│   │   │   └── midas/
│   │   │       ├── __init__.py
│   │   │       ├── api.py
│   │   │       ├── midas/
│   │   │       │   ├── __init__.py
│   │   │       │   ├── base_model.py
│   │   │       │   ├── blocks.py
│   │   │       │   ├── dpt_depth.py
│   │   │       │   ├── midas_net.py
│   │   │       │   ├── midas_net_custom.py
│   │   │       │   ├── transforms.py
│   │   │       │   └── vit.py
│   │   │       └── utils.py
│   │   └── util.py
│   ├── lvdm/
│   │   ├── __init__.py
│   │   ├── basics.py
│   │   ├── common.py
│   │   ├── data/
│   │   │   ├── base.py
│   │   │   └── webvid.py
│   │   ├── distributions.py
│   │   ├── ema.py
│   │   ├── models/
│   │   │   ├── autoencoder.py
│   │   │   ├── autoencoder_dualref.py
│   │   │   ├── ddpm3d.py
│   │   │   ├── samplers/
│   │   │   │   ├── ddim.py
│   │   │   │   └── ddim_multiplecond.py
│   │   │   └── utils_diffusion.py
│   │   └── modules/
│   │       ├── attention.py
│   │       ├── attention_svd.py
│   │       ├── encoders/
│   │       │   ├── condition.py
│   │       │   └── resampler.py
│   │       ├── networks/
│   │       │   ├── ae_modules.py
│   │       │   └── openaimodel3d.py
│   │       └── x_transformer.py
│   ├── main/
│   │   ├── __init__.py
│   │   ├── callbacks.py
│   │   ├── trainer.py
│   │   ├── utils_data.py
│   │   └── utils_train.py
│   ├── prompts/
│   │   └── 512_interp/
│   │       └── prompts.txt
│   ├── requirements.txt
│   ├── scripts/
│   │   ├── evaluation/
│   │   │   ├── ddp_wrapper.py
│   │   │   ├── funcs.py
│   │   │   └── inference.py
│   │   ├── gradio/
│   │   │   ├── i2v_test.py
│   │   │   └── i2v_test_application.py
│   │   └── run.sh
│   └── utils/
│       ├── __init__.py
│       ├── save_video.py
│       └── utils.py
├── __init__.py
├── pre_run.py
├── pyproject.toml
├── readme.md
└── requirements.txt
Download .txt
SYMBOL INDEX (1216 symbols across 67 files)

FILE: ToonCrafter/cldm/cldm.py
  class ControlledUnetModel (line 23) | class ControlledUnetModel(UNetModel):
    method forward (line 24) | def forward(self, x, timesteps, context=None, features_adapter=None, f...
  class ControlNet (line 90) | class ControlNet(nn.Module):
    method __init__ (line 91) | def __init__(
    method make_zero_conv (line 323) | def make_zero_conv(self, channels):
    method forward (line 326) | def forward(self, x, hint, timesteps, context, **kwargs):
  class ControlLDM (line 351) | class ControlLDM(LatentDiffusion):
    method __init__ (line 353) | def __init__(self, control_stage_config, control_key, only_mid_control...
    method get_input (line 361) | def get_input(self, batch, k, bs=None, *args, **kwargs):
    method apply_model (line 371) | def apply_model(self, x_noisy, t, cond, *args, **kwargs):
    method get_unconditional_conditioning (line 387) | def get_unconditional_conditioning(self, N):
    method log_images (line 391) | def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50,...
    method sample_log (line 452) | def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
    method configure_optimizers (line 459) | def configure_optimizers(self):
    method low_vram_shift (line 468) | def low_vram_shift(self, is_diffusing):

FILE: ToonCrafter/cldm/ddim_hacked.py
  class DDIMSampler (line 10) | class DDIMSampler(object):
    method __init__ (line 11) | def __init__(self, model, schedule="linear", **kwargs):
    method register_buffer (line 17) | def register_buffer(self, name, attr):
    method make_schedule (line 23) | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddi...
    method sample (line 55) | def sample(self,
    method ddim_sampling (line 123) | def ddim_sampling(self, cond, shape,
    method p_sample_ddim (line 181) | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_origin...
    method encode (line 234) | def encode(self, x0, c, t_enc, use_original_steps=False, return_interm...
    method stochastic_encode (line 282) | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
    method decode (line 298) | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale...

FILE: ToonCrafter/cldm/hack.py
  function disable_verbosity (line 11) | def disable_verbosity():
  function enable_sliced_attention (line 17) | def enable_sliced_attention():
  function hack_everything (line 23) | def hack_everything(clip_skip=0):
  function _hacked_clip_forward (line 32) | def _hacked_clip_forward(self, text):
  function _hacked_sliced_attentin_forward (line 72) | def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):

FILE: ToonCrafter/cldm/logger.py
  class ImageLogger (line 11) | class ImageLogger(Callback):
    method __init__ (line 12) | def __init__(self, batch_frequency=2000, max_images=4, clamp=True, inc...
    method log_local (line 28) | def log_local(self, save_dir, split, images, global_step, current_epoc...
    method log_img (line 42) | def log_img(self, pl_module, batch, batch_idx, split="train"):
    method check_frequency (line 71) | def check_frequency(self, check_idx):
    method on_train_batch_end (line 74) | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch...

FILE: ToonCrafter/cldm/model.py
  function get_state_dict (line 8) | def get_state_dict(d):
  function load_state_dict (line 12) | def load_state_dict(ckpt_path, location='cpu'):
  function create_model (line 24) | def create_model(config_path):

FILE: ToonCrafter/gradio_app.py
  function dynamicrafter_demo (line 16) | def dynamicrafter_demo(result_dir='./tmp/', res=512):
  function get_parser (line 67) | def get_parser():

FILE: ToonCrafter/ldm/data/util.py
  class AddMiDaS (line 6) | class AddMiDaS(object):
    method __init__ (line 7) | def __init__(self, model_type):
    method pt2np (line 11) | def pt2np(self, x):
    method np2pt (line 15) | def np2pt(self, x):
    method __call__ (line 19) | def __call__(self, sample):

FILE: ToonCrafter/ldm/models/autoencoder.py
  class AutoencoderKL (line 13) | class AutoencoderKL(pl.LightningModule):
    method __init__ (line 14) | def __init__(self,
    method init_from_ckpt (line 52) | def init_from_ckpt(self, path, ignore_keys=list()):
    method ema_scope (line 64) | def ema_scope(self, context=None):
    method on_train_batch_end (line 78) | def on_train_batch_end(self, *args, **kwargs):
    method encode (line 82) | def encode(self, x):
    method decode (line 88) | def decode(self, z):
    method forward (line 93) | def forward(self, input, sample_posterior=True):
    method get_input (line 102) | def get_input(self, batch, k):
    method training_step (line 109) | def training_step(self, batch, batch_idx, optimizer_idx):
    method validation_step (line 130) | def validation_step(self, batch, batch_idx):
    method _validation_step (line 136) | def _validation_step(self, batch, batch_idx, postfix=""):
    method configure_optimizers (line 150) | def configure_optimizers(self):
    method get_last_layer (line 163) | def get_last_layer(self):
    method log_images (line 167) | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
    method to_rgb (line 192) | def to_rgb(self, x):
  class IdentityFirstStage (line 201) | class IdentityFirstStage(torch.nn.Module):
    method __init__ (line 202) | def __init__(self, *args, vq_interface=False, **kwargs):
    method encode (line 206) | def encode(self, x, *args, **kwargs):
    method decode (line 209) | def decode(self, x, *args, **kwargs):
    method quantize (line 212) | def quantize(self, x, *args, **kwargs):
    method forward (line 217) | def forward(self, x, *args, **kwargs):

FILE: ToonCrafter/ldm/models/diffusion/ddim.py
  class DDIMSampler (line 10) | class DDIMSampler(object):
    method __init__ (line 11) | def __init__(self, model, schedule="linear", **kwargs):
    method register_buffer (line 17) | def register_buffer(self, name, attr):
    method make_schedule (line 23) | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddi...
    method sample (line 55) | def sample(self,
    method ddim_sampling (line 123) | def ddim_sampling(self, cond, shape,
    method p_sample_ddim (line 181) | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_origin...
    method encode (line 254) | def encode(self, x0, c, t_enc, use_original_steps=False, return_interm...
    method stochastic_encode (line 301) | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
    method decode (line 317) | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale...

FILE: ToonCrafter/ldm/models/diffusion/ddpm.py
  function disabled_train (line 36) | def disabled_train(self, mode=True):
  function uniform_on_device (line 42) | def uniform_on_device(r1, r2, shape, device):
  class DDPM (line 46) | class DDPM(pl.LightningModule):
    method __init__ (line 48) | def __init__(self,
    method register_schedule (line 138) | def register_schedule(self, given_betas=None, beta_schedule="linear", ...
    method ema_scope (line 195) | def ema_scope(self, context=None):
    method init_from_ckpt (line 210) | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
    method q_mean_variance (line 272) | def q_mean_variance(self, x_start, t):
    method predict_start_from_noise (line 284) | def predict_start_from_noise(self, x_t, t, noise):
    method predict_start_from_z_and_v (line 290) | def predict_start_from_z_and_v(self, x_t, t, v):
    method predict_eps_from_z_and_v (line 298) | def predict_eps_from_z_and_v(self, x_t, t, v):
    method q_posterior (line 304) | def q_posterior(self, x_start, x_t, t):
    method p_mean_variance (line 313) | def p_mean_variance(self, x, t, clip_denoised: bool):
    method p_sample (line 326) | def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
    method p_sample_loop (line 335) | def p_sample_loop(self, shape, return_intermediates=False):
    method sample (line 350) | def sample(self, batch_size=16, return_intermediates=False):
    method q_sample (line 356) | def q_sample(self, x_start, t, noise=None):
    method get_v (line 361) | def get_v(self, x, noise, t):
    method get_loss (line 367) | def get_loss(self, pred, target, mean=True):
    method p_losses (line 382) | def p_losses(self, x_start, t, noise=None):
    method forward (line 413) | def forward(self, x, *args, **kwargs):
    method get_input (line 419) | def get_input(self, batch, k):
    method shared_step (line 427) | def shared_step(self, batch):
    method training_step (line 432) | def training_step(self, batch, batch_idx):
    method validation_step (line 457) | def validation_step(self, batch, batch_idx):
    method on_train_batch_end (line 465) | def on_train_batch_end(self, *args, **kwargs):
    method _get_rows_from_list (line 469) | def _get_rows_from_list(self, samples):
    method log_images (line 477) | def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=Non...
    method configure_optimizers (line 514) | def configure_optimizers(self):
  class LatentDiffusion (line 523) | class LatentDiffusion(DDPM):
    method __init__ (line 526) | def __init__(self,
    method make_cond_schedule (line 584) | def make_cond_schedule(self, ):
    method on_train_batch_start (line 591) | def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
    method register_schedule (line 606) | def register_schedule(self,
    method instantiate_first_stage (line 615) | def instantiate_first_stage(self, config):
    method instantiate_cond_stage (line 622) | def instantiate_cond_stage(self, config):
    method _get_denoise_row_from_list (line 643) | def _get_denoise_row_from_list(self, samples, desc='', force_no_decode...
    method get_first_stage_encoding (line 655) | def get_first_stage_encoding(self, encoder_posterior):
    method get_learned_conditioning (line 664) | def get_learned_conditioning(self, c):
    method meshgrid (line 677) | def meshgrid(self, h, w):
    method delta_border (line 684) | def delta_border(self, h, w):
    method get_weighting (line 698) | def get_weighting(self, h, w, Ly, Lx, device):
    method get_fold_unfold (line 714) | def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1):  # todo...
    method get_input (line 767) | def get_input(self, batch, k, return_first_stage_outputs=False, force_...
    method decode_first_stage (line 820) | def decode_first_stage(self, z, predict_cids=False, force_not_quantize...
    method encode_first_stage (line 831) | def encode_first_stage(self, x):
    method shared_step (line 834) | def shared_step(self, batch, **kwargs):
    method forward (line 839) | def forward(self, x, c, *args, **kwargs):
    method apply_model (line 850) | def apply_model(self, x_noisy, t, cond, return_ids=False):
    method _predict_eps_from_xstart (line 867) | def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
    method _prior_bpd (line 871) | def _prior_bpd(self, x_start):
    method p_losses (line 885) | def p_losses(self, x_start, cond, t, noise=None):
    method p_mean_variance (line 922) | def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codeboo...
    method p_sample (line 954) | def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
    method progressive_denoising (line 985) | def progressive_denoising(self, cond, shape, verbose=True, callback=No...
    method p_sample_loop (line 1041) | def p_sample_loop(self, cond, shape, return_intermediates=False,
    method sample (line 1092) | def sample(self, cond, batch_size=16, return_intermediates=False, x_T=...
    method sample_log (line 1110) | def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
    method get_unconditional_conditioning (line 1124) | def get_unconditional_conditioning(self, batch_size, null_label=None):
    method log_images (line 1149) | def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ...
    method configure_optimizers (line 1278) | def configure_optimizers(self):
    method to_rgb (line 1303) | def to_rgb(self, x):
  class DiffusionWrapper (line 1312) | class DiffusionWrapper(pl.LightningModule):
    method __init__ (line 1313) | def __init__(self, diff_model_config, conditioning_key):
    method forward (line 1320) | def forward(self, x, t, c_concat: list = None, c_crossattn: list = Non...
  class LatentUpscaleDiffusion (line 1354) | class LatentUpscaleDiffusion(LatentDiffusion):
    method __init__ (line 1355) | def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_...
    method instantiate_low_stage (line 1363) | def instantiate_low_stage(self, config):
    method get_input (line 1371) | def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
    method log_images (line 1393) | def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200,...
  class LatentFinetuneDiffusion (line 1492) | class LatentFinetuneDiffusion(LatentDiffusion):
    method __init__ (line 1498) | def __init__(self,
    method init_from_ckpt (line 1521) | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
    method log_images (line 1553) | def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200,...
  class LatentInpaintDiffusion (line 1634) | class LatentInpaintDiffusion(LatentFinetuneDiffusion):
    method __init__ (line 1641) | def __init__(self,
    method get_input (line 1651) | def get_input(self, batch, k, cond_key=None, bs=None, return_first_sta...
    method log_images (line 1677) | def log_images(self, *args, **kwargs):
  class LatentDepth2ImageDiffusion (line 1684) | class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
    method __init__ (line 1689) | def __init__(self, depth_stage_config, concat_keys=("midas_in",), *arg...
    method get_input (line 1695) | def get_input(self, batch, k, cond_key=None, bs=None, return_first_sta...
    method log_images (line 1728) | def log_images(self, *args, **kwargs):
  class LatentUpscaleFinetuneDiffusion (line 1737) | class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
    method __init__ (line 1741) | def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
    method instantiate_low_stage (line 1752) | def instantiate_low_stage(self, config):
    method get_input (line 1760) | def get_input(self, batch, k, cond_key=None, bs=None, return_first_sta...
    method log_images (line 1794) | def log_images(self, *args, **kwargs):

FILE: ToonCrafter/ldm/models/diffusion/dpm_solver/dpm_solver.py
  class NoiseScheduleVP (line 7) | class NoiseScheduleVP:
    method __init__ (line 8) | def __init__(
    method marginal_log_mean_coeff (line 106) | def marginal_log_mean_coeff(self, t):
    method marginal_alpha (line 120) | def marginal_alpha(self, t):
    method marginal_std (line 126) | def marginal_std(self, t):
    method marginal_lambda (line 132) | def marginal_lambda(self, t):
    method inverse_lambda (line 140) | def inverse_lambda(self, lamb):
  function model_wrapper (line 161) | def model_wrapper(
  class DPM_Solver (line 319) | class DPM_Solver:
    method __init__ (line 320) | def __init__(self, model_fn, noise_schedule, predict_x0=False, thresho...
    method noise_prediction_fn (line 346) | def noise_prediction_fn(self, x, t):
    method data_prediction_fn (line 352) | def data_prediction_fn(self, x, t):
    method model_fn (line 367) | def model_fn(self, x, t):
    method get_time_steps (line 376) | def get_time_steps(self, skip_type, t_T, t_0, N, device):
    method get_orders_and_timesteps_for_singlestep_solver (line 405) | def get_orders_and_timesteps_for_singlestep_solver(self, steps, order,...
    method denoise_to_zero_fn (line 463) | def denoise_to_zero_fn(self, x, s):
    method dpm_solver_first_update (line 469) | def dpm_solver_first_update(self, x, s, t, model_s=None, return_interm...
    method singlestep_dpm_solver_second_update (line 515) | def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s...
    method singlestep_dpm_solver_third_update (line 599) | def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2...
    method multistep_dpm_solver_second_update (line 723) | def multistep_dpm_solver_second_update(self, x, model_prev_list, t_pre...
    method multistep_dpm_solver_third_update (line 780) | def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev...
    method singlestep_dpm_solver_update (line 827) | def singlestep_dpm_solver_update(self, x, s, t, order, return_intermed...
    method multistep_dpm_solver_update (line 855) | def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list,...
    method dpm_solver_adaptive (line 878) | def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0....
    method sample (line 939) | def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_...
  function interpolate_fn (line 1104) | def interpolate_fn(x, xp, yp):
  function expand_dims (line 1145) | def expand_dims(v, dims):

FILE: ToonCrafter/ldm/models/diffusion/dpm_solver/sampler.py
  class DPMSolverSampler (line 13) | class DPMSolverSampler(object):
    method __init__ (line 14) | def __init__(self, model, **kwargs):
    method register_buffer (line 20) | def register_buffer(self, name, attr):
    method sample (line 27) | def sample(self,

FILE: ToonCrafter/ldm/models/diffusion/plms.py
  class PLMSSampler (line 12) | class PLMSSampler(object):
    method __init__ (line 13) | def __init__(self, model, schedule="linear", **kwargs):
    method register_buffer (line 19) | def register_buffer(self, name, attr):
    method make_schedule (line 25) | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddi...
    method sample (line 59) | def sample(self,
    method plms_sampling (line 118) | def plms_sampling(self, cond, shape,
    method p_sample_plms (line 178) | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_origin...

FILE: ToonCrafter/ldm/models/diffusion/sampling_util.py
  function append_dims (line 5) | def append_dims(x, target_dims):
  function norm_thresholding (line 14) | def norm_thresholding(x0, value):
  function spatial_norm_thresholding (line 19) | def spatial_norm_thresholding(x0, value):

FILE: ToonCrafter/ldm/modules/attention.py
  function exists (line 23) | def exists(val):
  function uniq (line 27) | def uniq(arr):
  function default (line 31) | def default(val, d):
  function max_neg_value (line 37) | def max_neg_value(t):
  function init_ (line 41) | def init_(tensor):
  class GEGLU (line 49) | class GEGLU(nn.Module):
    method __init__ (line 50) | def __init__(self, dim_in, dim_out):
    method forward (line 54) | def forward(self, x):
  class FeedForward (line 59) | class FeedForward(nn.Module):
    method __init__ (line 60) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
    method forward (line 75) | def forward(self, x):
  function zero_module (line 79) | def zero_module(module):
  function Normalize (line 88) | def Normalize(in_channels):
  class SpatialSelfAttention (line 92) | class SpatialSelfAttention(nn.Module):
    method __init__ (line 93) | def __init__(self, in_channels):
    method forward (line 119) | def forward(self, x):
  class CrossAttention (line 145) | class CrossAttention(nn.Module):
    method __init__ (line 146) | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, ...
    method forward (line 163) | def forward(self, x, context=None, mask=None):
  class MemoryEfficientCrossAttention (line 197) | class MemoryEfficientCrossAttention(nn.Module):
    method __init__ (line 199) | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, ...
    method forward (line 216) | def forward(self, x, context=None, mask=None):
  class BasicTransformerBlock (line 246) | class BasicTransformerBlock(nn.Module):
    method __init__ (line 251) | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None,...
    method forward (line 268) | def forward(self, x, context=None):
    method _forward (line 271) | def _forward(self, x, context=None):
  class SpatialTransformer (line 278) | class SpatialTransformer(nn.Module):
    method __init__ (line 287) | def __init__(self, in_channels, n_heads, d_head,
    method forward (line 321) | def forward(self, x, context=None):

FILE: ToonCrafter/ldm/modules/diffusionmodules/model.py
  function get_timestep_embedding (line 20) | def get_timestep_embedding(timesteps, embedding_dim):
  function nonlinearity (line 41) | def nonlinearity(x):
  function Normalize (line 46) | def Normalize(in_channels, num_groups=32):
  class Upsample (line 50) | class Upsample(nn.Module):
    method __init__ (line 51) | def __init__(self, in_channels, with_conv):
    method forward (line 61) | def forward(self, x):
  class Downsample (line 68) | class Downsample(nn.Module):
    method __init__ (line 69) | def __init__(self, in_channels, with_conv):
    method forward (line 80) | def forward(self, x):
  class ResnetBlock (line 90) | class ResnetBlock(nn.Module):
    method __init__ (line 91) | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=Fa...
    method forward (line 129) | def forward(self, x, temb):
  class AttnBlock (line 152) | class AttnBlock(nn.Module):
    method __init__ (line 153) | def __init__(self, in_channels):
    method forward (line 179) | def forward(self, x):
  class MemoryEfficientAttnBlock (line 205) | class MemoryEfficientAttnBlock(nn.Module):
    method __init__ (line 212) | def __init__(self, in_channels):
    method forward (line 239) | def forward(self, x):
  class MemoryEfficientCrossAttentionWrapper (line 271) | class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
    method forward (line 272) | def forward(self, x, context=None, mask=None):
  function make_attn (line 280) | def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
  class Model (line 300) | class Model(nn.Module):
    method __init__ (line 301) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
    method forward (line 400) | def forward(self, x, t=None, context=None):
    method get_last_layer (line 448) | def get_last_layer(self):
  class Encoder (line 452) | class Encoder(nn.Module):
    method __init__ (line 453) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
    method forward (line 518) | def forward(self, x):
  class Decoder (line 546) | class Decoder(nn.Module):
    method __init__ (line 547) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
    method forward (line 619) | def forward(self, z):
  class SimpleDecoder (line 655) | class SimpleDecoder(nn.Module):
    method __init__ (line 656) | def __init__(self, in_channels, out_channels, *args, **kwargs):
    method forward (line 678) | def forward(self, x):
  class UpsampleDecoder (line 691) | class UpsampleDecoder(nn.Module):
    method __init__ (line 692) | def __init__(self, in_channels, out_channels, ch, num_res_blocks, reso...
    method forward (line 725) | def forward(self, x):
  class LatentRescaler (line 739) | class LatentRescaler(nn.Module):
    method __init__ (line 740) | def __init__(self, factor, in_channels, mid_channels, out_channels, de...
    method forward (line 764) | def forward(self, x):
  class MergedRescaleEncoder (line 776) | class MergedRescaleEncoder(nn.Module):
    method __init__ (line 777) | def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
    method forward (line 789) | def forward(self, x):
  class MergedRescaleDecoder (line 795) | class MergedRescaleDecoder(nn.Module):
    method __init__ (line 796) | def __init__(self, z_channels, out_ch, resolution, num_res_blocks, att...
    method forward (line 806) | def forward(self, x):
  class Upsampler (line 812) | class Upsampler(nn.Module):
    method __init__ (line 813) | def __init__(self, in_size, out_size, in_channels, out_channels, ch_mu...
    method forward (line 825) | def forward(self, x):
  class Resize (line 831) | class Resize(nn.Module):
    method __init__ (line 832) | def __init__(self, in_channels=None, learned=False, mode="bilinear"):
    method forward (line 847) | def forward(self, x, scale_factor=1.0):

FILE: ToonCrafter/ldm/modules/diffusionmodules/openaimodel.py
  function convert_module_to_f16 (line 23) | def convert_module_to_f16(x):
  function convert_module_to_f32 (line 26) | def convert_module_to_f32(x):
  class AttentionPool2d (line 31) | class AttentionPool2d(nn.Module):
    method __init__ (line 36) | def __init__(
    method forward (line 50) | def forward(self, x):
  class TimestepBlock (line 61) | class TimestepBlock(nn.Module):
    method forward (line 67) | def forward(self, x, emb):
  class TimestepEmbedSequential (line 73) | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    method forward (line 79) | def forward(self, x, emb, context=None, check=False):
  class Upsample (line 90) | class Upsample(nn.Module):
    method __init__ (line 99) | def __init__(self, channels, use_conv, dims=2, out_channels=None, padd...
    method forward (line 108) | def forward(self, x):
  class TransposedUpsample (line 120) | class TransposedUpsample(nn.Module):
    method __init__ (line 122) | def __init__(self, channels, out_channels=None, ks=5):
    method forward (line 129) | def forward(self,x):
  class Downsample (line 133) | class Downsample(nn.Module):
    method __init__ (line 142) | def __init__(self, channels, use_conv, dims=2, out_channels=None,paddi...
    method forward (line 157) | def forward(self, x):
  class ResBlock (line 162) | class ResBlock(TimestepBlock):
    method __init__ (line 178) | def __init__(
    method forward (line 242) | def forward(self, x, emb):
    method _forward (line 254) | def _forward(self, x, emb):
  class AttentionBlock (line 277) | class AttentionBlock(nn.Module):
    method __init__ (line 284) | def __init__(
    method forward (line 313) | def forward(self, x):
    method _forward (line 317) | def _forward(self, x):
  function count_flops_attn (line 326) | def count_flops_attn(model, _x, y):
  class QKVAttentionLegacy (line 346) | class QKVAttentionLegacy(nn.Module):
    method __init__ (line 351) | def __init__(self, n_heads):
    method forward (line 355) | def forward(self, qkv):
    method count_flops (line 374) | def count_flops(model, _x, y):
  class QKVAttention (line 378) | class QKVAttention(nn.Module):
    method __init__ (line 383) | def __init__(self, n_heads):
    method forward (line 387) | def forward(self, qkv):
    method count_flops (line 408) | def count_flops(model, _x, y):
  class UNetModel (line 412) | class UNetModel(nn.Module):
    method __init__ (line 442) | def __init__(
    method convert_to_fp16 (line 738) | def convert_to_fp16(self):
    method convert_to_fp32 (line 746) | def convert_to_fp32(self):
    method forward (line 754) | def forward(self, x, timesteps=None, context=None, y=None,**kwargs):

FILE: ToonCrafter/ldm/modules/diffusionmodules/upscaling.py
  class AbstractLowScaleModel (line 10) | class AbstractLowScaleModel(nn.Module):
    method __init__ (line 12) | def __init__(self, noise_schedule_config=None):
    method register_schedule (line 17) | def register_schedule(self, beta_schedule="linear", timesteps=1000,
    method q_sample (line 44) | def q_sample(self, x_start, t, noise=None):
    method forward (line 49) | def forward(self, x):
    method decode (line 52) | def decode(self, x):
  class SimpleImageConcat (line 56) | class SimpleImageConcat(AbstractLowScaleModel):
    method __init__ (line 58) | def __init__(self):
    method forward (line 62) | def forward(self, x):
  class ImageConcatWithNoiseAugmentation (line 67) | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
    method __init__ (line 68) | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cud...
    method forward (line 72) | def forward(self, x, noise_level=None):

FILE: ToonCrafter/ldm/modules/diffusionmodules/util.py
  function make_beta_schedule (line 21) | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_e...
  function make_ddim_timesteps (line 46) | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_...
  function make_ddim_sampling_parameters (line 63) | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbos...
  function betas_for_alpha_bar (line 77) | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.9...
  function extract_into_tensor (line 96) | def extract_into_tensor(a, t, x_shape):
  function checkpoint (line 102) | def checkpoint(func, inputs, params, flag):
  class CheckpointFunction (line 119) | class CheckpointFunction(torch.autograd.Function):
    method forward (line 121) | def forward(ctx, run_function, length, *args):
    method backward (line 133) | def backward(ctx, *output_grads):
  function timestep_embedding (line 154) | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=Fal...
  function zero_module (line 177) | def zero_module(module):
  function scale_module (line 186) | def scale_module(module, scale):
  function mean_flat (line 195) | def mean_flat(tensor):
  function normalization (line 202) | def normalization(channels):
  class SiLU (line 212) | class SiLU(nn.Module):
    method forward (line 213) | def forward(self, x):
  class GroupNorm32 (line 217) | class GroupNorm32(nn.GroupNorm):
    method forward (line 218) | def forward(self, x):
  function conv_nd (line 221) | def conv_nd(dims, *args, **kwargs):
  function linear (line 234) | def linear(*args, **kwargs):
  function avg_pool_nd (line 241) | def avg_pool_nd(dims, *args, **kwargs):
  class HybridConditioner (line 254) | class HybridConditioner(nn.Module):
    method __init__ (line 256) | def __init__(self, c_concat_config, c_crossattn_config):
    method forward (line 261) | def forward(self, c_concat, c_crossattn):
  function noise_like (line 267) | def noise_like(shape, device, repeat=False):

FILE: ToonCrafter/ldm/modules/distributions/distributions.py
  class AbstractDistribution (line 5) | class AbstractDistribution:
    method sample (line 6) | def sample(self):
    method mode (line 9) | def mode(self):
  class DiracDistribution (line 13) | class DiracDistribution(AbstractDistribution):
    method __init__ (line 14) | def __init__(self, value):
    method sample (line 17) | def sample(self):
    method mode (line 20) | def mode(self):
  class DiagonalGaussianDistribution (line 24) | class DiagonalGaussianDistribution(object):
    method __init__ (line 25) | def __init__(self, parameters, deterministic=False):
    method sample (line 35) | def sample(self):
    method kl (line 39) | def kl(self, other=None):
    method nll (line 53) | def nll(self, sample, dims=[1,2,3]):
    method mode (line 61) | def mode(self):
  function normal_kl (line 65) | def normal_kl(mean1, logvar1, mean2, logvar2):

FILE: ToonCrafter/ldm/modules/ema.py
  class LitEma (line 5) | class LitEma(nn.Module):
    method __init__ (line 6) | def __init__(self, model, decay=0.9999, use_num_upates=True):
    method reset_num_updates (line 25) | def reset_num_updates(self):
    method forward (line 29) | def forward(self, model):
    method copy_to (line 50) | def copy_to(self, model):
    method store (line 59) | def store(self, parameters):
    method restore (line 68) | def restore(self, parameters):

FILE: ToonCrafter/ldm/modules/encoders/modules.py
  class AbstractEncoder (line 11) | class AbstractEncoder(nn.Module):
    method __init__ (line 12) | def __init__(self):
    method encode (line 15) | def encode(self, *args, **kwargs):
  class IdentityEncoder (line 19) | class IdentityEncoder(AbstractEncoder):
    method encode (line 21) | def encode(self, x):
  class ClassEmbedder (line 25) | class ClassEmbedder(nn.Module):
    method __init__ (line 26) | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
    method forward (line 33) | def forward(self, batch, key=None, disable_dropout=False):
    method get_unconditional_conditioning (line 45) | def get_unconditional_conditioning(self, bs, device="cuda"):
  function disabled_train (line 52) | def disabled_train(self, mode=True):
  class FrozenT5Embedder (line 58) | class FrozenT5Embedder(AbstractEncoder):
    method __init__ (line 60) | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_...
    method freeze (line 69) | def freeze(self):
    method forward (line 75) | def forward(self, text):
    method encode (line 84) | def encode(self, text):
  class FrozenCLIPEmbedder (line 88) | class FrozenCLIPEmbedder(AbstractEncoder):
    method __init__ (line 95) | def __init__(self, version="openai/clip-vit-large-patch14", device="cu...
    method freeze (line 111) | def freeze(self):
    method forward (line 117) | def forward(self, text):
    method encode (line 130) | def encode(self, text):
  class FrozenOpenCLIPEmbedder (line 134) | class FrozenOpenCLIPEmbedder(AbstractEncoder):
    method __init__ (line 143) | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", devic...
    method freeze (line 163) | def freeze(self):
    method forward (line 168) | def forward(self, text):
    method encode_with_transformer (line 173) | def encode_with_transformer(self, text):
    method text_transformer_forward (line 182) | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
    method encode (line 192) | def encode(self, text):
  class FrozenCLIPT5Encoder (line 196) | class FrozenCLIPT5Encoder(AbstractEncoder):
    method __init__ (line 197) | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_ve...
    method encode (line 205) | def encode(self, text):
    method forward (line 208) | def forward(self, text):

FILE: ToonCrafter/ldm/modules/image_degradation/bsrgan.py
  function modcrop_np (line 29) | def modcrop_np(img, sf):
  function analytic_kernel (line 49) | def analytic_kernel(k):
  function anisotropic_Gaussian (line 65) | def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
  function gm_blur_kernel (line 86) | def gm_blur_kernel(mean, cov, size=15):
  function shift_pixel (line 99) | def shift_pixel(x, sf, upper_left=True):
  function blur (line 128) | def blur(x, k):
  function gen_kernel (line 145) | def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]),...
  function fspecial_gaussian (line 187) | def fspecial_gaussian(hsize, sigma):
  function fspecial_laplacian (line 201) | def fspecial_laplacian(alpha):
  function fspecial (line 210) | def fspecial(filter_type, *args, **kwargs):
  function bicubic_degradation (line 228) | def bicubic_degradation(x, sf=3):
  function srmd_degradation (line 240) | def srmd_degradation(x, k, sf=3):
  function dpsr_degradation (line 262) | def dpsr_degradation(x, k, sf=3):
  function classical_degradation (line 284) | def classical_degradation(x, k, sf=3):
  function add_sharpening (line 299) | def add_sharpening(img, weight=0.5, radius=50, threshold=10):
  function add_blur (line 325) | def add_blur(img, sf=4):
  function add_resize (line 339) | def add_resize(img, sf=4):
  function add_Gaussian_noise (line 369) | def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
  function add_speckle_noise (line 386) | def add_speckle_noise(img, noise_level1=2, noise_level2=25):
  function add_Poisson_noise (line 404) | def add_Poisson_noise(img):
  function add_JPEG_noise (line 418) | def add_JPEG_noise(img):
  function random_crop (line 427) | def random_crop(lq, hq, sf=4, lq_patchsize=64):
  function degradation_bsrgan (line 438) | def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
  function degradation_bsrgan_variant (line 530) | def degradation_bsrgan_variant(image, sf=4, isp_model=None):
  function degradation_bsrgan_plus (line 617) | def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True,...

FILE: ToonCrafter/ldm/modules/image_degradation/bsrgan_light.py
  function modcrop_np (line 28) | def modcrop_np(img, sf):
  function analytic_kernel (line 48) | def analytic_kernel(k):
  function anisotropic_Gaussian (line 64) | def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
  function gm_blur_kernel (line 85) | def gm_blur_kernel(mean, cov, size=15):
  function shift_pixel (line 98) | def shift_pixel(x, sf, upper_left=True):
  function blur (line 127) | def blur(x, k):
  function gen_kernel (line 144) | def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]),...
  function fspecial_gaussian (line 186) | def fspecial_gaussian(hsize, sigma):
  function fspecial_laplacian (line 200) | def fspecial_laplacian(alpha):
  function fspecial (line 209) | def fspecial(filter_type, *args, **kwargs):
  function bicubic_degradation (line 227) | def bicubic_degradation(x, sf=3):
  function srmd_degradation (line 239) | def srmd_degradation(x, k, sf=3):
  function dpsr_degradation (line 261) | def dpsr_degradation(x, k, sf=3):
  function classical_degradation (line 283) | def classical_degradation(x, k, sf=3):
  function add_sharpening (line 298) | def add_sharpening(img, weight=0.5, radius=50, threshold=10):
  function add_blur (line 324) | def add_blur(img, sf=4):
  function add_resize (line 342) | def add_resize(img, sf=4):
  function add_Gaussian_noise (line 372) | def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
  function add_speckle_noise (line 389) | def add_speckle_noise(img, noise_level1=2, noise_level2=25):
  function add_Poisson_noise (line 407) | def add_Poisson_noise(img):
  function add_JPEG_noise (line 421) | def add_JPEG_noise(img):
  function random_crop (line 430) | def random_crop(lq, hq, sf=4, lq_patchsize=64):
  function degradation_bsrgan (line 441) | def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
  function degradation_bsrgan_variant (line 533) | def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False):

FILE: ToonCrafter/ldm/modules/image_degradation/utils_image.py
  function is_image_file (line 29) | def is_image_file(filename):
  function get_timestamp (line 33) | def get_timestamp():
  function imshow (line 37) | def imshow(x, title=None, cbar=False, figsize=None):
  function surf (line 47) | def surf(Z, cmap='rainbow', figsize=None):
  function get_image_paths (line 67) | def get_image_paths(dataroot):
  function _get_paths_from_images (line 74) | def _get_paths_from_images(path):
  function patches_from_image (line 93) | def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
  function imssave (line 112) | def imssave(imgs, img_path):
  function split_imageset (line 125) | def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_si...
  function mkdir (line 153) | def mkdir(path):
  function mkdirs (line 158) | def mkdirs(paths):
  function mkdir_and_rename (line 166) | def mkdir_and_rename(path):
  function imread_uint (line 185) | def imread_uint(path, n_channels=3):
  function imsave (line 203) | def imsave(img, img_path):
  function imwrite (line 209) | def imwrite(img, img_path):
  function read_img (line 220) | def read_img(path):
  function uint2single (line 249) | def uint2single(img):
  function single2uint (line 254) | def single2uint(img):
  function uint162single (line 259) | def uint162single(img):
  function single2uint16 (line 264) | def single2uint16(img):
  function uint2tensor4 (line 275) | def uint2tensor4(img):
  function uint2tensor3 (line 282) | def uint2tensor3(img):
  function tensor2uint (line 289) | def tensor2uint(img):
  function single2tensor3 (line 302) | def single2tensor3(img):
  function single2tensor4 (line 307) | def single2tensor4(img):
  function tensor2single (line 312) | def tensor2single(img):
  function tensor2single3 (line 320) | def tensor2single3(img):
  function single2tensor5 (line 329) | def single2tensor5(img):
  function single32tensor5 (line 333) | def single32tensor5(img):
  function single42tensor4 (line 337) | def single42tensor4(img):
  function tensor2img (line 342) | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
  function augment_img (line 380) | def augment_img(img, mode=0):
  function augment_img_tensor4 (line 401) | def augment_img_tensor4(img, mode=0):
  function augment_img_tensor (line 422) | def augment_img_tensor(img, mode=0):
  function augment_img_np3 (line 441) | def augment_img_np3(img, mode=0):
  function augment_imgs (line 469) | def augment_imgs(img_list, hflip=True, rot=True):
  function modcrop (line 494) | def modcrop(img_in, scale):
  function shave (line 510) | def shave(img_in, border=0):
  function rgb2ycbcr (line 529) | def rgb2ycbcr(img, only_y=True):
  function ycbcr2rgb (line 553) | def ycbcr2rgb(img):
  function bgr2ycbcr (line 573) | def bgr2ycbcr(img, only_y=True):
  function channel_convert (line 597) | def channel_convert(in_c, tar_type, img_list):
  function calculate_psnr (line 621) | def calculate_psnr(img1, img2, border=0):
  function calculate_ssim (line 642) | def calculate_ssim(img1, img2, border=0):
  function ssim (line 669) | def ssim(img1, img2):
  function cubic (line 700) | def cubic(x):
  function calculate_weights_indices (line 708) | def calculate_weights_indices(in_length, out_length, scale, kernel, kern...
  function imresize (line 766) | def imresize(img, scale, antialiasing=True):
  function imresize_np (line 839) | def imresize_np(img, scale, antialiasing=True):

FILE: ToonCrafter/ldm/modules/midas/api.py
  function disabled_train (line 22) | def disabled_train(self, mode=True):
  function load_midas_transform (line 28) | def load_midas_transform(model_type):
  function load_model (line 73) | def load_model(model_type):
  class MiDaSInference (line 137) | class MiDaSInference(nn.Module):
    method __init__ (line 150) | def __init__(self, model_type):
    method forward (line 157) | def forward(self, x):

FILE: ToonCrafter/ldm/modules/midas/midas/base_model.py
  class BaseModel (line 4) | class BaseModel(torch.nn.Module):
    method load (line 5) | def load(self, path):

FILE: ToonCrafter/ldm/modules/midas/midas/blocks.py
  function _make_encoder (line 11) | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=F...
  function _make_scratch (line 49) | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
  function _make_pretrained_efficientnet_lite3 (line 78) | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
  function _make_efficientnet_backbone (line 88) | def _make_efficientnet_backbone(effnet):
  function _make_resnet_backbone (line 101) | def _make_resnet_backbone(resnet):
  function _make_pretrained_resnext101_wsl (line 114) | def _make_pretrained_resnext101_wsl(use_pretrained):
  class Interpolate (line 120) | class Interpolate(nn.Module):
    method __init__ (line 124) | def __init__(self, scale_factor, mode, align_corners=False):
    method forward (line 138) | def forward(self, x):
  class ResidualConvUnit (line 155) | class ResidualConvUnit(nn.Module):
    method __init__ (line 159) | def __init__(self, features):
    method forward (line 177) | def forward(self, x):
  class FeatureFusionBlock (line 194) | class FeatureFusionBlock(nn.Module):
    method __init__ (line 198) | def __init__(self, features):
    method forward (line 209) | def forward(self, *xs):
  class ResidualConvUnit_custom (line 231) | class ResidualConvUnit_custom(nn.Module):
    method __init__ (line 235) | def __init__(self, features, activation, bn):
    method forward (line 263) | def forward(self, x):
  class FeatureFusionBlock_custom (line 291) | class FeatureFusionBlock_custom(nn.Module):
    method __init__ (line 295) | def __init__(self, features, activation, deconv=False, bn=False, expan...
    method forward (line 320) | def forward(self, *xs):

FILE: ToonCrafter/ldm/modules/midas/midas/dpt_depth.py
  function _make_fusion_block (line 15) | def _make_fusion_block(features, use_bn):
  class DPT (line 26) | class DPT(BaseModel):
    method __init__ (line 27) | def __init__(
    method forward (line 67) | def forward(self, x):
  class DPTDepthModel (line 88) | class DPTDepthModel(DPT):
    method __init__ (line 89) | def __init__(self, path=None, non_negative=True, **kwargs):
    method forward (line 107) | def forward(self, x):

FILE: ToonCrafter/ldm/modules/midas/midas/midas_net.py
  class MidasNet (line 12) | class MidasNet(BaseModel):
    method __init__ (line 16) | def __init__(self, path=None, features=256, non_negative=True):
    method forward (line 49) | def forward(self, x):

FILE: ToonCrafter/ldm/modules/midas/midas/midas_net_custom.py
  class MidasNet_small (line 12) | class MidasNet_small(BaseModel):
    method __init__ (line 16) | def __init__(self, path=None, features=64, backbone="efficientnet_lite...
    method forward (line 73) | def forward(self, x):
  function fuse_model (line 109) | def fuse_model(m):

FILE: ToonCrafter/ldm/modules/midas/midas/transforms.py
  function apply_min_size (line 6) | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AR...
  class Resize (line 48) | class Resize(object):
    method __init__ (line 52) | def __init__(
    method constrain_to_multiple_of (line 94) | def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
    method get_size (line 105) | def get_size(self, width, height):
    method __call__ (line 162) | def __call__(self, sample):
  class NormalizeImage (line 197) | class NormalizeImage(object):
    method __init__ (line 201) | def __init__(self, mean, std):
    method __call__ (line 205) | def __call__(self, sample):
  class PrepareForNet (line 211) | class PrepareForNet(object):
    method __init__ (line 215) | def __init__(self):
    method __call__ (line 218) | def __call__(self, sample):

FILE: ToonCrafter/ldm/modules/midas/midas/vit.py
  class Slice (line 9) | class Slice(nn.Module):
    method __init__ (line 10) | def __init__(self, start_index=1):
    method forward (line 14) | def forward(self, x):
  class AddReadout (line 18) | class AddReadout(nn.Module):
    method __init__ (line 19) | def __init__(self, start_index=1):
    method forward (line 23) | def forward(self, x):
  class ProjectReadout (line 31) | class ProjectReadout(nn.Module):
    method __init__ (line 32) | def __init__(self, in_features, start_index=1):
    method forward (line 38) | def forward(self, x):
  class Transpose (line 45) | class Transpose(nn.Module):
    method __init__ (line 46) | def __init__(self, dim0, dim1):
    method forward (line 51) | def forward(self, x):
  function forward_vit (line 56) | def forward_vit(pretrained, x):
  function _resize_pos_embed (line 100) | def _resize_pos_embed(self, posemb, gs_h, gs_w):
  function forward_flex (line 117) | def forward_flex(self, x):
  function get_activation (line 159) | def get_activation(name):
  function get_readout_oper (line 166) | def get_readout_oper(vit_features, features, use_readout, start_index=1):
  function _make_vit_b16_backbone (line 183) | def _make_vit_b16_backbone(
  function _make_pretrained_vitl16_384 (line 297) | def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=...
  function _make_pretrained_vitb16_384 (line 310) | def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=...
  function _make_pretrained_deitb16_384 (line 319) | def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks...
  function _make_pretrained_deitb16_distil_384 (line 328) | def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore"...
  function _make_vit_b_rn50_backbone (line 343) | def _make_vit_b_rn50_backbone(
  function _make_pretrained_vitb_rn50_384 (line 478) | def _make_pretrained_vitb_rn50_384(

FILE: ToonCrafter/ldm/modules/midas/utils.py
  function read_pfm (line 9) | def read_pfm(path):
  function write_pfm (line 58) | def write_pfm(path, image, scale=1):
  function read_image (line 97) | def read_image(path):
  function resize_image (line 116) | def resize_image(img):
  function resize_depth (line 146) | def resize_depth(depth, width, height):
  function write_depth (line 165) | def write_depth(path, depth, bits=1):

FILE: ToonCrafter/ldm/util.py
  function log_txt_as_img (line 11) | def log_txt_as_img(wh, xc, size=10):
  function ismap (line 35) | def ismap(x):
  function isimage (line 41) | def isimage(x):
  function exists (line 47) | def exists(x):
  function default (line 51) | def default(val, d):
  function mean_flat (line 57) | def mean_flat(tensor):
  function count_params (line 65) | def count_params(model, verbose=False):
  function instantiate_from_config (line 72) | def instantiate_from_config(config):
  function get_obj_from_str (line 82) | def get_obj_from_str(string, reload=False):
  class AdamWwithEMAandWings (line 90) | class AdamWwithEMAandWings(optim.Optimizer):
    method __init__ (line 92) | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8,  #...
    method __setstate__ (line 113) | def __setstate__(self, state):
    method step (line 119) | def step(self, closure=None):

FILE: ToonCrafter/lvdm/basics.py
  function disabled_train (line 14) | def disabled_train(self, mode=True):
  function zero_module (line 20) | def zero_module(module):
  function scale_module (line 29) | def scale_module(module, scale):
  function conv_nd (line 38) | def conv_nd(dims, *args, **kwargs):
  function linear (line 51) | def linear(*args, **kwargs):
  function avg_pool_nd (line 58) | def avg_pool_nd(dims, *args, **kwargs):
  function nonlinearity (line 71) | def nonlinearity(type='silu'):
  class GroupNormSpecific (line 78) | class GroupNormSpecific(nn.GroupNorm):
    method forward (line 79) | def forward(self, x):
  function normalization (line 83) | def normalization(channels, num_groups=32):
  class HybridConditioner (line 92) | class HybridConditioner(nn.Module):
    method __init__ (line 94) | def __init__(self, c_concat_config, c_crossattn_config):
    method forward (line 99) | def forward(self, c_concat, c_crossattn):

FILE: ToonCrafter/lvdm/common.py
  function gather_data (line 8) | def gather_data(data, return_np=True):
  function autocast (line 16) | def autocast(f):
  function extract_into_tensor (line 25) | def extract_into_tensor(a, t, x_shape):
  function noise_like (line 31) | def noise_like(shape, device, repeat=False):
  function default (line 37) | def default(val, d):
  function exists (line 42) | def exists(val):
  function identity (line 45) | def identity(*args, **kwargs):
  function uniq (line 48) | def uniq(arr):
  function mean_flat (line 51) | def mean_flat(tensor):
  function ismap (line 57) | def ismap(x):
  function isimage (line 62) | def isimage(x):
  function max_neg_value (line 67) | def max_neg_value(t):
  function shape_to_str (line 70) | def shape_to_str(x):
  function init_ (line 74) | def init_(tensor):
  function checkpoint (line 81) | def checkpoint(func, inputs, params, flag):

FILE: ToonCrafter/lvdm/data/base.py
  class Txt2ImgIterableBaseDataset (line 5) | class Txt2ImgIterableBaseDataset(IterableDataset):
    method __init__ (line 9) | def __init__(self, num_records=0, valid_ids=None, size=256):
    method __len__ (line 18) | def __len__(self):
    method __iter__ (line 22) | def __iter__(self):

FILE: ToonCrafter/lvdm/data/webvid.py
  class WebVid (line 13) | class WebVid(Dataset):
    method __init__ (line 25) | def __init__(self,
    method _load_metadata (line 72) | def _load_metadata(self):
    method _get_video_path (line 83) | def _get_video_path(self, sample):
    method __getitem__ (line 88) | def __getitem__(self, index):
    method __len__ (line 170) | def __len__(self):

FILE: ToonCrafter/lvdm/distributions.py
  class AbstractDistribution (line 5) | class AbstractDistribution:
    method sample (line 6) | def sample(self):
    method mode (line 9) | def mode(self):
  class DiracDistribution (line 13) | class DiracDistribution(AbstractDistribution):
    method __init__ (line 14) | def __init__(self, value):
    method sample (line 17) | def sample(self):
    method mode (line 20) | def mode(self):
  class DiagonalGaussianDistribution (line 24) | class DiagonalGaussianDistribution(object):
    method __init__ (line 25) | def __init__(self, parameters, deterministic=False):
    method sample (line 35) | def sample(self, noise=None):
    method kl (line 42) | def kl(self, other=None):
    method nll (line 56) | def nll(self, sample, dims=[1,2,3]):
    method mode (line 64) | def mode(self):
  function normal_kl (line 68) | def normal_kl(mean1, logvar1, mean2, logvar2):

FILE: ToonCrafter/lvdm/ema.py
  class LitEma (line 5) | class LitEma(nn.Module):
    method __init__ (line 6) | def __init__(self, model, decay=0.9999, use_num_upates=True):
    method forward (line 25) | def forward(self,model):
    method copy_to (line 46) | def copy_to(self, model):
    method store (line 55) | def store(self, parameters):
    method restore (line 64) | def restore(self, parameters):

FILE: ToonCrafter/lvdm/models/autoencoder.py
  class AutoencoderKL (line 15) | class AutoencoderKL(pl.LightningModule):
    method __init__ (line 16) | def __init__(self,
    method init_test (line 58) | def init_test(self,):
    method init_from_ckpt (line 87) | def init_from_ckpt(self, path, ignore_keys=list()):
    method encode (line 104) | def encode(self, x, return_hidden_states=False, **kwargs):
    method decode (line 116) | def decode(self, z, **kwargs):
    method forward (line 122) | def forward(self, input, sample_posterior=True, **additional_decode_kw...
    method _forward (line 127) | def _forward(self, input, sample_posterior=True, **additional_decode_k...
    method get_input (line 137) | def get_input(self, batch, k):
    method training_step (line 147) | def training_step(self, batch, batch_idx, optimizer_idx):
    method validation_step (line 168) | def validation_step(self, batch, batch_idx):
    method configure_optimizers (line 182) | def configure_optimizers(self):
    method get_last_layer (line 193) | def get_last_layer(self):
    method log_images (line 197) | def log_images(self, batch, only_inputs=False, **kwargs):
    method to_rgb (line 213) | def to_rgb(self, x):
  class IdentityFirstStage (line 222) | class IdentityFirstStage(torch.nn.Module):
    method __init__ (line 223) | def __init__(self, *args, vq_interface=False, **kwargs):
    method encode (line 227) | def encode(self, x, *args, **kwargs):
    method decode (line 230) | def decode(self, x, *args, **kwargs):
    method quantize (line 233) | def quantize(self, x, *args, **kwargs):
    method forward (line 238) | def forward(self, x, *args, **kwargs):
  class AutoencoderKL_Dualref (line 245) | class AutoencoderKL_Dualref(AutoencoderKL):
    method __init__ (line 246) | def __init__(self,
    method _forward (line 266) | def _forward(self, input, sample_posterior=True, **additional_decode_k...

FILE: ToonCrafter/lvdm/models/autoencoder_dualref.py
  function nonlinearity (line 24) | def nonlinearity(x):
  function Normalize (line 29) | def Normalize(in_channels, num_groups=32):
  class ResnetBlock (line 35) | class ResnetBlock(nn.Module):
    method __init__ (line 36) | def __init__(
    method forward (line 72) | def forward(self, x, temb):
  class LinAttnBlock (line 95) | class LinAttnBlock(LinearAttention):
    method __init__ (line 98) | def __init__(self, in_channels):
  class AttnBlock (line 102) | class AttnBlock(nn.Module):
    method __init__ (line 103) | def __init__(self, in_channels):
    method attention (line 121) | def attention(self, h_: torch.Tensor) -> torch.Tensor:
    method forward (line 138) | def forward(self, x, **kwargs):
  class MemoryEfficientAttnBlock (line 145) | class MemoryEfficientAttnBlock(nn.Module):
    method __init__ (line 153) | def __init__(self, in_channels):
    method attention (line 172) | def attention(self, h_: torch.Tensor) -> torch.Tensor:
    method forward (line 202) | def forward(self, x, **kwargs):
  class CrossAttentionWrapper (line 209) | class CrossAttentionWrapper(CrossAttention):
    method forward (line 210) | def forward(self, x, context=None, mask=None, **unused_kwargs):
  class MemoryEfficientCrossAttentionWrapper (line 218) | class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
    method forward (line 219) | def forward(self, x, context=None, mask=None, **unused_kwargs):
  function make_attn (line 227) | def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
  class CrossAttentionWrapperFusion (line 274) | class CrossAttentionWrapperFusion(CrossAttention):
    method __init__ (line 275) | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, ...
    method forward (line 282) | def forward(self, x, context=None, mask=None):
    method _forward (line 288) | def _forward(
  class MemoryEfficientCrossAttentionWrapperFusion (line 345) | class MemoryEfficientCrossAttentionWrapperFusion(MemoryEfficientCrossAtt...
    method __init__ (line 347) | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, ...
    method forward (line 353) | def forward(self, x, context=None, mask=None):
    method _forward (line 359) | def _forward(
  class Combiner (line 431) | class Combiner(nn.Module):
    method __init__ (line 432) | def __init__(self, ch) -> None:
    method forward (line 439) | def forward(self, x, context):
    method _forward (line 445) | def _forward(self, x, context):
  class Decoder (line 459) | class Decoder(nn.Module):
    method __init__ (line 460) | def __init__(
    method _make_attn (line 568) | def _make_attn(self) -> Callable:
    method _make_resblock (line 571) | def _make_resblock(self) -> Callable:
    method _make_conv (line 574) | def _make_conv(self) -> Callable:
    method get_last_layer (line 577) | def get_last_layer(self, **kwargs):
    method forward (line 580) | def forward(self, z, ref_context=None, **kwargs):
  class TimestepBlock (line 636) | class TimestepBlock(nn.Module):
    method forward (line 642) | def forward(self, x: torch.Tensor, emb: torch.Tensor):
  class ResBlock (line 648) | class ResBlock(TimestepBlock):
    method __init__ (line 664) | def __init__(
    method forward (line 754) | def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
    method _forward (line 766) | def _forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
  class VideoTransformerBlock (line 800) | class VideoTransformerBlock(nn.Module):
    method __init__ (line 806) | def __init__(
    method forward (line 888) | def forward(
    method _forward (line 896) | def _forward(self, x, context=None, timesteps=None):
    method get_last_layer (line 929) | def get_last_layer(self):
  function partialclass (line 939) | def partialclass(cls, *args, **kwargs):
  class VideoResBlock (line 947) | class VideoResBlock(ResnetBlock):
    method __init__ (line 948) | def __init__(
    method get_alpha (line 985) | def get_alpha(self, bs):
    method forward (line 993) | def forward(self, x, temb, skip_video=False, timesteps=None):
  class AE3DConv (line 1015) | class AE3DConv(torch.nn.Conv2d):
    method __init__ (line 1016) | def __init__(self, in_channels, out_channels, video_kernel_size=3, *ar...
    method forward (line 1030) | def forward(self, input, timesteps, skip_video=False):
  class VideoBlock (line 1039) | class VideoBlock(AttnBlock):
    method __init__ (line 1040) | def __init__(
    method forward (line 1071) | def forward(self, x, timesteps, skip_video=False):
    method get_alpha (line 1098) | def get_alpha(
  class MemoryEfficientVideoBlock (line 1109) | class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock):
    method __init__ (line 1110) | def __init__(
    method forward (line 1141) | def forward(self, x, timesteps, skip_time_block=False):
    method get_alpha (line 1168) | def get_alpha(
  function make_time_attn (line 1179) | def make_time_attn(
  class Conv2DWrapper (line 1217) | class Conv2DWrapper(torch.nn.Conv2d):
    method forward (line 1218) | def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
  class VideoDecoder (line 1222) | class VideoDecoder(Decoder):
    method __init__ (line 1225) | def __init__(
    method get_last_layer (line 1243) | def get_last_layer(self, skip_time_mix=False, **kwargs):
    method _make_attn (line 1253) | def _make_attn(self) -> Callable:
    method _make_conv (line 1263) | def _make_conv(self) -> Callable:
    method _make_resblock (line 1269) | def _make_resblock(self) -> Callable:

FILE: ToonCrafter/lvdm/models/ddpm3d.py
  class DDPM (line 42) | class DDPM(pl.LightningModule):
    method __init__ (line 44) | def __init__(self,
    method register_schedule (line 125) | def register_schedule(self, given_betas=None, beta_schedule="linear", ...
    method ema_scope (line 191) | def ema_scope(self, context=None):
    method init_from_ckpt (line 205) | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
    method q_mean_variance (line 223) | def q_mean_variance(self, x_start, t):
    method predict_start_from_noise (line 235) | def predict_start_from_noise(self, x_t, t, noise):
    method predict_start_from_z_and_v (line 241) | def predict_start_from_z_and_v(self, x_t, t, v):
    method predict_eps_from_z_and_v (line 249) | def predict_eps_from_z_and_v(self, x_t, t, v):
    method q_posterior (line 255) | def q_posterior(self, x_start, x_t, t):
    method p_mean_variance (line 264) | def p_mean_variance(self, x, t, clip_denoised: bool):
    method p_sample (line 277) | def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
    method p_sample_loop (line 286) | def p_sample_loop(self, shape, return_intermediates=False):
    method sample (line 301) | def sample(self, batch_size=16, return_intermediates=False):
    method q_sample (line 307) | def q_sample(self, x_start, t, noise=None):
    method get_v (line 312) | def get_v(self, x, noise, t):
    method get_loss (line 318) | def get_loss(self, pred, target, mean=True):
    method p_losses (line 333) | def p_losses(self, x_start, t, noise=None):
    method forward (line 364) | def forward(self, x, *args, **kwargs):
    method get_input (line 370) | def get_input(self, batch, k):
    method shared_step (line 380) | def shared_step(self, batch):
    method training_step (line 385) | def training_step(self, batch, batch_idx):
    method validation_step (line 401) | def validation_step(self, batch, batch_idx):
    method on_train_batch_end (line 409) | def on_train_batch_end(self, *args, **kwargs):
    method _get_rows_from_list (line 413) | def _get_rows_from_list(self, samples):
    method log_images (line 421) | def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=Non...
    method configure_optimizers (line 458) | def configure_optimizers(self):
  class LatentDiffusion (line 466) | class LatentDiffusion(DDPM):
    method __init__ (line 468) | def __init__(self,
    method make_cond_schedule (line 550) | def make_cond_schedule(self, ):
    method on_train_batch_start (line 557) | def on_train_batch_start(self, batch, batch_idx, dataloader_idx=None):
    method register_schedule (line 574) | def register_schedule(self, given_betas=None, beta_schedule="linear", ...
    method instantiate_first_stage (line 582) | def instantiate_first_stage(self, config):
    method instantiate_cond_stage (line 589) | def instantiate_cond_stage(self, config):
    method get_learned_conditioning (line 600) | def get_learned_conditioning(self, c):
    method get_first_stage_encoding (line 613) | def get_first_stage_encoding(self, encoder_posterior, noise=None):
    method encode_first_stage (line 623) | def encode_first_stage(self, x):
    method decode_core (line 648) | def decode_core(self, z, **kwargs):
    method decode_first_stage (line 692) | def decode_first_stage(self, z, **kwargs):
    method differentiable_decode_first_stage (line 696) | def differentiable_decode_first_stage(self, z, **kwargs):
    method get_batch_input (line 700) | def get_batch_input(self, batch, random_uncond, return_first_stage_out...
    method forward (line 733) | def forward(self, x, c, **kwargs):
    method shared_step (line 739) | def shared_step(self, batch, random_uncond, **kwargs):
    method apply_model (line 745) | def apply_model(self, x_noisy, t, cond, **kwargs):
    method p_losses (line 762) | def p_losses(self, x_start, cond, t, noise=None, **kwargs):
    method training_step (line 808) | def training_step(self, batch, batch_idx):
    method _get_denoise_row_from_list (line 822) | def _get_denoise_row_from_list(self, samples, desc=''):
    method log_images (line 847) | def log_images(self, batch, sample=True, ddim_steps=200, ddim_eta=1., ...
    method p_mean_variance (line 902) | def p_mean_variance(self, x, c, t, clip_denoised: bool, return_x0=Fals...
    method p_sample (line 928) | def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, r...
    method p_sample_loop (line 950) | def p_sample_loop(self, cond, shape, return_intermediates=False, x_T=N...
    method sample (line 997) | def sample(self, cond, batch_size=16, return_intermediates=False, x_T=...
    method sample_log (line 1014) | def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
    method configure_schedulers (line 1025) | def configure_schedulers(self, optimizer):
  class LatentVisualDiffusion (line 1051) | class LatentVisualDiffusion(LatentDiffusion):
    method __init__ (line 1052) | def __init__(self, img_cond_stage_config, image_proj_stage_config, fre...
    method _init_img_ctx_projector (line 1058) | def _init_img_ctx_projector(self, config, trainable):
    method _init_embedder (line 1066) | def _init_embedder(self, config, freeze=True):
    method shared_step (line 1074) | def shared_step(self, batch, random_uncond, **kwargs):
    method get_batch_input (line 1080) | def get_batch_input(self, batch, random_uncond, return_first_stage_out...
    method log_images (line 1147) | def log_images(self, batch, sample=True, ddim_steps=50, ddim_eta=1., p...
    method configure_optimizers (line 1218) | def configure_optimizers(self):
  class DiffusionWrapper (line 1253) | class DiffusionWrapper(pl.LightningModule):
    method __init__ (line 1254) | def __init__(self, diff_model_config, conditioning_key):
    method forward (line 1259) | def forward(self, x, t, c_concat: list = None, c_crossattn: list = None,

FILE: ToonCrafter/lvdm/models/samplers/ddim.py
  class DDIMSampler (line 10) | class DDIMSampler(object):
    method __init__ (line 11) | def __init__(self, model, schedule="linear", **kwargs):
    method register_buffer (line 18) | def register_buffer(self, name, attr):
    method make_schedule (line 24) | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddi...
    method sample (line 60) | def sample(self,
    method ddim_sampling (line 135) | def ddim_sampling(self, cond, shape,
    method p_sample_ddim (line 206) | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_origin...
    method decode (line 282) | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale...
    method stochastic_encode (line 305) | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):

FILE: ToonCrafter/lvdm/models/samplers/ddim_multiplecond.py
  class DDIMSampler (line 10) | class DDIMSampler(object):
    method __init__ (line 11) | def __init__(self, model, schedule="linear", **kwargs):
    method register_buffer (line 18) | def register_buffer(self, name, attr):
    method make_schedule (line 24) | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddi...
    method sample (line 60) | def sample(self,
    method ddim_sampling (line 138) | def ddim_sampling(self, cond, shape,
    method p_sample_ddim (line 211) | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_origin...
    method decode (line 288) | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale...
    method stochastic_encode (line 310) | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):

FILE: ToonCrafter/lvdm/models/utils_diffusion.py
  function timestep_embedding (line 8) | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=Fal...
  function make_beta_schedule (line 31) | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_e...
  function make_ddim_timesteps (line 56) | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_...
  function make_ddim_sampling_parameters (line 79) | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbos...
  function betas_for_alpha_bar (line 94) | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.9...
  function rescale_zero_terminal_snr (line 112) | def rescale_zero_terminal_snr(betas):
  function rescale_noise_cfg (line 147) | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):

FILE: ToonCrafter/lvdm/modules/attention.py
  class RelativePosition (line 20) | class RelativePosition(nn.Module):
    method __init__ (line 23) | def __init__(self, num_units, max_relative_position):
    method forward (line 30) | def forward(self, length_q, length_k):
  class CrossAttention (line 42) | class CrossAttention(nn.Module):
    method __init__ (line 44) | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, ...
    method forward (line 81) | def forward(self, x, context=None, mask=None):
    method efficient_forward (line 146) | def efficient_forward(self, x, context=None, mask=None):
  class BasicTransformerBlock (line 212) | class BasicTransformerBlock(nn.Module):
    method __init__ (line 214) | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None,...
    method forward (line 231) | def forward(self, x, context=None, mask=None, **kwargs):
    method _forward (line 242) | def _forward(self, x, context=None, mask=None):
  class SpatialTransformer (line 249) | class SpatialTransformer(nn.Module):
    method __init__ (line 259) | def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., ...
    method forward (line 294) | def forward(self, x, context=None, **kwargs):
  class TemporalTransformer (line 313) | class TemporalTransformer(nn.Module):
    method __init__ (line 320) | def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., ...
    method forward (line 365) | def forward(self, x, context=None):
  class GEGLU (line 415) | class GEGLU(nn.Module):
    method __init__ (line 416) | def __init__(self, dim_in, dim_out):
    method forward (line 420) | def forward(self, x):
  class FeedForward (line 425) | class FeedForward(nn.Module):
    method __init__ (line 426) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
    method forward (line 441) | def forward(self, x):
  class LinearAttention (line 445) | class LinearAttention(nn.Module):
    method __init__ (line 446) | def __init__(self, dim, heads=4, dim_head=32):
    method forward (line 453) | def forward(self, x):
  class SpatialSelfAttention (line 464) | class SpatialSelfAttention(nn.Module):
    method __init__ (line 465) | def __init__(self, in_channels):
    method forward (line 491) | def forward(self, x):

FILE: ToonCrafter/lvdm/modules/attention_svd.py
  function exists (line 61) | def exists(val):
  function uniq (line 65) | def uniq(arr):
  function default (line 69) | def default(val, d):
  function max_neg_value (line 75) | def max_neg_value(t):
  function init_ (line 79) | def init_(tensor):
  class GEGLU (line 87) | class GEGLU(nn.Module):
    method __init__ (line 88) | def __init__(self, dim_in, dim_out):
    method forward (line 92) | def forward(self, x):
  class FeedForward (line 97) | class FeedForward(nn.Module):
    method __init__ (line 98) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
    method forward (line 112) | def forward(self, x):
  function zero_module (line 116) | def zero_module(module):
  function Normalize (line 125) | def Normalize(in_channels):
  class LinearAttention (line 131) | class LinearAttention(nn.Module):
    method __init__ (line 132) | def __init__(self, dim, heads=4, dim_head=32):
    method forward (line 139) | def forward(self, x):
  class SelfAttention (line 154) | class SelfAttention(nn.Module):
    method __init__ (line 157) | def __init__(
    method forward (line 182) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class SpatialSelfAttention (line 211) | class SpatialSelfAttention(nn.Module):
    method __init__ (line 212) | def __init__(self, in_channels):
    method forward (line 230) | def forward(self, x):
  class CrossAttention (line 256) | class CrossAttention(nn.Module):
    method __init__ (line 257) | def __init__(
    method forward (line 282) | def forward(
  class MemoryEfficientCrossAttention (line 348) | class MemoryEfficientCrossAttention(nn.Module):
    method __init__ (line 350) | def __init__(
    method forward (line 374) | def forward(
  class BasicTransformerBlock (line 457) | class BasicTransformerBlock(nn.Module):
    method __init__ (line 463) | def __init__(
    method forward (line 528) | def forward(
    method _forward (line 552) | def _forward(
  class BasicTransformerSingleLayerBlock (line 576) | class BasicTransformerSingleLayerBlock(nn.Module):
    method __init__ (line 583) | def __init__(
    method forward (line 609) | def forward(self, x, context=None):
    method _forward (line 614) | def _forward(self, x, context=None):
  class SpatialTransformer (line 620) | class SpatialTransformer(nn.Module):
    method __init__ (line 630) | def __init__(
    method forward (line 703) | def forward(self, x, context=None):
  class SimpleTransformer (line 727) | class SimpleTransformer(nn.Module):
    method __init__ (line 728) | def __init__(
    method forward (line 753) | def forward(

FILE: ToonCrafter/lvdm/modules/encoders/condition.py
  class AbstractEncoder (line 12) | class AbstractEncoder(nn.Module):
    method __init__ (line 13) | def __init__(self):
    method encode (line 16) | def encode(self, *args, **kwargs):
  class IdentityEncoder (line 20) | class IdentityEncoder(AbstractEncoder):
    method encode (line 21) | def encode(self, x):
  class ClassEmbedder (line 25) | class ClassEmbedder(nn.Module):
    method __init__ (line 26) | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
    method forward (line 33) | def forward(self, batch, key=None, disable_dropout=False):
    method get_unconditional_conditioning (line 45) | def get_unconditional_conditioning(self, bs, device="cuda"):
  function disabled_train (line 52) | def disabled_train(self, mode=True):
  function get_available_devices (line 58) | def get_available_devices():
  function get_device (line 68) | def get_device(device):
  class FrozenT5Embedder (line 75) | class FrozenT5Embedder(AbstractEncoder):
    method __init__ (line 78) | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_...
    method freeze (line 88) | def freeze(self):
    method forward (line 94) | def forward(self, text):
    method encode (line 103) | def encode(self, text):
  class FrozenCLIPEmbedder (line 107) | class FrozenCLIPEmbedder(AbstractEncoder):
    method __init__ (line 115) | def __init__(self, version="openai/clip-vit-large-patch14", device="cu...
    method freeze (line 131) | def freeze(self):
    method forward (line 137) | def forward(self, text):
    method encode (line 150) | def encode(self, text):
  class ClipImageEmbedder (line 154) | class ClipImageEmbedder(nn.Module):
    method __init__ (line 155) | def __init__(
    method preprocess (line 173) | def preprocess(self, x):
    method forward (line 183) | def forward(self, x, no_dropout=False):
  class FrozenOpenCLIPEmbedder (line 192) | class FrozenOpenCLIPEmbedder(AbstractEncoder):
    method __init__ (line 202) | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", devic...
    method freeze (line 223) | def freeze(self):
    method forward (line 228) | def forward(self, text):
    method encode_with_transformer (line 233) | def encode_with_transformer(self, text):
    method text_transformer_forward (line 242) | def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
    method encode (line 252) | def encode(self, text):
  class FrozenOpenCLIPImageEmbedder (line 256) | class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
    method __init__ (line 261) | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", devic...
    method preprocess (line 285) | def preprocess(self, x):
    method freeze (line 295) | def freeze(self):
    method forward (line 301) | def forward(self, image, no_dropout=False):
    method encode_with_vision_transformer (line 307) | def encode_with_vision_transformer(self, img):
    method encode (line 312) | def encode(self, text):
  class FrozenOpenCLIPImageEmbedderV2 (line 316) | class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder):
    method __init__ (line 321) | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", devic...
    method preprocess (line 343) | def preprocess(self, x):
    method freeze (line 353) | def freeze(self):
    method forward (line 358) | def forward(self, image, no_dropout=False):
    method encode_with_vision_transformer (line 363) | def encode_with_vision_transformer(self, x):
  class FrozenCLIPT5Encoder (line 396) | class FrozenCLIPT5Encoder(AbstractEncoder):
    method __init__ (line 397) | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_ve...
    method encode (line 405) | def encode(self, text):
    method forward (line 408) | def forward(self, text):

FILE: ToonCrafter/lvdm/modules/encoders/resampler.py
  class ImageProjModel (line 9) | class ImageProjModel(nn.Module):
    method __init__ (line 11) | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024,...
    method forward (line 18) | def forward(self, image_embeds):
  function FeedForward (line 27) | def FeedForward(dim, mult=4):
  function reshape_tensor (line 37) | def reshape_tensor(x, heads):
  class PerceiverAttention (line 48) | class PerceiverAttention(nn.Module):
    method __init__ (line 49) | def __init__(self, *, dim, dim_head=64, heads=8):
    method forward (line 64) | def forward(self, x, latents):
  class Resampler (line 96) | class Resampler(nn.Module):
    method __init__ (line 97) | def __init__(
    method forward (line 134) | def forward(self, x):

FILE: ToonCrafter/lvdm/modules/networks/ae_modules.py
  function nonlinearity (line 13) | def nonlinearity(x):
  function Normalize (line 18) | def Normalize(in_channels, num_groups=32):
  class LinAttnBlock (line 22) | class LinAttnBlock(LinearAttention):
    method __init__ (line 25) | def __init__(self, in_channels):
  class AttnBlock (line 29) | class AttnBlock(nn.Module):
    method __init__ (line 30) | def __init__(self, in_channels):
    method forward (line 56) | def forward(self, x):
  function make_attn (line 84) | def make_attn(in_channels, attn_type="vanilla"):
  class Downsample (line 95) | class Downsample(nn.Module):
    method __init__ (line 96) | def __init__(self, in_channels, with_conv):
    method forward (line 108) | def forward(self, x):
  class Upsample (line 118) | class Upsample(nn.Module):
    method __init__ (line 119) | def __init__(self, in_channels, with_conv):
    method forward (line 130) | def forward(self, x):
  function get_timestep_embedding (line 137) | def get_timestep_embedding(timesteps, embedding_dim):
  class ResnetBlock (line 158) | class ResnetBlock(nn.Module):
    method __init__ (line 159) | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=Fa...
    method forward (line 197) | def forward(self, x, temb):
  class Model (line 220) | class Model(nn.Module):
    method __init__ (line 221) | def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
    method forward (line 321) | def forward(self, x, t=None, context=None):
    method get_last_layer (line 369) | def get_last_layer(self):
  class Encoder (line 373) | class Encoder(nn.Module):
    method __init__ (line 374) | def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
    method forward (line 440) | def forward(self, x, return_hidden_states=False):
  class Decoder (line 486) | class Decoder(nn.Module):
    method __init__ (line 487) | def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
    method forward (line 560) | def forward(self, z):
  class SimpleDecoder (line 602) | class SimpleDecoder(nn.Module):
    method __init__ (line 603) | def __init__(self, in_channels, out_channels, *args, **kwargs):
    method forward (line 625) | def forward(self, x):
  class UpsampleDecoder (line 638) | class UpsampleDecoder(nn.Module):
    method __init__ (line 639) | def __init__(self, in_channels, out_channels, ch, num_res_blocks, reso...
    method forward (line 672) | def forward(self, x):
  class LatentRescaler (line 686) | class LatentRescaler(nn.Module):
    method __init__ (line 687) | def __init__(self, factor, in_channels, mid_channels, out_channels, de...
    method forward (line 711) | def forward(self, x):
  class MergedRescaleEncoder (line 723) | class MergedRescaleEncoder(nn.Module):
    method __init__ (line 724) | def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
    method forward (line 736) | def forward(self, x):
  class MergedRescaleDecoder (line 742) | class MergedRescaleDecoder(nn.Module):
    method __init__ (line 743) | def __init__(self, z_channels, out_ch, resolution, num_res_blocks, att...
    method forward (line 753) | def forward(self, x):
  class Upsampler (line 759) | class Upsampler(nn.Module):
    method __init__ (line 760) | def __init__(self, in_size, out_size, in_channels, out_channels, ch_mu...
    method forward (line 772) | def forward(self, x):
  class Resize (line 778) | class Resize(nn.Module):
    method __init__ (line 779) | def __init__(self, in_channels=None, learned=False, mode="bilinear"):
    method forward (line 794) | def forward(self, x, scale_factor=1.0):
  class FirstStagePostProcessor (line 802) | class FirstStagePostProcessor(nn.Module):
    method __init__ (line 804) | def __init__(self, ch_mult: list, in_channels,
    method instantiate_pretrained (line 838) | def instantiate_pretrained(self, config):
    method encode_with_pretrained (line 846) | def encode_with_pretrained(self, x):
    method forward (line 852) | def forward(self, x):

FILE: ToonCrafter/lvdm/modules/networks/openaimodel3d.py
  class TimestepBlock (line 19) | class TimestepBlock(nn.Module):
    method forward (line 24) | def forward(self, x, emb):
  class TimestepEmbedSequential (line 30) | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    method forward (line 36) | def forward(self, x, emb, context=None, batch_size=None):
  class Downsample (line 51) | class Downsample(nn.Module):
    method __init__ (line 60) | def __init__(self, channels, use_conv, dims=2, out_channels=None, padd...
    method forward (line 75) | def forward(self, x):
  class Upsample (line 80) | class Upsample(nn.Module):
    method __init__ (line 89) | def __init__(self, channels, use_conv, dims=2, out_channels=None, padd...
    method forward (line 98) | def forward(self, x):
  class ResBlock (line 109) | class ResBlock(TimestepBlock):
    method __init__ (line 126) | def __init__(
    method forward (line 197) | def forward(self, x, emb, batch_size=None):
    method _forward (line 210) | def _forward(self, x, emb, batch_size=None):
  class TemporalConvBlock (line 239) | class TemporalConvBlock(nn.Module):
    method __init__ (line 243) | def __init__(self, in_channels, out_channels=None, dropout=0.0, spatia...
    method forward (line 272) | def forward(self, x):
  class UNetModel (line 281) | class UNetModel(nn.Module):
    method __init__ (line 311) | def __init__(self,
    method forward (line 548) | def forward(self, x, timesteps, context=None, features_adapter=None, f...

FILE: ToonCrafter/lvdm/modules/x_transformer.py
  class AbsolutePositionalEmbedding (line 24) | class AbsolutePositionalEmbedding(nn.Module):
    method __init__ (line 25) | def __init__(self, dim, max_seq_len):
    method init_ (line 30) | def init_(self):
    method forward (line 33) | def forward(self, x):
  class FixedPositionalEmbedding (line 38) | class FixedPositionalEmbedding(nn.Module):
    method __init__ (line 39) | def __init__(self, dim):
    method forward (line 44) | def forward(self, x, seq_dim=1, offset=0):
  function exists (line 53) | def exists(val):
  function default (line 57) | def default(val, d):
  function always (line 63) | def always(val):
  function not_equals (line 69) | def not_equals(val):
  function equals (line 75) | def equals(val):
  function max_neg_value (line 81) | def max_neg_value(tensor):
  function pick_and_pop (line 87) | def pick_and_pop(keys, d):
  function group_dict_by_key (line 92) | def group_dict_by_key(cond, d):
  function string_begins_with (line 101) | def string_begins_with(prefix, str):
  function group_by_key_prefix (line 105) | def group_by_key_prefix(prefix, d):
  function groupby_prefix_and_trim (line 109) | def groupby_prefix_and_trim(prefix, d):
  class Scale (line 116) | class Scale(nn.Module):
    method __init__ (line 117) | def __init__(self, value, fn):
    method forward (line 122) | def forward(self, x, **kwargs):
  class Rezero (line 127) | class Rezero(nn.Module):
    method __init__ (line 128) | def __init__(self, fn):
    method forward (line 133) | def forward(self, x, **kwargs):
  class ScaleNorm (line 138) | class ScaleNorm(nn.Module):
    method __init__ (line 139) | def __init__(self, dim, eps=1e-5):
    method forward (line 145) | def forward(self, x):
  class RMSNorm (line 150) | class RMSNorm(nn.Module):
    method __init__ (line 151) | def __init__(self, dim, eps=1e-8):
    method forward (line 157) | def forward(self, x):
  class Residual (line 162) | class Residual(nn.Module):
    method forward (line 163) | def forward(self, x, residual):
  class GRUGating (line 167) | class GRUGating(nn.Module):
    method __init__ (line 168) | def __init__(self, dim):
    method forward (line 172) | def forward(self, x, residual):
  class GEGLU (line 183) | class GEGLU(nn.Module):
    method __init__ (line 184) | def __init__(self, dim_in, dim_out):
    method forward (line 188) | def forward(self, x):
  class FeedForward (line 193) | class FeedForward(nn.Module):
    method __init__ (line 194) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
    method forward (line 209) | def forward(self, x):
  class Attention (line 214) | class Attention(nn.Module):
    method __init__ (line 215) | def __init__(
    method forward (line 267) | def forward(
  class AttentionLayers (line 369) | class AttentionLayers(nn.Module):
    method __init__ (line 370) | def __init__(
    method forward (line 480) | def forward(
  class Encoder (line 540) | class Encoder(AttentionLayers):
    method __init__ (line 541) | def __init__(self, **kwargs):
  class TransformerWrapper (line 547) | class TransformerWrapper(nn.Module):
    method __init__ (line 548) | def __init__(
    method init_ (line 594) | def init_(self):
    method forward (line 597) | def forward(

FILE: ToonCrafter/main/callbacks.py
  class ImageLogger (line 15) | class ImageLogger(Callback):
    method __init__ (line 16) | def __init__(self, batch_frequency, max_images=8, clamp=True, rescale=...
    method log_to_tensorboard (line 31) | def log_to_tensorboard(self, pl_module, batch_logs, filename, split, s...
    method log_batch_imgs (line 58) | def log_batch_imgs(self, pl_module, batch, batch_idx, split="train"):
    method on_train_batch_end (line 90) | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch...
    method on_validation_batch_end (line 94) | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, ...
  class CUDACallback (line 104) | class CUDACallback(Callback):
    method on_train_epoch_start (line 106) | def on_train_epoch_start(self, trainer, pl_module):
    method on_train_epoch_end (line 117) | def on_train_epoch_end(self, trainer, pl_module):

FILE: ToonCrafter/main/trainer.py
  function get_parser (line 14) | def get_parser(**parser_kwargs):
  function get_nondefault_trainer_args (line 33) | def get_nondefault_trainer_args(args):
  function melk (line 129) | def melk(*args, **kwargs):
  function divein (line 136) | def divein(*args, **kwargs):

FILE: ToonCrafter/main/utils_data.py
  function worker_init_fn (line 15) | def worker_init_fn(_):
  class WrappedDataset (line 31) | class WrappedDataset(Dataset):
    method __init__ (line 34) | def __init__(self, dataset):
    method __len__ (line 37) | def __len__(self):
    method __getitem__ (line 40) | def __getitem__(self, idx):
  class DataModuleFromConfig (line 44) | class DataModuleFromConfig(pl.LightningDataModule):
    method __init__ (line 45) | def __init__(self, batch_size, train=None, validation=None, test=None,...
    method prepare_data (line 72) | def prepare_data(self):
    method setup (line 75) | def setup(self, stage=None):
    method _train_dataloader (line 81) | def _train_dataloader(self):
    method _val_dataloader (line 93) | def _val_dataloader(self, shuffle=False):
    method _test_dataloader (line 106) | def _test_dataloader(self, shuffle=False):
    method _predict_dataloader (line 128) | def _predict_dataloader(self, shuffle=False):

FILE: ToonCrafter/main/utils_train.py
  function init_workspace (line 9) | def init_workspace(name, logdir, model_config, lightning_config, rank=0):
  function check_config_attribute (line 28) | def check_config_attribute(config, name):
  function get_trainer_callbacks (line 35) | def get_trainer_callbacks(lightning_config, config, logdir, ckptdir, log...
  function get_trainer_logger (line 99) | def get_trainer_logger(lightning_config, logdir, on_debug):
  function get_trainer_strategy (line 125) | def get_trainer_strategy(lightning_config):
  function load_checkpoints (line 138) | def load_checkpoints(model, model_cfg):
  function set_logger (line 162) | def set_logger(logfile, name='mainlogger'):

FILE: ToonCrafter/scripts/evaluation/ddp_wrapper.py
  function setup_dist (line 8) | def setup_dist(local_rank):
  function get_dist_info (line 15) | def get_dist_info():

FILE: ToonCrafter/scripts/evaluation/funcs.py
  function batch_ddim_sampling (line 17) | def batch_ddim_sampling(model: LatentDiffusion, cond, noise_shape, n_sam...
  function get_filelist (line 99) | def get_filelist(data_dir, ext='*'):
  function get_dirlist (line 105) | def get_dirlist(path):
  function load_model_checkpoint (line 117) | def load_model_checkpoint(model, ckpt):
  function load_prompts (line 155) | def load_prompts(prompt_file):
  function load_video_batch (line 166) | def load_video_batch(filepath_list, frame_stride, video_size=(256, 256),...
  function load_image_batch (line 208) | def load_image_batch(filepath_list, image_size=(256, 256)):
  function save_videos (line 232) | def save_videos(batch_tensors, savedir, filenames, fps=10):
  function get_latent_z (line 247) | def get_latent_z(model, videos):

FILE: ToonCrafter/scripts/evaluation/inference.py
  function get_filelist (line 19) | def get_filelist(data_dir, postfixes):
  function load_model_checkpoint (line 27) | def load_model_checkpoint(model, ckpt):
  function load_prompts (line 54) | def load_prompts(prompt_file):
  function load_data_prompts (line 64) | def load_data_prompts(data_dir, video_size=(256,256), video_frames=16, i...
  function save_results (line 109) | def save_results(prompt, samples, filename, fakedir, fps=8, loop=False):
  function save_results_seperate (line 135) | def save_results_seperate(prompt, samples, filename, fakedir, fps=10, lo...
  function get_latent_z (line 157) | def get_latent_z(model, videos):
  function get_latent_z_with_hidden_states (line 164) | def get_latent_z_with_hidden_states(model, videos):
  function image_guided_synthesis (line 180) | def image_guided_synthesis(model, prompts, videos, noise_shape, n_sample...
  function run_inference (line 277) | def run_inference(args, gpu_num, gpu_no):
  function get_parser (line 344) | def get_parser():

FILE: ToonCrafter/scripts/gradio/i2v_test.py
  class Image2Video (line 13) | class Image2Video():
    method __init__ (line 14) | def __init__(self,result_dir='./tmp/',gpu_num=1,resolution='256_256') ...
    method get_image (line 37) | def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, f...
    method download_model (line 94) | def download_model(self):

FILE: ToonCrafter/scripts/gradio/i2v_test_application.py
  class Image2Video (line 13) | class Image2Video():
    method __init__ (line 14) | def __init__(self,result_dir='./tmp/',gpu_num=1,resolution='256_256') ...
    method get_image (line 38) | def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, f...
    method download_model (line 117) | def download_model(self):
    method get_latent_z_with_hidden_states (line 127) | def get_latent_z_with_hidden_states(self, model, videos):

FILE: ToonCrafter/utils/save_video.py
  function frames_to_mp4 (line 14) | def frames_to_mp4(frame_dir,output_path,fps):
  function tensor_to_mp4 (line 27) | def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None):
  function tensor2videogrids (line 44) | def tensor2videogrids(video, root, filename, fps, rescale=True, clamp=Tr...
  function log_local (line 62) | def log_local(batch_logs, save_dir, filename, save_fps=10, rescale=True):
  function prepare_to_log (line 120) | def prepare_to_log(batch_logs, max_images=100000, clamp=True):
  function fill_with_black_squares (line 140) | def fill_with_black_squares(video, desired_len: int) -> Tensor:
  function load_num_videos (line 150) | def load_num_videos(data_path, num_videos):
  function npz_to_video_grid (line 163) | def npz_to_video_grid(data_path, out_path, num_frames, fps, num_videos=N...

FILE: ToonCrafter/utils/utils.py
  function count_params (line 9) | def count_params(model, verbose=False):
  function check_istarget (line 16) | def check_istarget(name, para_list):
  function instantiate_from_config (line 28) | def instantiate_from_config(config):
  function get_obj_from_str (line 38) | def get_obj_from_str(string, reload=False):
  function load_npz_from_dir (line 46) | def load_npz_from_dir(data_dir):
  function load_npz_from_paths (line 52) | def load_npz_from_paths(data_paths):
  function resize_numpy_image (line 58) | def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edg...
  function setup_dist (line 71) | def setup_dist(args):

FILE: __init__.py
  function instantiate_from_config (line 36) | def instantiate_from_config(config):
  function get_obj_from_str (line 46) | def get_obj_from_str(string, reload=False):
  function get_state_dict (line 54) | def get_state_dict(d):
  function load_state_dict (line 58) | def load_state_dict(ckpt_path, location='cpu'):
  function get_models (line 71) | def get_models(root: Path = ROOT.joinpath("checkpoints"), ignoreed: tupl...
  class ToonCrafterNode (line 83) | class ToonCrafterNode:
    method INPUT_TYPES (line 86) | def INPUT_TYPES(s):
    method init (line 111) | def init(self, ckpt_name="", result_dir=ROOT.joinpath("tmp/"), gpu_num...
    method optional_autocast (line 144) | def optional_autocast(device):
    method get_image (line 152) | def get_image(self, image: torch.Tensor, ckpt_name, vram_opt_strategy,...
    method save_videos (line 266) | def save_videos(self, batch_tensors, savedir, filenames, fps=10):
    method download_model (line 281) | def download_model(self):
    method get_latent_z_with_hidden_states (line 291) | def get_latent_z_with_hidden_states(self, model, videos):
  class ToonCrafterWithSketch (line 308) | class ToonCrafterWithSketch(ToonCrafterNode):
    method INPUT_TYPES (line 311) | def INPUT_TYPES(s):
    method init (line 337) | def init(self, ckpt_name="", result_dir=ROOT.joinpath("tmp/"), gpu_num...
    method get_image (line 380) | def get_image(self, image: torch.Tensor, ckpt_name, vram_opt_strategy,...

FILE: pre_run.py
  function run (line 9) | def run():
Condensed preview — 98 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (991K chars).
[
  {
    "path": ".github/workflows/publish.yml",
    "chars": 581,
    "preview": "name: Publish to Comfy registry\non:\n  workflow_dispatch:\n  push:\n    branches:\n      - main\n      - master\n    paths:\n  "
  },
  {
    "path": ".gitignore",
    "chars": 135,
    "preview": ".DS_Store\n*pyc\n.vscode\n__pycache__\n*.egg-info\n\ncheckpoints\nToonCrafter/checkpoints\nresults\nbackup\nLOG\n/models\nToonCrafte"
  },
  {
    "path": "LICENSE",
    "chars": 11331,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "ToonCrafter/.gitignore",
    "chars": 77,
    "preview": ".DS_Store\n*pyc\n.vscode\n__pycache__\n*.egg-info\n\ncheckpoints\nresults\nbackup\nLOG"
  },
  {
    "path": "ToonCrafter/LICENSE",
    "chars": 11331,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "ToonCrafter/README.md",
    "chars": 5619,
    "preview": "## ___***ToonCrafter: Generative Cartoon Interpolation***___\n<!-- ![](./assets/logo_long.png#gh-light-mode-only){: width"
  },
  {
    "path": "ToonCrafter/__init__.py",
    "chars": 86,
    "preview": "import sys\nfrom pathlib import Path\nsys.path.append(Path(__file__).parent.as_posix())\n"
  },
  {
    "path": "ToonCrafter/cldm/cldm.py",
    "chars": 21052,
    "preview": "import einops\nimport torch\nimport torch as th\nimport torch.nn as nn\n\nfrom ToonCrafter.ldm.modules.diffusionmodules.util "
  },
  {
    "path": "ToonCrafter/cldm/ddim_hacked.py",
    "chars": 16446,
    "preview": "\"\"\"SAMPLING ONLY.\"\"\"\n\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom ToonCrafter.ldm.modules.diffusionmodul"
  },
  {
    "path": "ToonCrafter/cldm/hack.py",
    "chars": 3579,
    "preview": "import torch\nimport einops\n\nimport ldm.modules.encoders.modules\nimport ldm.modules.attention\n\nfrom transformers import l"
  },
  {
    "path": "ToonCrafter/cldm/logger.py",
    "chars": 3182,
    "preview": "import os\n\nimport numpy as np\nimport torch\nimport torchvision\nfrom PIL import Image\nfrom pytorch_lightning.callbacks imp"
  },
  {
    "path": "ToonCrafter/cldm/model.py",
    "chars": 842,
    "preview": "import os\nimport torch\n\nfrom omegaconf import OmegaConf\nfrom comfy.ldm.util import instantiate_from_config\n\n\ndef get_sta"
  },
  {
    "path": "ToonCrafter/configs/cldm_v21.yaml",
    "chars": 476,
    "preview": "control_stage_config:\n  target: ToonCrafter.cldm.cldm.ControlNet\n  params:\n    use_checkpoint: True\n    image_size: 32 #"
  },
  {
    "path": "ToonCrafter/configs/inference_512_v1.0.yaml",
    "chars": 2519,
    "preview": "model:\n  target: lvdm.models.ddpm3d.LatentVisualDiffusion\n  params:\n    rescale_betas_zero_snr: True\n    parameterizatio"
  },
  {
    "path": "ToonCrafter/configs/training_1024_v1.0/config.yaml",
    "chars": 4260,
    "preview": "model:\n  pretrained_checkpoint: checkpoints/dynamicrafter_1024_v1/model.ckpt\n  base_learning_rate: 1.0e-05\n  scale_lr: F"
  },
  {
    "path": "ToonCrafter/configs/training_1024_v1.0/run.sh",
    "chars": 1020,
    "preview": "# NCCL configuration\n# export NCCL_DEBUG=INFO\n# export NCCL_IB_DISABLE=0\n# export NCCL_IB_GID_INDEX=3\n# export NCCL_NET_"
  },
  {
    "path": "ToonCrafter/configs/training_512_v1.0/config.yaml",
    "chars": 4257,
    "preview": "model:\n  pretrained_checkpoint: checkpoints/dynamicrafter_512_v1/model.ckpt\n  base_learning_rate: 1.0e-05\n  scale_lr: Fa"
  },
  {
    "path": "ToonCrafter/configs/training_512_v1.0/run.sh",
    "chars": 1019,
    "preview": "# NCCL configuration\n# export NCCL_DEBUG=INFO\n# export NCCL_IB_DISABLE=0\n# export NCCL_IB_GID_INDEX=3\n# export NCCL_NET_"
  },
  {
    "path": "ToonCrafter/gradio_app.py",
    "chars": 4212,
    "preview": "import os\nimport argparse\nimport sys\nimport gradio as gr\nfrom scripts.gradio.i2v_test_application import Image2Video\nsys"
  },
  {
    "path": "ToonCrafter/ldm/data/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ToonCrafter/ldm/data/util.py",
    "chars": 641,
    "preview": "import torch\n\nfrom ToonCrafter.ldm.modules.midas.api import load_midas_transform\n\n\nclass AddMiDaS(object):\n    def __ini"
  },
  {
    "path": "ToonCrafter/ldm/models/autoencoder.py",
    "chars": 8608,
    "preview": "import torch\nimport pytorch_lightning as pl\nimport torch.nn.functional as F\nfrom contextlib import contextmanager\n\nfrom "
  },
  {
    "path": "ToonCrafter/ldm/models/diffusion/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ToonCrafter/ldm/models/diffusion/ddim.py",
    "chars": 17316,
    "preview": "\"\"\"SAMPLING ONLY.\"\"\"\n\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom ToonCrafter.ldm.modules.diffusionmodul"
  },
  {
    "path": "ToonCrafter/ldm/models/diffusion/ddpm.py",
    "chars": 84719,
    "preview": "\"\"\"\nwild mixture of\nhttps://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e316"
  },
  {
    "path": "ToonCrafter/ldm/models/diffusion/dpm_solver/__init__.py",
    "chars": 37,
    "preview": "from .sampler import DPMSolverSampler"
  },
  {
    "path": "ToonCrafter/ldm/models/diffusion/dpm_solver/dpm_solver.py",
    "chars": 65968,
    "preview": "import torch\nimport torch.nn.functional as F\nimport math\nfrom tqdm import tqdm\n\n\nclass NoiseScheduleVP:\n    def __init__"
  },
  {
    "path": "ToonCrafter/ldm/models/diffusion/dpm_solver/sampler.py",
    "chars": 2990,
    "preview": "\"\"\"SAMPLING ONLY.\"\"\"\nimport torch\n\nfrom .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver\n\n\nMODEL_TYPES = {\n"
  },
  {
    "path": "ToonCrafter/ldm/models/diffusion/plms.py",
    "chars": 12951,
    "preview": "\"\"\"SAMPLING ONLY.\"\"\"\n\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\nfrom functools import partial\n\nfrom ToonCraf"
  },
  {
    "path": "ToonCrafter/ldm/models/diffusion/sampling_util.py",
    "chars": 753,
    "preview": "import torch\nimport numpy as np\n\n\ndef append_dims(x, target_dims):\n    \"\"\"Appends dimensions to the end of a tensor unti"
  },
  {
    "path": "ToonCrafter/ldm/modules/attention.py",
    "chars": 11818,
    "preview": "from inspect import isfunction\nimport math\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn, einsum\nfro"
  },
  {
    "path": "ToonCrafter/ldm/modules/diffusionmodules/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ToonCrafter/ldm/modules/diffusionmodules/model.py",
    "chars": 34396,
    "preview": "# pytorch_diffusion + derived encoder decoder\nimport math\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom ein"
  },
  {
    "path": "ToonCrafter/ldm/modules/diffusionmodules/openaimodel.py",
    "chars": 30413,
    "preview": "from abc import abstractmethod\nimport math\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nimport torch.nn."
  },
  {
    "path": "ToonCrafter/ldm/modules/diffusionmodules/upscaling.py",
    "chars": 3448,
    "preview": "import torch\nimport torch.nn as nn\nimport numpy as np\nfrom functools import partial\n\nfrom ToonCrafter.ldm.modules.diffus"
  },
  {
    "path": "ToonCrafter/ldm/modules/diffusionmodules/util.py",
    "chars": 9880,
    "preview": "# adopted from\n# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py\n# and\n#"
  },
  {
    "path": "ToonCrafter/ldm/modules/distributions/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ToonCrafter/ldm/modules/distributions/distributions.py",
    "chars": 2970,
    "preview": "import torch\nimport numpy as np\n\n\nclass AbstractDistribution:\n    def sample(self):\n        raise NotImplementedError()\n"
  },
  {
    "path": "ToonCrafter/ldm/modules/ema.py",
    "chars": 3110,
    "preview": "import torch\nfrom torch import nn\n\n\nclass LitEma(nn.Module):\n    def __init__(self, model, decay=0.9999, use_num_upates="
  },
  {
    "path": "ToonCrafter/ldm/modules/encoders/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ToonCrafter/ldm/modules/encoders/modules.py",
    "chars": 7623,
    "preview": "import torch\nimport torch.nn as nn\nfrom torch.utils.checkpoint import checkpoint\n\nfrom transformers import T5Tokenizer, "
  },
  {
    "path": "ToonCrafter/ldm/modules/image_degradation/__init__.py",
    "chars": 232,
    "preview": "from ToonCrafter.ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr\nfrom ToonC"
  },
  {
    "path": "ToonCrafter/ldm/modules/image_degradation/bsrgan.py",
    "chars": 25198,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"\n# --------------------------------------------\n# Super-Resolution\n# ------------------------"
  },
  {
    "path": "ToonCrafter/ldm/modules/image_degradation/bsrgan_light.py",
    "chars": 22341,
    "preview": "# -*- coding: utf-8 -*-\nimport numpy as np\nimport cv2\nimport torch\n\nfrom functools import partial\nimport random\nfrom sci"
  },
  {
    "path": "ToonCrafter/ldm/modules/image_degradation/utils_image.py",
    "chars": 29022,
    "preview": "import os\nimport math\nimport random\nimport numpy as np\nimport torch\nimport cv2\nfrom torchvision.utils import make_grid\nf"
  },
  {
    "path": "ToonCrafter/ldm/modules/midas/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ToonCrafter/ldm/modules/midas/api.py",
    "chars": 5386,
    "preview": "# based on https://github.com/isl-org/MiDaS\n\nimport cv2\nimport torch\nimport torch.nn as nn\nfrom torchvision.transforms i"
  },
  {
    "path": "ToonCrafter/ldm/modules/midas/midas/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ToonCrafter/ldm/modules/midas/midas/base_model.py",
    "chars": 367,
    "preview": "import torch\n\n\nclass BaseModel(torch.nn.Module):\n    def load(self, path):\n        \"\"\"Load model from file.\n\n        Arg"
  },
  {
    "path": "ToonCrafter/ldm/modules/midas/midas/blocks.py",
    "chars": 9242,
    "preview": "import torch\nimport torch.nn as nn\n\nfrom .vit import (\n    _make_pretrained_vitb_rn50_384,\n    _make_pretrained_vitl16_3"
  },
  {
    "path": "ToonCrafter/ldm/modules/midas/midas/dpt_depth.py",
    "chars": 3154,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .base_model import BaseModel\nfrom .blocks impor"
  },
  {
    "path": "ToonCrafter/ldm/modules/midas/midas/midas_net.py",
    "chars": 2709,
    "preview": "\"\"\"MidashNet: Network for monocular depth estimation trained by mixing several datasets.\nThis file contains code that is"
  },
  {
    "path": "ToonCrafter/ldm/modules/midas/midas/midas_net_custom.py",
    "chars": 5207,
    "preview": "\"\"\"MidashNet: Network for monocular depth estimation trained by mixing several datasets.\nThis file contains code that is"
  },
  {
    "path": "ToonCrafter/ldm/modules/midas/midas/transforms.py",
    "chars": 7869,
    "preview": "import numpy as np\nimport cv2\nimport math\n\n\ndef apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):"
  },
  {
    "path": "ToonCrafter/ldm/modules/midas/midas/vit.py",
    "chars": 14625,
    "preview": "import torch\nimport torch.nn as nn\nimport timm\nimport types\nimport math\nimport torch.nn.functional as F\n\n\nclass Slice(nn"
  },
  {
    "path": "ToonCrafter/ldm/modules/midas/utils.py",
    "chars": 4582,
    "preview": "\"\"\"Utils for monoDepth.\"\"\"\nimport sys\nimport re\nimport numpy as np\nimport cv2\nimport torch\n\n\ndef read_pfm(path):\n    \"\"\""
  },
  {
    "path": "ToonCrafter/ldm/util.py",
    "chars": 7227,
    "preview": "import importlib\n\nimport torch\nfrom torch import optim\nimport numpy as np\n\nfrom inspect import isfunction\nfrom PIL impor"
  },
  {
    "path": "ToonCrafter/lvdm/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ToonCrafter/lvdm/basics.py",
    "chars": 2864,
    "preview": "# adopted from\n# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py\n# and\n#"
  },
  {
    "path": "ToonCrafter/lvdm/common.py",
    "chars": 2819,
    "preview": "import math\nfrom inspect import isfunction\nimport torch\nfrom torch import nn\nimport torch.distributed as dist\n\n\ndef gath"
  },
  {
    "path": "ToonCrafter/lvdm/data/base.py",
    "chars": 655,
    "preview": "from abc import abstractmethod\nfrom torch.utils.data import IterableDataset\n\n\nclass Txt2ImgIterableBaseDataset(IterableD"
  },
  {
    "path": "ToonCrafter/lvdm/data/webvid.py",
    "chars": 7926,
    "preview": "import os\nimport random\nfrom tqdm import tqdm\nimport pandas as pd\nfrom decord import VideoReader, cpu\n\nimport torch\nfrom"
  },
  {
    "path": "ToonCrafter/lvdm/distributions.py",
    "chars": 3042,
    "preview": "import torch\nimport numpy as np\n\n\nclass AbstractDistribution:\n    def sample(self):\n        raise NotImplementedError()\n"
  },
  {
    "path": "ToonCrafter/lvdm/ema.py",
    "chars": 2981,
    "preview": "import torch\nfrom torch import nn\n\n\nclass LitEma(nn.Module):\n    def __init__(self, model, decay=0.9999, use_num_upates="
  },
  {
    "path": "ToonCrafter/lvdm/models/autoencoder.py",
    "chars": 11382,
    "preview": "import os\nfrom contextlib import contextmanager\nimport torch\nimport numpy as np\nfrom einops import rearrange\nimport torc"
  },
  {
    "path": "ToonCrafter/lvdm/models/autoencoder_dualref.py",
    "chars": 42914,
    "preview": "#### https://github.com/Stability-AI/generative-models\nfrom einops import rearrange, repeat\nimport logging\nfrom typing i"
  },
  {
    "path": "ToonCrafter/lvdm/models/ddpm3d.py",
    "chars": 58857,
    "preview": "\"\"\"\nwild mixture of\nhttps://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_"
  },
  {
    "path": "ToonCrafter/lvdm/models/samplers/ddim.py",
    "chars": 16095,
    "preview": "import numpy as np\nfrom tqdm import tqdm\nimport torch\nfrom lvdm.models.utils_diffusion import make_ddim_sampling_paramet"
  },
  {
    "path": "ToonCrafter/lvdm/models/samplers/ddim_multiplecond.py",
    "chars": 16351,
    "preview": "import numpy as np\nfrom tqdm import tqdm\nimport torch\nfrom lvdm.models.utils_diffusion import make_ddim_sampling_paramet"
  },
  {
    "path": "ToonCrafter/lvdm/models/utils_diffusion.py",
    "chars": 6833,
    "preview": "import math\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom einops import repeat\n\n\ndef timestep_emb"
  },
  {
    "path": "ToonCrafter/lvdm/modules/attention.py",
    "chars": 21384,
    "preview": "import torch\nfrom torch import nn, einsum\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\nfrom func"
  },
  {
    "path": "ToonCrafter/lvdm/modules/attention_svd.py",
    "chars": 25167,
    "preview": "import logging\nimport math\nfrom inspect import isfunction\nfrom typing import Any, Optional\n\nimport torch\nimport torch.nn"
  },
  {
    "path": "ToonCrafter/lvdm/modules/encoders/condition.py",
    "chars": 15332,
    "preview": "import torch\nimport torch.nn as nn\nimport kornia\nimport open_clip\nimport os\nfrom torch.utils.checkpoint import checkpoin"
  },
  {
    "path": "ToonCrafter/lvdm/modules/encoders/resampler.py",
    "chars": 4961,
    "preview": "# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py\n# and https://gith"
  },
  {
    "path": "ToonCrafter/lvdm/modules/networks/ae_modules.py",
    "chars": 34864,
    "preview": "# pytorch_diffusion + derived encoder decoder\nimport math\n\nimport torch\nimport numpy as np\nimport torch.nn as nn\nfrom ei"
  },
  {
    "path": "ToonCrafter/lvdm/modules/networks/openaimodel3d.py",
    "chars": 26086,
    "preview": "from functools import partial\nfrom abc import abstractmethod\nimport torch\nimport torch.nn as nn\nfrom einops import rearr"
  },
  {
    "path": "ToonCrafter/lvdm/modules/x_transformer.py",
    "chars": 20157,
    "preview": "\"\"\"shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers\"\"\"\nfrom functools import partial\nf"
  },
  {
    "path": "ToonCrafter/main/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ToonCrafter/main/callbacks.py",
    "chars": 6312,
    "preview": "import os\nimport time\nimport logging\nmainlogger = logging.getLogger('mainlogger')\n\nimport torch\nimport torchvision\nimpor"
  },
  {
    "path": "ToonCrafter/main/trainer.py",
    "chars": 7494,
    "preview": "import argparse, os, sys, datetime\nfrom omegaconf import OmegaConf\nfrom transformers import logging as transf_logging\nim"
  },
  {
    "path": "ToonCrafter/main/utils_data.py",
    "chars": 5603,
    "preview": "from functools import partial\nimport numpy as np\n\nimport torch\nimport pytorch_lightning as pl\nfrom torch.utils.data impo"
  },
  {
    "path": "ToonCrafter/main/utils_train.py",
    "chars": 6963,
    "preview": "import os, re\nfrom omegaconf import OmegaConf\nimport logging\nmainlogger = logging.getLogger('mainlogger')\n\nimport torch\n"
  },
  {
    "path": "ToonCrafter/prompts/512_interp/prompts.txt",
    "chars": 41,
    "preview": "walking man\nan anime scene\nan anime scene"
  },
  {
    "path": "ToonCrafter/requirements.txt",
    "chars": 307,
    "preview": "decord==0.6.0\neinops==0.3.0\nimageio==2.9.0\nnumpy==1.24.2\nomegaconf==2.1.1\nopencv_python\npandas==2.0.0\nPillow==9.5.0\npyto"
  },
  {
    "path": "ToonCrafter/scripts/evaluation/ddp_wrapper.py",
    "chars": 1535,
    "preview": "import datetime\nimport argparse, importlib\nfrom pytorch_lightning import seed_everything\n\nimport torch\nimport torch.dist"
  },
  {
    "path": "ToonCrafter/scripts/evaluation/funcs.py",
    "chars": 10293,
    "preview": "import os\nimport sys\nimport glob\nimport numpy as np\nfrom collections import OrderedDict\nfrom decord import VideoReader, "
  },
  {
    "path": "ToonCrafter/scripts/evaluation/inference.py",
    "chars": 18558,
    "preview": "import argparse, os, sys, glob\nimport datetime, time\nfrom omegaconf import OmegaConf\nfrom tqdm import tqdm\nfrom einops i"
  },
  {
    "path": "ToonCrafter/scripts/gradio/i2v_test.py",
    "chars": 5037,
    "preview": "import os\nimport time\nfrom omegaconf import OmegaConf\nimport torch\nfrom scripts.evaluation.funcs import load_model_check"
  },
  {
    "path": "ToonCrafter/scripts/gradio/i2v_test_application.py",
    "chars": 6671,
    "preview": "import os\nimport time\nfrom omegaconf import OmegaConf\nimport torch\nfrom scripts.evaluation.funcs import load_model_check"
  },
  {
    "path": "ToonCrafter/scripts/run.sh",
    "chars": 727,
    "preview": "\nckpt=checkpoints/tooncrafter_512_interp_v1/model.ckpt\nconfig=configs/inference_512_v1.0.yaml\n\nprompt_dir=prompts/512_in"
  },
  {
    "path": "ToonCrafter/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ToonCrafter/utils/save_video.py",
    "chars": 8485,
    "preview": "import os\nimport numpy as np\nfrom tqdm import tqdm\nfrom PIL import Image\nfrom einops import rearrange\n\nimport torch\nimpo"
  },
  {
    "path": "ToonCrafter/utils/utils.py",
    "chars": 2179,
    "preview": "import os\nimport importlib\nimport numpy as np\nimport cv2\nimport torch\nimport torch.distributed as dist\n\n\ndef count_param"
  },
  {
    "path": "__init__.py",
    "chars": 23263,
    "preview": "import os\nimport sys\nimport torch\nimport time\nimport logging as logger\nimport importlib\n\nfrom functools import cache\nfro"
  },
  {
    "path": "pre_run.py",
    "chars": 1435,
    "preview": "import sys\nimport argparse\nimport os\nfrom pathlib import Path\nsys.path.append(Path(__file__).parent.as_posix())\nfrom scr"
  },
  {
    "path": "pyproject.toml",
    "chars": 774,
    "preview": "[project]\nname = \"comfyui-tooncrafter\"\ndescription = \"This project is used to enable [a/ToonCrafter](https://github.com/"
  },
  {
    "path": "readme.md",
    "chars": 2150,
    "preview": "# Introduction\nThis project is used to enable [ToonCrafter](https://github.com/ToonCrafter/ToonCrafter) to be used in Co"
  },
  {
    "path": "requirements.txt",
    "chars": 178,
    "preview": "imageio==2.9.0\nnumpy\nomegaconf==2.1.1\nopencv_python\npandas\npytorch_lightning==1.9.3\npyyaml\nsetuptools\nmoviepy\nav\nxformer"
  }
]

About this extraction

This page contains the full source code of the AIGODLIKE/ComfyUI-ToonCrafter GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 98 files (933.1 KB), approximately 235.4k tokens, and a symbol index with 1216 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!