Full Code of Lightricks/LTX-Video for AI

main 4b2d05305762 cached
54 files
418.1 KB
100.7k tokens
287 symbols
1 requests
Download .txt
Showing preview only (438K chars total). Download the full file or copy to clipboard to get everything.
Repository: Lightricks/LTX-Video
Branch: main
Commit: 4b2d05305762
Files: 54
Total size: 418.1 KB

Directory structure:
gitextract_jt2joi_e/

├── .gitattributes
├── .github/
│   └── workflows/
│       └── pylint.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── configs/
│   ├── ltxv-13b-0.9.8-dev-fp8.yaml
│   ├── ltxv-13b-0.9.8-dev.yaml
│   ├── ltxv-13b-0.9.8-distilled-fp8.yaml
│   ├── ltxv-13b-0.9.8-distilled.yaml
│   ├── ltxv-2b-0.9.1.yaml
│   ├── ltxv-2b-0.9.5.yaml
│   ├── ltxv-2b-0.9.6-dev.yaml
│   ├── ltxv-2b-0.9.6-distilled.yaml
│   ├── ltxv-2b-0.9.8-distilled-fp8.yaml
│   ├── ltxv-2b-0.9.8-distilled.yaml
│   └── ltxv-2b-0.9.yaml
├── inference.py
├── ltx_video/
│   ├── __init__.py
│   ├── inference.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── autoencoders/
│   │   │   ├── __init__.py
│   │   │   ├── causal_conv3d.py
│   │   │   ├── causal_video_autoencoder.py
│   │   │   ├── conv_nd_factory.py
│   │   │   ├── dual_conv3d.py
│   │   │   ├── latent_upsampler.py
│   │   │   ├── pixel_norm.py
│   │   │   ├── pixel_shuffle.py
│   │   │   ├── vae.py
│   │   │   ├── vae_encode.py
│   │   │   └── video_autoencoder.py
│   │   └── transformers/
│   │       ├── __init__.py
│   │       ├── attention.py
│   │       ├── embeddings.py
│   │       ├── symmetric_patchifier.py
│   │       └── transformer3d.py
│   ├── pipelines/
│   │   ├── __init__.py
│   │   ├── crf_compressor.py
│   │   └── pipeline_ltx_video.py
│   ├── schedulers/
│   │   ├── __init__.py
│   │   └── rf.py
│   └── utils/
│       ├── __init__.py
│       ├── diffusers_config_mapping.py
│       ├── prompt_enhance_utils.py
│       ├── skip_layer_strategy.py
│       └── torch_utils.py
├── pyproject.toml
└── tests/
    ├── conftest.py
    ├── test_configs.py
    ├── test_inference.py
    ├── test_scheduler.py
    ├── test_vae.py
    └── utils/
        └── .gitattributes

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitattributes
================================================
*.jpg filter=lfs diff=lfs merge=lfs -text
*.jpeg filter=lfs diff=lfs merge=lfs -text
*.png filter=lfs diff=lfs merge=lfs -text
*.gif filter=lfs diff=lfs merge=lfs -text
tests/utils/car.png filter=lfs diff=lfs merge=lfs -text


================================================
FILE: .github/workflows/pylint.yml
================================================
name: Ruff

on: [push]

jobs:
  build:
    runs-on: ubuntu-latest
    strategy:
      matrix:
        python-version: ["3.10"]
    steps:
    - name: Checkout repository and submodules
      uses: actions/checkout@v4
    - name: Set up Python ${{ matrix.python-version }}
      uses: actions/setup-python@v5
      with:
        python-version: ${{ matrix.python-version }}
    - name: Install dependencies
      run: |
        python -m pip install --upgrade pip
        pip install ruff==0.2.2 black==24.2.0
    - name: Analyzing the code with ruff
      run: |
        ruff $(git ls-files '*.py')
    - name: Verify that no Black changes are required
      run: |
        black --check $(git ls-files '*.py')


================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

# From inference.py
outputs/
*.mp4
*.png
!tests/utils/car.png

================================================
FILE: .pre-commit-config.yaml
================================================
repos:
  - repo: https://github.com/astral-sh/ruff-pre-commit
    # Ruff version.
    rev: v0.2.2
    hooks:
      # Run the linter.
      - id: ruff
        args: [--fix]  # Automatically fix issues if possible.
        types: [python]  # Ensure it only runs on .py files.

  - repo: https://github.com/psf/black
    rev: 24.2.0  # Specify the version of Black you want
    hooks:
      - id: black
        name: Black code formatter
        language_version: python3  # Use the Python version you're targeting (e.g., 3.10)

================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
<div align="center">

# LTX-Video

[![Website](https://img.shields.io/badge/Website-LTXV-181717?logo=google-chrome)](https://ltx.video)
[![Model](https://img.shields.io/badge/HuggingFace-Model-orange?logo=huggingface)](https://huggingface.co/Lightricks/LTX-Video)
[![Demo](https://img.shields.io/badge/Demo-Try%20Now-brightgreen?logo=vercel)](https://app.ltx.studio/ltx-2-playground/t2v)
[![Paper](https://img.shields.io/badge/Paper-arXiv-B31B1B?logo=arxiv)](https://arxiv.org/abs/2501.00103)
[![Trainer](https://img.shields.io/badge/LTXV-Trainer-9146FF?logo=github)](https://github.com/Lightricks/LTX-Video-Trainer)
[![Discord](https://img.shields.io/badge/Join-Discord-5865F2?logo=discord)](https://discord.gg/ltxplatform)

This is the official repository for LTX-Video.

</div>

---

## 🚀 **New: LTX-2 is Now Available!**

**We're excited to announce [LTX-2](https://github.com/Lightricks/LTX-2) - the next generation of LTX with synchronized audio+video generation!**

LTX-2 is the first DiT-based audio-video foundation model that contains all core capabilities of modern video generation in one model. **LTX-2 is now the primary home for LTX development** and includes significant improvements:

- 🎵 **Synchronized Audio+Video Generation** - Generate videos with perfectly synchronized audio
- 🎬 **Latest Model** - LTX-2 with improved quality and capabilities
- 🔌 **ComfyUI Integration** - Built into ComfyUI core for seamless workflows
- 🎯 **Advanced Features:**
  - Multiple keyframe support
  - IC-LoRA control models for precise generation
  - Standard LoRA support for style customization
  - Latent upsampler for multiscale pipelines
- 🛠️ **Training Tools** - LoRA training capabilities
- 📚 **Comprehensive Documentation** - Full documentation at [https://docs.ltx.video](https://docs.ltx.video)
- 🔄 **Active Development** - Ongoing improvements and community support

**[👉 Check out LTX-2 here](https://github.com/Lightricks/LTX-2)**

**[📖 View Documentation](https://docs.ltx.video)**

---

## Table of Contents

- [Introduction](#introduction)
- [What's New](#news)
- [Models](#models)
- [Quick Start Guide](#quick-start-guide)
  - [Online demo](#online-inference)
  - [Run locally](#run-locally)
    - [Installation](#installation)
    - [Inference](#inference)
  - [ComfyUI Integration](#comfyui-integration)
  - [Diffusers Integration](#diffusers-integration)
- [Model User Guide](#model-user-guide)
- [Community Contribution](#community-contribution)
- [Training](#training)
- [Control Models](#control-models)
- [Join Us!](#join-us)
- [Acknowledgement](#acknowledgement)

# Introduction

LTX-Video is the first DiT-based video generation model that contains all core capabilities of modern video generation in one model: synchronized audio and video, high fidelity, multiple performance modes, production-ready outputs, API access, and open access. It can generate up to 50 FPS videos at native 4K resolution with synchronized audio in one pass.
The model is trained on a large-scale dataset of diverse videos and can generate high-resolution videos with realistic and diverse content.

The model supports image-to-video, multi-keyframe conditioning, keyframe-based animation, video extension (both forward and backward), video-to-video transformations, and any combination of these features.

### Image-to-video examples
| | | |
|:---:|:---:|:---:|
| ![example1](./docs/_static/ltx-video_i2v_example_00001.gif) | ![example2](./docs/_static/ltx-video_i2v_example_00002.gif) | ![example3](./docs/_static/ltx-video_i2v_example_00003.gif) |
| ![example4](./docs/_static/ltx-video_i2v_example_00004.gif) | ![example5](./docs/_static/ltx-video_i2v_example_00005.gif) |  ![example6](./docs/_static/ltx-video_i2v_example_00006.gif) |
| ![example7](./docs/_static/ltx-video_i2v_example_00007.gif) |  ![example8](./docs/_static/ltx-video_i2v_example_00008.gif) | ![example9](./docs/_static/ltx-video_i2v_example_00009.gif) |

### Controlled video examples
| | | |
|:---:|:---:|:---:|
| ![control0](./docs/_static/ltx-video_ic_2v_example_00000.gif) | ![control1](./docs/_static/ltx-video_ic_2v_example_00001.gif) | ![control2](./docs/_static/ltx-video_ic_2v_example_00002.gif) |

| | |
|:---:|:---:|
| ![control3](./docs/_static/ltx-video_ic_2v_example_00003.gif) | ![control4](./docs/_static/ltx-video_ic_2v_example_00004.gif) |

# News

## October 23, 2025: LTX-2 Announced

Today we announced our newest foundation model, LTX-2. LTX-2 represents a major leap forward from our previous model, LTXV 0.9.8. Here’s what’s new:
* **Audio + Video, Together**: Visuals and sound are generated in one coherent process, with motion, dialogue, ambience, and music flowing simultaneously.
* **4K Fidelity**: Professional-grade precision with native 4K and up to 50 fps, sharp textures, clean motion, and synchronized audio.
* **Longer Generations**: LTX-2 supports longer, continuous clips with synchronized audio up to 10 seconds.
* **Low Cost & Efficiency**: Up to 50% lower compute cost than competing models, powered by a multi-GPU inference stack.
* **Creative Control**: Multi-keyframe conditioning, 3D camera logic, and LoRA fine-tuning deliver frame-level precision and style consistency.

For more details, please see our [blog post](https://website.ltx.video/blog/introducing-ltx-2). LTX-2 model weights, code, and benchmarks will be released to the community later in 2025. 


## July, 16th, 2025: New Distilled models v0.9.8 with up to 60 seconds of video:
- Long shot generation in LTXV-13B!
  * LTX-Video now supports up to 60 seconds of video.
  * Compatible also with the official IC-LoRAs.
  * Try now in [ComfyUI](https://github.com/Lightricks/ComfyUI-LTXVideo/tree/master/example_workflows/ltxv-13b-i2v-long-multi-prompt.json).
- Release a new distilled models:
  * 13B distilled model [ltxv-13b-0.9.8-distilled](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-distilled.yaml)
  * 2B distilled model [ltxv-2b-0.9.8-distilled](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.8-distilled.yaml)
  * Both models are distilled from the same base model [ltxv-13b-0.9.8-dev](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-dev.yaml) and are compatible for use together in the same multiscale pipeline.
  * Improved prompt understanding and detail generation
  * Includes corresponding FP8 weights and workflows.
- Release a new detailer model [LTX-Video-ICLoRA-detailer-13B-0.9.8](https://huggingface.co/Lightricks/LTX-Video-ICLoRA-detailer-13b-0.9.8)
  * Available in [ComfyUI](https://github.com/Lightricks/ComfyUI-LTXVideo/tree/master/example_workflows/ltxv-13b-upscale.json).

## July, 8th, 2025: New Control Models Released!
- Released three new control models for LTX-Video on HuggingFace:
    * **Depth Control**: [LTX-Video-ICLoRA-depth-13b-0.9.7](https://huggingface.co/Lightricks/LTX-Video-ICLoRA-depth-13b-0.9.7)
    * **Pose Control**: [LTX-Video-ICLoRA-pose-13b-0.9.7](https://huggingface.co/Lightricks/LTX-Video-ICLoRA-pose-13b-0.9.7)
    * **Canny Control**: [LTX-Video-ICLoRA-canny-13b-0.9.7](https://huggingface.co/Lightricks/LTX-Video-ICLoRA-canny-13b-0.9.7)


## May, 14th, 2025: New distilled model 13B v0.9.7:
- Release a new 13B distilled model [ltxv-13b-0.9.7-distilled](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled.safetensors)
    * Amazing for iterative work - generates HD videos in 10 seconds, with low-res preview after just 3 seconds (on H100)!
    * Does not require classifier-free guidance and spatio-temporal guidance.
    * Supports sampling with 8 (recommended), or less diffusion steps.
    * Also released a LoRA version of the distilled model, [ltxv-13b-0.9.7-distilled-lora128](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled-lora128.safetensors)
        * Requires only 1GB of VRAM
        * Can be used with the full 13B model for fast inference
- Release a new quantized distilled model [ltxv-13b-0.9.7-distilled-fp8](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled-fp8.safetensors) for *real-time* generation (on H100) with even less VRAM

## May, 5th, 2025: New model 13B v0.9.7:
- Release a new 13B model [ltxv-13b-0.9.7-dev](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev.safetensors)
- Release a new quantized model [ltxv-13b-0.9.7-dev-fp8](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev-fp8.safetensors) for faster inference with less VRam
- Release a new upscalers
  * [ltxv-temporal-upscaler-0.9.7](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-temporal-upscaler-0.9.7.safetensors)
  * [ltxv-spatial-upscaler-0.9.7](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-spatial-upscaler-0.9.7.safetensors)
- Breakthrough prompt adherence and physical understanding.
- New Pipeline for multi-scale video rendering for fast and high quality results


## April, 15th, 2025: New checkpoints v0.9.6:
- Release a new checkpoint [ltxv-2b-0.9.6-dev-04-25](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-2b-0.9.6-dev-04-25.safetensors) with improved quality
- Release a new distilled model [ltxv-2b-0.9.6-distilled-04-25](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-2b-0.9.6-distilled-04-25.safetensors)
    * 15x faster inference than non-distilled model.
    * Does not require classifier-free guidance and spatio-temporal guidance.
    * Supports sampling with 8 (recommended), or less diffusion steps.
- Improved prompt adherence, motion quality and fine details.
- New default resolution and FPS: 1216 × 704 pixels at 30 FPS
    * Still real time on H100 with the distilled model.
    * Other resolutions and FPS are still supported.
- Support stochastic inference (can improve visual quality when using the distilled model)

## March, 5th, 2025: New checkpoint v0.9.5
- New license for commercial use ([OpenRail-M](https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.5.license.txt))
- Release a new checkpoint v0.9.5 with improved quality
- Support keyframes and video extension
- Support higher resolutions
- Improved prompt understanding
- Improved VAE
- New online web app in [LTX-Studio](https://app.ltx.studio/ltx-video)
- Automatic prompt enhancement

## February, 20th, 2025: More inference options
- Improve STG (Spatiotemporal Guidance) for LTX-Video
- Support MPS on macOS with PyTorch 2.3.0
- Add support for 8-bit model, LTX-VideoQ8
- Add TeaCache for LTX-Video
- Add [ComfyUI-LTXTricks](#comfyui-integration)
- Add Diffusion-Pipe

## December 31st, 2024: Research paper
- Release the [research paper](https://arxiv.org/abs/2501.00103)

## December 20th, 2024: New checkpoint v0.9.1
- Release a new checkpoint v0.9.1 with improved quality
- Support for STG / PAG
- Support loading checkpoints of LTX-Video in Diffusers format (conversion is done on-the-fly)
- Support offloading unused parts to CPU
- Support the new timestep-conditioned VAE decoder
- Reference contributions from the community in the readme file
- Relax transformers dependency

## November 21th, 2024: Initial release v0.9.0
- Initial release of LTX-Video
- Support text-to-video and image-to-video generation


# Models

| Name                    | Notes                                                                                      | inference.py config                                                                                                                                      | ComfyUI workflow (Recommended) |
|-------------------------|--------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------|
| ltxv-13b-0.9.8-dev                   | Highest quality, requires more VRAM                                                        | [ltxv-13b-0.9.8-dev.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-dev.yaml)                                             | [ltxv-13b-i2v-base.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/ltxv-13b-i2v-base.json)             |
| [ltxv-13b-0.9.8-mix](https://app.ltx.studio/motion-workspace?videoModel=ltxv-13b)            | Mix ltxv-13b-dev and ltxv-13b-distilled in the same multi-scale rendering workflow for balanced speed-quality | N/A                                             | [ltxv-13b-i2v-mixed-multiscale.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/ltxv-13b-i2v-mixed-multiscale.json)             |
 [ltxv-13b-0.9.8-distilled](https://app.ltx.studio/motion-workspace?videoModel=ltxv)        | Faster, less VRAM usage, slight quality reduction compared to 13b. Ideal for rapid iterations | [ltxv-13b-0.9.8-distilled.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-distilled.yaml)                                    | [ltxv-13b-dist-i2v-base.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/13b-distilled/ltxv-13b-dist-i2v-base.json) |
ltxv-2b-0.9.8-distilled        | Smaller model, slight quality reduction compared to 13b distilled. Ideal for fast generation with light VRAM usage | [ltxv-2b-0.9.8-distilled.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.8-distilled.yaml)                                    | N/A |
| ltxv-13b-0.9.8-dev-fp8               | Quantized version of ltxv-13b | [ltxv-13b-0.9.8-dev-fp8.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-dev-fp8.yaml) | [ltxv-13b-i2v-base-fp8.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/ltxv-13b-i2v-base-fp8.json) |
| ltxv-13b-0.9.8-distilled-fp8     | Quantized version of ltxv-13b-distilled | [ltxv-13b-0.9.8-distilled-fp8.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-distilled-fp8.yaml) | [ltxv-13b-dist-i2v-base-fp8.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/13b-distilled/ltxv-13b-dist-i2v-base-fp8.json) |
| ltxv-2b-0.9.8-distilled-fp8     | Quantized version of ltxv-2b-distilled | [ltxv-2b-0.9.8-distilled-fp8.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.8-distilled-fp8.yaml) | N/A |
| ltxv-2b-0.9.6                     | Good quality, lower VRAM requirement than ltxv-13b                                         | [ltxv-2b-0.9.6-dev.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.6-dev.yaml)                                                 | [ltxvideo-i2v.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/low_level/ltxvideo-i2v.json)             |
| ltxv-2b-0.9.6-distilled         | 15× faster, real-time capable, fewer steps needed, no STG/CFG required                     | [ltxv-2b-0.9.6-distilled.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.6-distilled.yaml)                                     | [ltxvideo-i2v-distilled.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/low_level/ltxvideo-i2v-distilled.json)             |


# Quick Start Guide

## Online inference
The model is accessible right away via the following links:
- [LTX-Studio image-to-video (13B-mix)](https://app.ltx.studio/motion-workspace?videoModel=ltxv-13b)
- [LTX-Studio image-to-video (13B distilled)](https://app.ltx.studio/motion-workspace?videoModel=ltxv)
- [Fal.ai image-to-video (13B full)](https://fal.ai/models/fal-ai/ltx-video-13b-dev/image-to-video)
- [Fal.ai image-to-video (13B distilled)](https://fal.ai/models/fal-ai/ltx-video-13b-distilled/image-to-video)
- [Replicate image-to-video](https://replicate.com/lightricks/ltx-video)

## Run locally

### Installation
The codebase was tested with Python 3.10.5, CUDA version 12.2, and supports PyTorch >= 2.1.2.
On macOS, MPS was tested with PyTorch 2.3.0, and should support PyTorch == 2.3 or >= 2.6.

```bash
git clone https://github.com/Lightricks/LTX-Video.git
cd LTX-Video

# create env
python -m venv env
source env/bin/activate
python -m pip install -e .\[inference\]
```

#### FP8 Kernels (optional)

[FP8 kernels](https://github.com/Lightricks/LTXVideo-Q8-Kernels) developed for LTX-Video provide performance boost on supported graphics cards (Ada architecture and later). To install FP8 kernels, follow the instructions in that repository.

### Inference

📝 **Note:** For best results, we recommend using our [ComfyUI](#comfyui-integration) workflow. We're working on updating the inference.py script to match the high quality and output fidelity of ComfyUI.

To use our model, please follow the inference code in [inference.py](./inference.py):

#### For image-to-video generation:

```bash
python inference.py --prompt "PROMPT" --conditioning_media_paths IMAGE_PATH --conditioning_start_frames 0 --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED --pipeline_config configs/ltxv-13b-0.9.8-distilled.yaml
```

#### Extending a video:

📝 **Note:** Input video segments must contain a multiple of 8 frames plus 1 (e.g., 9, 17, 25, etc.), and the target frame number should be a multiple of 8.


```bash
python inference.py --prompt "PROMPT" --conditioning_media_paths VIDEO_PATH --conditioning_start_frames START_FRAME --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED --pipeline_config configs/ltxv-13b-0.9.8-distilled.yaml
```

#### For video generation with multiple conditions:

You can now generate a video conditioned on a set of images and/or short video segments.
Simply provide a list of paths to the images or video segments you want to condition on, along with their target frame numbers in the generated video. You can also specify the conditioning strength for each item (default: 1.0).

```bash
python inference.py --prompt "PROMPT" --conditioning_media_paths IMAGE_OR_VIDEO_PATH_1 IMAGE_OR_VIDEO_PATH_2 --conditioning_start_frames TARGET_FRAME_1 TARGET_FRAME_2 --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED --pipeline_config configs/ltxv-13b-0.9.8-distilled.yaml
```

### Using as a library

```python
from ltx_video.inference import infer, InferenceConfig

infer(
    InferenceConfig(
        pipeline_config="configs/ltxv-13b-0.9.8-distilled.yaml",
        prompt=PROMPT,
        height=HEIGHT,
        width=WIDTH,
        num_frames=NUM_FRAMES,
        output_path="output.mp4",
    )
)
```

## ComfyUI Integration
To use our model with ComfyUI, please follow the instructions at [https://github.com/Lightricks/ComfyUI-LTXVideo/](https://github.com/Lightricks/ComfyUI-LTXVideo/).

## Diffusers Integration
To use our model with the Diffusers Python library, check out the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video).

Diffusers also support an 8-bit version of LTX-Video, [see details below](#ltx-videoq8)

# Model User Guide

## 📝 Prompt Engineering

When writing prompts, focus on detailed, chronological descriptions of actions and scenes. Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. Start directly with the action, and keep descriptions literal and precise. Think like a cinematographer describing a shot list. Keep within 200 words. For best results, build your prompts using this structure:

* Start with main action in a single sentence
* Add specific details about movements and gestures
* Describe character/object appearances precisely
* Include background and environment details
* Specify camera angles and movements
* Describe lighting and colors
* Note any changes or sudden events
* See [examples](#introduction) for more inspiration.

### Automatic Prompt Enhancement

When using `LTXVideoPipeline` directly, you can enable prompt enhancement by setting `enhance_prompt=True`.

## 🎮 Parameter Guide

* Resolution Preset: Higher resolutions for detailed scenes, lower for faster generation and simpler scenes. The model works on resolutions that are divisible by 32 and number of frames that are divisible by 8 + 1 (e.g. 257). In case the resolution or number of frames are not divisible by 32 or 8 + 1, the input will be padded with -1 and then cropped to the desired resolution and number of frames. The model works best on resolutions under 720 x 1280 and number of frames below 257
* Seed: Save seed values to recreate specific styles or compositions you like
* Guidance Scale: 3-3.5 are the recommended values
* Inference Steps: More steps (40+) for quality, fewer steps (20-30) for speed

📝 For advanced parameters usage, please see `python inference.py --help`

## Community Contribution

### ComfyUI-LTXTricks 🛠️

A community project providing additional nodes for enhanced control over the LTX Video model. It includes implementations of advanced techniques like RF-Inversion, RF-Edit, FlowEdit, and more. These nodes enable workflows such as Image and Video to Video (I+V2V), enhanced sampling via Spatiotemporal Skip Guidance (STG), and interpolation with precise frame settings.

- **Repository:** [ComfyUI-LTXTricks](https://github.com/logtd/ComfyUI-LTXTricks)
- **Features:**
  - 🔄 **RF-Inversion:** Implements [RF-Inversion](https://rf-inversion.github.io/) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_inversion.json).
  - ✂️ **RF-Edit:** Implements [RF-Solver-Edit](https://github.com/wangjiangshan0725/RF-Solver-Edit) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_rf_edit.json).
  - 🌊 **FlowEdit:** Implements [FlowEdit](https://github.com/fallenshock/FlowEdit) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_flow_edit.json).
  - 🎥 **I+V2V:** Enables Video to Video with a reference image. [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_iv2v.json).
  - ✨ **Enhance:** Partial implementation of [STGuidance](https://junhahyung.github.io/STGuidance/). [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltxv_stg.json).
  - 🖼️ **Interpolation and Frame Setting:** Nodes for precise control of latents per frame. [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_interpolation.json).


### LTX-VideoQ8 🎱 <a id="ltx-videoq8"></a>

**LTX-VideoQ8** is an 8-bit optimized version of [LTX-Video](https://github.com/Lightricks/LTX-Video), designed for faster performance on NVIDIA ADA GPUs.

- **Repository:** [LTX-VideoQ8](https://github.com/KONAKONA666/LTX-Video)
- **Features:**
  - 🚀 Up to 3X speed-up with no accuracy loss
  - 🎥 Generate 720x480x121 videos in under a minute on RTX 4060 (8GB VRAM)
  - 🛠️ Fine-tune 2B transformer models with precalculated latents
- **Community Discussion:** [Reddit Thread](https://www.reddit.com/r/StableDiffusion/comments/1h79ks2/fast_ltx_video_on_rtx_4060_and_other_ada_gpus/)
- **Diffusers integration:** A diffusers integration for the 8-bit model is already out! [Details here](https://github.com/sayakpaul/q8-ltx-video)


### TeaCache for LTX-Video 🍵 <a id="TeaCache"></a>

**TeaCache** is a training-free caching approach that leverages timestep differences across model outputs to accelerate LTX-Video inference by up to 2x without significant visual quality degradation.

- **Repository:** [TeaCache4LTX-Video](https://github.com/ali-vilab/TeaCache/tree/main/TeaCache4LTX-Video)
- **Features:**
  - 🚀 Speeds up LTX-Video inference.
  - 📊 Adjustable trade-offs between speed (up to 2x) and visual quality using configurable parameters.
  - 🛠️ No retraining required: Works directly with existing models.

### Your Contribution

...is welcome! If you have a project or tool that integrates with LTX-Video,
please let us know by opening an issue or pull request.

# Training

We provide an open-source repository for fine-tuning the LTX-Video model: [LTX-Video-Trainer](https://github.com/Lightricks/LTX-Video-Trainer).
This repository supports both the 2B and 13B model variants, enabling full fine-tuning as well as LoRA (Low-Rank Adaptation) fine-tuning for more efficient training. This includes:

- **Control LoRAs**: Train custom control models like depth, pose, and canny control
- **Effect LoRAs**: Create specialized effects and transformations for video generation

Explore the repository to customize the model for your specific use cases!
More information and training instructions can be found in the [README](https://github.com/Lightricks/LTX-Video-Trainer/blob/main/README.md).

# Control Models

[ComfyUI-LTXVideo](https://github.com/Lightricks/ComfyUI-LTXVideo) repository now contains workflows and models for 3 specialized models that enable precise control over LTX-Video generation:

Pose Control, Depth Control and Canny Control

**Example ComfyUI Workflow (for all control types):** [ic-lora.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/ic_lora/ic-lora.json)

# Join Us

Want to work on cutting-edge AI research and make a real impact on millions of users worldwide?

At **Lightricks**, an AI-first company, we're revolutionizing how visual content is created.

If you are passionate about AI, computer vision, and video generation, we would love to hear from you!

Please visit our [careers page](https://careers.lightricks.com/careers?query=&office=all&department=R%26D) for more information.

# Acknowledgement

We are grateful for the following awesome projects when implementing LTX-Video:
* [DiT](https://github.com/facebookresearch/DiT) and [PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha): vision transformers for image generation.


## Citation

📄 Our tech report is out! If you find our work helpful, please ⭐️ star the repository and cite our paper.

```
@article{HaCohen2024LTXVideo,
  title={LTX-Video: Realtime Video Latent Diffusion},
  author={HaCohen, Yoav and Chiprut, Nisan and Brazowski, Benny and Shalem, Daniel and Moshe, Dudu and Richardson, Eitan and Levin, Eran and Shiran, Guy and Zabari, Nir and Gordon, Ori and Panet, Poriya and Weissbuch, Sapir and Kulikov, Victor and Bitterman, Yaki and Melumian, Zeev and Bibi, Ofir},
  journal={arXiv preprint arXiv:2501.00103},
  year={2024}
}
```


================================================
FILE: configs/ltxv-13b-0.9.8-dev-fp8.yaml
================================================
pipeline_type: multi-scale
checkpoint_path: "ltxv-13b-0.9.8-dev-fp8.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "float8_e4m3fn" # options: "float8_e4m3fn", "bfloat16", "mixed_precision"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false

first_pass:
  guidance_scale: [1, 1, 6, 8, 6, 1, 1]
  stg_scale: [0, 0, 4, 4, 4, 2, 1]
  rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
  guidance_timesteps: [1.0, 0.996,  0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
  skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
  num_inference_steps: 30
  skip_final_inference_steps: 3
  cfg_star_rescale: true

second_pass:
  guidance_scale: [1]
  stg_scale: [1]
  rescaling_scale: [1]
  guidance_timesteps: [1.0]
  skip_block_list: [27]
  num_inference_steps: 30
  skip_initial_inference_steps: 17
  cfg_star_rescale: true


================================================
FILE: configs/ltxv-13b-0.9.8-dev.yaml
================================================
pipeline_type: multi-scale
checkpoint_path: "ltxv-13b-0.9.8-dev.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false

first_pass:
  guidance_scale: [1, 1, 6, 8, 6, 1, 1]
  stg_scale: [0, 0, 4, 4, 4, 2, 1]
  rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
  guidance_timesteps: [1.0, 0.996,  0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
  skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
  num_inference_steps: 30
  skip_final_inference_steps: 3
  cfg_star_rescale: true

second_pass:
  guidance_scale: [1]
  stg_scale: [1]
  rescaling_scale: [1]
  guidance_timesteps: [1.0]
  skip_block_list: [27]
  num_inference_steps: 30
  skip_initial_inference_steps: 17
  cfg_star_rescale: true

================================================
FILE: configs/ltxv-13b-0.9.8-distilled-fp8.yaml
================================================
pipeline_type: multi-scale
checkpoint_path: "ltxv-13b-0.9.8-distilled-fp8.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "float8_e4m3fn" # options: "float8_e4m3fn", "bfloat16", "mixed_precision"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false

first_pass:
  timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
  guidance_scale: 1
  stg_scale: 0
  rescaling_scale: 1
  skip_block_list: [42]

second_pass:
  timesteps: [0.9094, 0.7250, 0.4219]
  guidance_scale: 1
  stg_scale: 0
  rescaling_scale: 1
  skip_block_list: [42]
  tone_map_compression_ratio: 0.6


================================================
FILE: configs/ltxv-13b-0.9.8-distilled.yaml
================================================
pipeline_type: multi-scale
checkpoint_path: "ltxv-13b-0.9.8-distilled.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false

first_pass:
  timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
  guidance_scale: 1
  stg_scale: 0
  rescaling_scale: 1
  skip_block_list: [42]

second_pass:
  timesteps: [0.9094, 0.7250, 0.4219]
  guidance_scale: 1
  stg_scale: 0
  rescaling_scale: 1
  skip_block_list: [42]
  tone_map_compression_ratio: 0.6


================================================
FILE: configs/ltxv-2b-0.9.1.yaml
================================================
pipeline_type: base
checkpoint_path: "ltx-video-2b-v0.9.1.safetensors"
guidance_scale: 3
stg_scale: 1
rescaling_scale: 0.7
skip_block_list: [19]
num_inference_steps: 40
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false

================================================
FILE: configs/ltxv-2b-0.9.5.yaml
================================================
pipeline_type: base
checkpoint_path: "ltx-video-2b-v0.9.5.safetensors"
guidance_scale: 3
stg_scale: 1
rescaling_scale: 0.7
skip_block_list: [19]
num_inference_steps: 40
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false

================================================
FILE: configs/ltxv-2b-0.9.6-dev.yaml
================================================
pipeline_type: base
checkpoint_path: "ltxv-2b-0.9.6-dev-04-25.safetensors"
guidance_scale: 3
stg_scale: 1
rescaling_scale: 0.7
skip_block_list: [19]
num_inference_steps: 40
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false

================================================
FILE: configs/ltxv-2b-0.9.6-distilled.yaml
================================================
pipeline_type: base
checkpoint_path: "ltxv-2b-0.9.6-distilled-04-25.safetensors"
guidance_scale: 1
stg_scale: 0
rescaling_scale: 1
num_inference_steps: 8
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: true

================================================
FILE: configs/ltxv-2b-0.9.8-distilled-fp8.yaml
================================================
pipeline_type: multi-scale
checkpoint_path: "ltxv-2b-0.9.8-distilled-fp8.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "float8_e4m3fn" # options: "float8_e4m3fn", "bfloat16", "mixed_precision"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false

first_pass:
  timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
  guidance_scale: 1
  stg_scale: 0
  rescaling_scale: 1
  skip_block_list: [42]

second_pass:
  timesteps: [0.9094, 0.7250, 0.4219]
  guidance_scale: 1
  stg_scale: 0
  rescaling_scale: 1
  skip_block_list: [42]


================================================
FILE: configs/ltxv-2b-0.9.8-distilled.yaml
================================================
pipeline_type: multi-scale
checkpoint_path: "ltxv-2b-0.9.8-distilled.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false

first_pass:
  timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
  guidance_scale: 1
  stg_scale: 0
  rescaling_scale: 1
  skip_block_list: [42]

second_pass:
  timesteps: [0.9094, 0.7250, 0.4219]
  guidance_scale: 1
  stg_scale: 0
  rescaling_scale: 1
  skip_block_list: [42]


================================================
FILE: configs/ltxv-2b-0.9.yaml
================================================
pipeline_type: base
checkpoint_path: "ltx-video-2b-v0.9.safetensors"
guidance_scale: 3
stg_scale: 1
rescaling_scale: 0.7
skip_block_list: [19]
num_inference_steps: 40
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false

================================================
FILE: inference.py
================================================
from transformers import HfArgumentParser

from ltx_video.inference import infer, InferenceConfig


def main():
    parser = HfArgumentParser(InferenceConfig)
    config = parser.parse_args_into_dataclasses()[0]
    infer(config=config)


if __name__ == "__main__":
    main()


================================================
FILE: ltx_video/__init__.py
================================================


================================================
FILE: ltx_video/inference.py
================================================
import os
import random
from datetime import datetime
from pathlib import Path
from diffusers.utils import logging
from typing import Optional, List, Union
import yaml

import imageio
import json
import numpy as np
import torch
from safetensors import safe_open
from PIL import Image
import torchvision.transforms.functional as TVF
from transformers import (
    T5EncoderModel,
    T5Tokenizer,
    AutoModelForCausalLM,
    AutoProcessor,
    AutoTokenizer,
)
from huggingface_hub import hf_hub_download
from dataclasses import dataclass, field

from ltx_video.models.autoencoders.causal_video_autoencoder import (
    CausalVideoAutoencoder,
)
from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
from ltx_video.models.transformers.transformer3d import Transformer3DModel
from ltx_video.pipelines.pipeline_ltx_video import (
    ConditioningItem,
    LTXVideoPipeline,
    LTXMultiScalePipeline,
)
from ltx_video.schedulers.rf import RectifiedFlowScheduler
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
import ltx_video.pipelines.crf_compressor as crf_compressor

logger = logging.get_logger("LTX-Video")


def get_total_gpu_memory():
    if torch.cuda.is_available():
        total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
        return total_memory
    return 0


def get_device():
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    return "cpu"


def load_image_to_tensor_with_resize_and_crop(
    image_input: Union[str, Image.Image],
    target_height: int = 512,
    target_width: int = 768,
    just_crop: bool = False,
) -> torch.Tensor:
    """Load and process an image into a tensor.

    Args:
        image_input: Either a file path (str) or a PIL Image object
        target_height: Desired height of output tensor
        target_width: Desired width of output tensor
        just_crop: If True, only crop the image to the target size without resizing
    """
    if isinstance(image_input, str):
        image = Image.open(image_input).convert("RGB")
    elif isinstance(image_input, Image.Image):
        image = image_input
    else:
        raise ValueError("image_input must be either a file path or a PIL Image object")

    input_width, input_height = image.size
    aspect_ratio_target = target_width / target_height
    aspect_ratio_frame = input_width / input_height
    if aspect_ratio_frame > aspect_ratio_target:
        new_width = int(input_height * aspect_ratio_target)
        new_height = input_height
        x_start = (input_width - new_width) // 2
        y_start = 0
    else:
        new_width = input_width
        new_height = int(input_width / aspect_ratio_target)
        x_start = 0
        y_start = (input_height - new_height) // 2

    image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
    if not just_crop:
        image = image.resize((target_width, target_height))

    frame_tensor = TVF.to_tensor(image)  # PIL -> tensor (C, H, W), [0,1]
    frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=3, sigma=1.0)
    frame_tensor_hwc = frame_tensor.permute(1, 2, 0)  # (C, H, W) -> (H, W, C)
    frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
    frame_tensor = frame_tensor_hwc.permute(2, 0, 1) * 255.0  # (H, W, C) -> (C, H, W)
    frame_tensor = (frame_tensor / 127.5) - 1.0
    # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
    return frame_tensor.unsqueeze(0).unsqueeze(2)


def calculate_padding(
    source_height: int, source_width: int, target_height: int, target_width: int
) -> tuple[int, int, int, int]:

    # Calculate total padding needed
    pad_height = target_height - source_height
    pad_width = target_width - source_width

    # Calculate padding for each side
    pad_top = pad_height // 2
    pad_bottom = pad_height - pad_top  # Handles odd padding
    pad_left = pad_width // 2
    pad_right = pad_width - pad_left  # Handles odd padding

    # Return padded tensor
    # Padding format is (left, right, top, bottom)
    padding = (pad_left, pad_right, pad_top, pad_bottom)
    return padding


def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
    # Remove non-letters and convert to lowercase
    clean_text = "".join(
        char.lower() for char in text if char.isalpha() or char.isspace()
    )

    # Split into words
    words = clean_text.split()

    # Build result string keeping track of length
    result = []
    current_length = 0

    for word in words:
        # Add word length plus 1 for underscore (except for first word)
        new_length = current_length + len(word)

        if new_length <= max_len:
            result.append(word)
            current_length += len(word)
        else:
            break

    return "-".join(result)


# Generate output video name
def get_unique_filename(
    base: str,
    ext: str,
    prompt: str,
    seed: int,
    resolution: tuple[int, int, int],
    dir: Path,
    endswith=None,
    index_range=1000,
) -> Path:
    base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
    for i in range(index_range):
        filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
        if not os.path.exists(filename):
            return filename
    raise FileExistsError(
        f"Could not find a unique filename after {index_range} attempts."
    )


def seed_everething(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    if torch.backends.mps.is_available():
        torch.mps.manual_seed(seed)


def create_transformer(ckpt_path: str, precision: str) -> Transformer3DModel:
    if precision == "float8_e4m3fn":
        try:
            from q8_kernels.integration.patch_transformer import (
                patch_diffusers_transformer as patch_transformer_for_q8_kernels,
            )

            transformer = Transformer3DModel.from_pretrained(
                ckpt_path, dtype=torch.float8_e4m3fn
            )
            patch_transformer_for_q8_kernels(transformer)
            return transformer
        except ImportError:
            raise ValueError(
                "Q8-Kernels not found. To use FP8 checkpoint, please install Q8 kernels from https://github.com/Lightricks/LTXVideo-Q8-Kernels"
            )
    elif precision == "bfloat16":
        return Transformer3DModel.from_pretrained(ckpt_path).to(torch.bfloat16)
    else:
        return Transformer3DModel.from_pretrained(ckpt_path)


def create_ltx_video_pipeline(
    ckpt_path: str,
    precision: str,
    text_encoder_model_name_or_path: str,
    sampler: Optional[str] = None,
    device: Optional[str] = None,
    enhance_prompt: bool = False,
    prompt_enhancer_image_caption_model_name_or_path: Optional[str] = None,
    prompt_enhancer_llm_model_name_or_path: Optional[str] = None,
) -> LTXVideoPipeline:
    ckpt_path = Path(ckpt_path)
    assert os.path.exists(
        ckpt_path
    ), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist"

    with safe_open(ckpt_path, framework="pt") as f:
        metadata = f.metadata()
        config_str = metadata.get("config")
        configs = json.loads(config_str)
        allowed_inference_steps = configs.get("allowed_inference_steps", None)

    vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
    transformer = create_transformer(ckpt_path, precision)

    # Use constructor if sampler is specified, otherwise use from_pretrained
    if sampler == "from_checkpoint" or not sampler:
        scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
    else:
        scheduler = RectifiedFlowScheduler(
            sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic")
        )

    text_encoder = T5EncoderModel.from_pretrained(
        text_encoder_model_name_or_path, subfolder="text_encoder"
    )
    patchifier = SymmetricPatchifier(patch_size=1)
    tokenizer = T5Tokenizer.from_pretrained(
        text_encoder_model_name_or_path, subfolder="tokenizer"
    )

    transformer = transformer.to(device)
    vae = vae.to(device)
    text_encoder = text_encoder.to(device)

    if enhance_prompt:
        prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained(
            prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
        )
        prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained(
            prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
        )
        prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained(
            prompt_enhancer_llm_model_name_or_path,
            torch_dtype="bfloat16",
        )
        prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(
            prompt_enhancer_llm_model_name_or_path,
        )
    else:
        prompt_enhancer_image_caption_model = None
        prompt_enhancer_image_caption_processor = None
        prompt_enhancer_llm_model = None
        prompt_enhancer_llm_tokenizer = None

    vae = vae.to(torch.bfloat16)
    text_encoder = text_encoder.to(torch.bfloat16)

    # Use submodels for the pipeline
    submodel_dict = {
        "transformer": transformer,
        "patchifier": patchifier,
        "text_encoder": text_encoder,
        "tokenizer": tokenizer,
        "scheduler": scheduler,
        "vae": vae,
        "prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model,
        "prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor,
        "prompt_enhancer_llm_model": prompt_enhancer_llm_model,
        "prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer,
        "allowed_inference_steps": allowed_inference_steps,
    }

    pipeline = LTXVideoPipeline(**submodel_dict)
    pipeline = pipeline.to(device)
    return pipeline


def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
    latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
    latent_upsampler.to(device)
    latent_upsampler.eval()
    return latent_upsampler


def load_pipeline_config(pipeline_config: str):
    current_file = Path(__file__)

    path = None
    if os.path.isfile(current_file.parent / pipeline_config):
        path = current_file.parent / pipeline_config
    elif os.path.isfile(pipeline_config):
        path = pipeline_config
    else:
        raise ValueError(f"Pipeline config file {pipeline_config} does not exist")

    with open(path, "r") as f:
        return yaml.safe_load(f)


@dataclass
class InferenceConfig:
    prompt: str = field(metadata={"help": "Prompt for the generation"})

    output_path: str = field(
        default_factory=lambda: Path(
            f"outputs/{datetime.today().strftime('%Y-%m-%d')}"
        ),
        metadata={"help": "Path to the folder to save the output video"},
    )

    # Pipeline settings
    pipeline_config: str = field(
        default="configs/ltxv-13b-0.9.7-dev.yaml",
        metadata={"help": "Path to the pipeline config file"},
    )
    seed: int = field(
        default=171198, metadata={"help": "Random seed for the inference"}
    )
    height: int = field(
        default=704, metadata={"help": "Height of the output video frames"}
    )
    width: int = field(
        default=1216, metadata={"help": "Width of the output video frames"}
    )
    num_frames: int = field(
        default=121,
        metadata={"help": "Number of frames to generate in the output video"},
    )
    frame_rate: int = field(
        default=30, metadata={"help": "Frame rate for the output video"}
    )
    offload_to_cpu: bool = field(
        default=False, metadata={"help": "Offloading unnecessary computations to CPU."}
    )
    negative_prompt: str = field(
        default="worst quality, inconsistent motion, blurry, jittery, distorted",
        metadata={"help": "Negative prompt for undesired features"},
    )

    # Video-to-video arguments
    input_media_path: Optional[str] = field(
        default=None,
        metadata={
            "help": "Path to the input video (or image) to be modified using the video-to-video pipeline"
        },
    )

    # Conditioning
    image_cond_noise_scale: float = field(
        default=0.15,
        metadata={"help": "Amount of noise to add to the conditioned image"},
    )
    conditioning_media_paths: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "List of paths to conditioning media (images or videos). Each path will be used as a conditioning item."
        },
    )
    conditioning_strengths: Optional[List[float]] = field(
        default=None,
        metadata={
            "help": "List of conditioning strengths (between 0 and 1) for each conditioning item. Must match the number of conditioning items."
        },
    )
    conditioning_start_frames: Optional[List[int]] = field(
        default=None,
        metadata={
            "help": "List of frame indices where each conditioning item should be applied. Must match the number of conditioning items."
        },
    )


def infer(config: InferenceConfig):
    pipeline_config = load_pipeline_config(config.pipeline_config)

    ltxv_model_name_or_path = pipeline_config["checkpoint_path"]
    if not os.path.isfile(ltxv_model_name_or_path):
        ltxv_model_path = hf_hub_download(
            repo_id="Lightricks/LTX-Video",
            filename=ltxv_model_name_or_path,
            repo_type="model",
        )
    else:
        ltxv_model_path = ltxv_model_name_or_path

    spatial_upscaler_model_name_or_path = pipeline_config.get(
        "spatial_upscaler_model_path"
    )
    if spatial_upscaler_model_name_or_path and not os.path.isfile(
        spatial_upscaler_model_name_or_path
    ):
        spatial_upscaler_model_path = hf_hub_download(
            repo_id="Lightricks/LTX-Video",
            filename=spatial_upscaler_model_name_or_path,
            repo_type="model",
        )
    else:
        spatial_upscaler_model_path = spatial_upscaler_model_name_or_path

    conditioning_media_paths = config.conditioning_media_paths
    conditioning_strengths = config.conditioning_strengths
    conditioning_start_frames = config.conditioning_start_frames

    # Validate conditioning arguments
    if conditioning_media_paths:
        # Use default strengths of 1.0
        if not conditioning_strengths:
            conditioning_strengths = [1.0] * len(conditioning_media_paths)
        if not conditioning_start_frames:
            raise ValueError(
                "If `conditioning_media_paths` is provided, "
                "`conditioning_start_frames` must also be provided"
            )
        if len(conditioning_media_paths) != len(conditioning_strengths) or len(
            conditioning_media_paths
        ) != len(conditioning_start_frames):
            raise ValueError(
                "`conditioning_media_paths`, `conditioning_strengths`, "
                "and `conditioning_start_frames` must have the same length"
            )
        if any(s < 0 or s > 1 for s in conditioning_strengths):
            raise ValueError("All conditioning strengths must be between 0 and 1")
        if any(f < 0 or f >= config.num_frames for f in conditioning_start_frames):
            raise ValueError(
                f"All conditioning start frames must be between 0 and {config.num_frames-1}"
            )

    seed_everething(config.seed)
    if config.offload_to_cpu and not torch.cuda.is_available():
        logger.warning(
            "offload_to_cpu is set to True, but offloading will not occur since the model is already running on CPU."
        )
        offload_to_cpu = False
    else:
        offload_to_cpu = config.offload_to_cpu and get_total_gpu_memory() < 30

    output_dir = (
        Path(config.output_path)
        if config.output_path
        else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
    )
    output_dir.mkdir(parents=True, exist_ok=True)

    # Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1)
    height_padded = ((config.height - 1) // 32 + 1) * 32
    width_padded = ((config.width - 1) // 32 + 1) * 32
    num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1

    padding = calculate_padding(
        config.height, config.width, height_padded, width_padded
    )

    logger.warning(
        f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}"
    )

    device = get_device()

    prompt_enhancement_words_threshold = pipeline_config[
        "prompt_enhancement_words_threshold"
    ]

    prompt_word_count = len(config.prompt.split())
    enhance_prompt = (
        prompt_enhancement_words_threshold > 0
        and prompt_word_count < prompt_enhancement_words_threshold
    )

    if prompt_enhancement_words_threshold > 0 and not enhance_prompt:
        logger.info(
            f"Prompt has {prompt_word_count} words, which exceeds the threshold of {prompt_enhancement_words_threshold}. Prompt enhancement disabled."
        )

    precision = pipeline_config["precision"]
    text_encoder_model_name_or_path = pipeline_config["text_encoder_model_name_or_path"]
    sampler = pipeline_config.get("sampler", None)
    prompt_enhancer_image_caption_model_name_or_path = pipeline_config[
        "prompt_enhancer_image_caption_model_name_or_path"
    ]
    prompt_enhancer_llm_model_name_or_path = pipeline_config[
        "prompt_enhancer_llm_model_name_or_path"
    ]

    pipeline = create_ltx_video_pipeline(
        ckpt_path=ltxv_model_path,
        precision=precision,
        text_encoder_model_name_or_path=text_encoder_model_name_or_path,
        sampler=sampler,
        device=device,
        enhance_prompt=enhance_prompt,
        prompt_enhancer_image_caption_model_name_or_path=prompt_enhancer_image_caption_model_name_or_path,
        prompt_enhancer_llm_model_name_or_path=prompt_enhancer_llm_model_name_or_path,
    )

    if pipeline_config.get("pipeline_type", None) == "multi-scale":
        if not spatial_upscaler_model_path:
            raise ValueError(
                "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering"
            )
        latent_upsampler = create_latent_upsampler(
            spatial_upscaler_model_path, pipeline.device
        )
        pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler)

    media_item = None
    if config.input_media_path:
        media_item = load_media_file(
            media_path=config.input_media_path,
            height=config.height,
            width=config.width,
            max_frames=num_frames_padded,
            padding=padding,
        )

    conditioning_items = (
        prepare_conditioning(
            conditioning_media_paths=conditioning_media_paths,
            conditioning_strengths=conditioning_strengths,
            conditioning_start_frames=conditioning_start_frames,
            height=config.height,
            width=config.width,
            num_frames=config.num_frames,
            padding=padding,
            pipeline=pipeline,
        )
        if conditioning_media_paths
        else None
    )

    stg_mode = pipeline_config.get("stg_mode", "attention_values")
    del pipeline_config["stg_mode"]
    if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values":
        skip_layer_strategy = SkipLayerStrategy.AttentionValues
    elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip":
        skip_layer_strategy = SkipLayerStrategy.AttentionSkip
    elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual":
        skip_layer_strategy = SkipLayerStrategy.Residual
    elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block":
        skip_layer_strategy = SkipLayerStrategy.TransformerBlock
    else:
        raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}")

    # Prepare input for the pipeline
    sample = {
        "prompt": config.prompt,
        "prompt_attention_mask": None,
        "negative_prompt": config.negative_prompt,
        "negative_prompt_attention_mask": None,
    }

    generator = torch.Generator(device=device).manual_seed(config.seed)

    images = pipeline(
        **pipeline_config,
        skip_layer_strategy=skip_layer_strategy,
        generator=generator,
        output_type="pt",
        callback_on_step_end=None,
        height=height_padded,
        width=width_padded,
        num_frames=num_frames_padded,
        frame_rate=config.frame_rate,
        **sample,
        media_items=media_item,
        conditioning_items=conditioning_items,
        is_video=True,
        vae_per_channel_normalize=True,
        image_cond_noise_scale=config.image_cond_noise_scale,
        mixed_precision=(precision == "mixed_precision"),
        offload_to_cpu=offload_to_cpu,
        device=device,
        enhance_prompt=enhance_prompt,
    ).images

    # Crop the padded images to the desired resolution and number of frames
    (pad_left, pad_right, pad_top, pad_bottom) = padding
    pad_bottom = -pad_bottom
    pad_right = -pad_right
    if pad_bottom == 0:
        pad_bottom = images.shape[3]
    if pad_right == 0:
        pad_right = images.shape[4]
    images = images[:, :, : config.num_frames, pad_top:pad_bottom, pad_left:pad_right]

    for i in range(images.shape[0]):
        # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
        video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy()
        # Unnormalizing images to [0, 255] range
        video_np = (video_np * 255).astype(np.uint8)
        fps = config.frame_rate
        height, width = video_np.shape[1:3]
        # In case a single image is generated
        if video_np.shape[0] == 1:
            output_filename = get_unique_filename(
                f"image_output_{i}",
                ".png",
                prompt=config.prompt,
                seed=config.seed,
                resolution=(height, width, config.num_frames),
                dir=output_dir,
            )
            imageio.imwrite(output_filename, video_np[0])
        else:
            output_filename = get_unique_filename(
                f"video_output_{i}",
                ".mp4",
                prompt=config.prompt,
                seed=config.seed,
                resolution=(height, width, config.num_frames),
                dir=output_dir,
            )

            # Write video
            with imageio.get_writer(output_filename, fps=fps) as video:
                for frame in video_np:
                    video.append_data(frame)

        logger.warning(f"Output saved to {output_filename}")


def prepare_conditioning(
    conditioning_media_paths: List[str],
    conditioning_strengths: List[float],
    conditioning_start_frames: List[int],
    height: int,
    width: int,
    num_frames: int,
    padding: tuple[int, int, int, int],
    pipeline: LTXVideoPipeline,
) -> Optional[List[ConditioningItem]]:
    """Prepare conditioning items based on input media paths and their parameters.

    Args:
        conditioning_media_paths: List of paths to conditioning media (images or videos)
        conditioning_strengths: List of conditioning strengths for each media item
        conditioning_start_frames: List of frame indices where each item should be applied
        height: Height of the output frames
        width: Width of the output frames
        num_frames: Number of frames in the output video
        padding: Padding to apply to the frames
        pipeline: LTXVideoPipeline object used for condition video trimming

    Returns:
        A list of ConditioningItem objects.
    """
    conditioning_items = []
    for path, strength, start_frame in zip(
        conditioning_media_paths, conditioning_strengths, conditioning_start_frames
    ):
        num_input_frames = orig_num_input_frames = get_media_num_frames(path)
        if hasattr(pipeline, "trim_conditioning_sequence") and callable(
            getattr(pipeline, "trim_conditioning_sequence")
        ):
            num_input_frames = pipeline.trim_conditioning_sequence(
                start_frame, orig_num_input_frames, num_frames
            )
        if num_input_frames < orig_num_input_frames:
            logger.warning(
                f"Trimming conditioning video {path} from {orig_num_input_frames} to {num_input_frames} frames."
            )

        media_tensor = load_media_file(
            media_path=path,
            height=height,
            width=width,
            max_frames=num_input_frames,
            padding=padding,
            just_crop=True,
        )
        conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength))
    return conditioning_items


def get_media_num_frames(media_path: str) -> int:
    is_video = any(
        media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
    )
    num_frames = 1
    if is_video:
        reader = imageio.get_reader(media_path)
        num_frames = reader.count_frames()
        reader.close()
    return num_frames


def load_media_file(
    media_path: str,
    height: int,
    width: int,
    max_frames: int,
    padding: tuple[int, int, int, int],
    just_crop: bool = False,
) -> torch.Tensor:
    is_video = any(
        media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
    )
    if is_video:
        reader = imageio.get_reader(media_path)
        num_input_frames = min(reader.count_frames(), max_frames)

        # Read and preprocess the relevant frames from the video file.
        frames = []
        for i in range(num_input_frames):
            frame = Image.fromarray(reader.get_data(i))
            frame_tensor = load_image_to_tensor_with_resize_and_crop(
                frame, height, width, just_crop=just_crop
            )
            frame_tensor = torch.nn.functional.pad(frame_tensor, padding)
            frames.append(frame_tensor)
        reader.close()

        # Stack frames along the temporal dimension
        media_tensor = torch.cat(frames, dim=2)
    else:  # Input image
        media_tensor = load_image_to_tensor_with_resize_and_crop(
            media_path, height, width, just_crop=just_crop
        )
        media_tensor = torch.nn.functional.pad(media_tensor, padding)
    return media_tensor


================================================
FILE: ltx_video/models/__init__.py
================================================


================================================
FILE: ltx_video/models/autoencoders/__init__.py
================================================


================================================
FILE: ltx_video/models/autoencoders/causal_conv3d.py
================================================
from typing import Tuple, Union

import torch
import torch.nn as nn


class CausalConv3d(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size: int = 3,
        stride: Union[int, Tuple[int]] = 1,
        dilation: int = 1,
        groups: int = 1,
        spatial_padding_mode: str = "zeros",
        **kwargs,
    ):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        kernel_size = (kernel_size, kernel_size, kernel_size)
        self.time_kernel_size = kernel_size[0]

        dilation = (dilation, 1, 1)

        height_pad = kernel_size[1] // 2
        width_pad = kernel_size[2] // 2
        padding = (0, height_pad, width_pad)

        self.conv = nn.Conv3d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            dilation=dilation,
            padding=padding,
            padding_mode=spatial_padding_mode,
            groups=groups,
        )

    def forward(self, x, causal: bool = True):
        if causal:
            first_frame_pad = x[:, :, :1, :, :].repeat(
                (1, 1, self.time_kernel_size - 1, 1, 1)
            )
            x = torch.concatenate((first_frame_pad, x), dim=2)
        else:
            first_frame_pad = x[:, :, :1, :, :].repeat(
                (1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
            )
            last_frame_pad = x[:, :, -1:, :, :].repeat(
                (1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
            )
            x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
        x = self.conv(x)
        return x

    @property
    def weight(self):
        return self.conv.weight


================================================
FILE: ltx_video/models/autoencoders/causal_video_autoencoder.py
================================================
import json
import os
from functools import partial
from types import SimpleNamespace
from typing import Any, Mapping, Optional, Tuple, Union, List
from pathlib import Path

import torch
import numpy as np
from einops import rearrange
from torch import nn
from diffusers.utils import logging
import torch.nn.functional as F
from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
from safetensors import safe_open


from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
from ltx_video.models.autoencoders.pixel_norm import PixelNorm
from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND
from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper
from ltx_video.models.transformers.attention import Attention
from ltx_video.utils.diffusers_config_mapping import (
    diffusers_and_ours_config_mapping,
    make_hashable_key,
    VAE_KEYS_RENAME_DICT,
)

PER_CHANNEL_STATISTICS_PREFIX = "per_channel_statistics."
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


class CausalVideoAutoencoder(AutoencoderKLWrapper):
    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
        *args,
        **kwargs,
    ):
        pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
        if (
            pretrained_model_name_or_path.is_dir()
            and (pretrained_model_name_or_path / "autoencoder.pth").exists()
        ):
            config_local_path = pretrained_model_name_or_path / "config.json"
            config = cls.load_config(config_local_path, **kwargs)

            model_local_path = pretrained_model_name_or_path / "autoencoder.pth"
            state_dict = torch.load(model_local_path, map_location=torch.device("cpu"))

            statistics_local_path = (
                pretrained_model_name_or_path / "per_channel_statistics.json"
            )
            if statistics_local_path.exists():
                with open(statistics_local_path, "r") as file:
                    data = json.load(file)
                transposed_data = list(zip(*data["data"]))
                data_dict = {
                    col: torch.tensor(vals)
                    for col, vals in zip(data["columns"], transposed_data)
                }
                std_of_means = data_dict["std-of-means"]
                mean_of_means = data_dict.get(
                    "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
                )
                state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}std-of-means"] = (
                    std_of_means
                )
                state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}mean-of-means"] = (
                    mean_of_means
                )

        elif pretrained_model_name_or_path.is_dir():
            config_path = pretrained_model_name_or_path / "vae" / "config.json"
            with open(config_path, "r") as f:
                config = make_hashable_key(json.load(f))

            assert config in diffusers_and_ours_config_mapping, (
                "Provided diffusers checkpoint config for VAE is not suppported. "
                "We only support diffusers configs found in Lightricks/LTX-Video."
            )

            config = diffusers_and_ours_config_mapping[config]

            state_dict_path = (
                pretrained_model_name_or_path
                / "vae"
                / "diffusion_pytorch_model.safetensors"
            )

            state_dict = {}
            with safe_open(state_dict_path, framework="pt", device="cpu") as f:
                for k in f.keys():
                    state_dict[k] = f.get_tensor(k)
            for key in list(state_dict.keys()):
                new_key = key
                for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
                    new_key = new_key.replace(replace_key, rename_key)

                state_dict[new_key] = state_dict.pop(key)

        elif pretrained_model_name_or_path.is_file() and str(
            pretrained_model_name_or_path
        ).endswith(".safetensors"):
            state_dict = {}
            with safe_open(
                pretrained_model_name_or_path, framework="pt", device="cpu"
            ) as f:
                metadata = f.metadata()
                for k in f.keys():
                    state_dict[k] = f.get_tensor(k)
            configs = json.loads(metadata["config"])
            config = configs["vae"]

        video_vae = cls.from_config(config)
        if "torch_dtype" in kwargs:
            video_vae.to(kwargs["torch_dtype"])
        video_vae.load_state_dict(state_dict)
        return video_vae

    @staticmethod
    def from_config(config):
        assert (
            config["_class_name"] == "CausalVideoAutoencoder"
        ), "config must have _class_name=CausalVideoAutoencoder"
        if isinstance(config["dims"], list):
            config["dims"] = tuple(config["dims"])

        assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"

        double_z = config.get("double_z", True)
        latent_log_var = config.get(
            "latent_log_var", "per_channel" if double_z else "none"
        )
        use_quant_conv = config.get("use_quant_conv", True)
        normalize_latent_channels = config.get("normalize_latent_channels", False)

        if use_quant_conv and latent_log_var in ["uniform", "constant"]:
            raise ValueError(
                f"latent_log_var={latent_log_var} requires use_quant_conv=False"
            )

        encoder = Encoder(
            dims=config["dims"],
            in_channels=config.get("in_channels", 3),
            out_channels=config["latent_channels"],
            blocks=config.get("encoder_blocks", config.get("blocks")),
            patch_size=config.get("patch_size", 1),
            latent_log_var=latent_log_var,
            norm_layer=config.get("norm_layer", "group_norm"),
            base_channels=config.get("encoder_base_channels", 128),
            spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
        )

        decoder = Decoder(
            dims=config["dims"],
            in_channels=config["latent_channels"],
            out_channels=config.get("out_channels", 3),
            blocks=config.get("decoder_blocks", config.get("blocks")),
            patch_size=config.get("patch_size", 1),
            norm_layer=config.get("norm_layer", "group_norm"),
            causal=config.get("causal_decoder", False),
            timestep_conditioning=config.get("timestep_conditioning", False),
            base_channels=config.get("decoder_base_channels", 128),
            spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
        )

        dims = config["dims"]
        return CausalVideoAutoencoder(
            encoder=encoder,
            decoder=decoder,
            latent_channels=config["latent_channels"],
            dims=dims,
            use_quant_conv=use_quant_conv,
            normalize_latent_channels=normalize_latent_channels,
        )

    @property
    def config(self):
        return SimpleNamespace(
            _class_name="CausalVideoAutoencoder",
            dims=self.dims,
            in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2,
            out_channels=self.decoder.conv_out.out_channels
            // self.decoder.patch_size**2,
            latent_channels=self.decoder.conv_in.in_channels,
            encoder_blocks=self.encoder.blocks_desc,
            decoder_blocks=self.decoder.blocks_desc,
            scaling_factor=1.0,
            norm_layer=self.encoder.norm_layer,
            patch_size=self.encoder.patch_size,
            latent_log_var=self.encoder.latent_log_var,
            use_quant_conv=self.use_quant_conv,
            causal_decoder=self.decoder.causal,
            timestep_conditioning=self.decoder.timestep_conditioning,
            normalize_latent_channels=self.normalize_latent_channels,
        )

    @property
    def is_video_supported(self):
        """
        Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images.
        """
        return self.dims != 2

    @property
    def spatial_downscale_factor(self):
        return (
            2
            ** len(
                [
                    block
                    for block in self.encoder.blocks_desc
                    if block[0]
                    in [
                        "compress_space",
                        "compress_all",
                        "compress_all_res",
                        "compress_space_res",
                    ]
                ]
            )
            * self.encoder.patch_size
        )

    @property
    def temporal_downscale_factor(self):
        return 2 ** len(
            [
                block
                for block in self.encoder.blocks_desc
                if block[0]
                in [
                    "compress_time",
                    "compress_all",
                    "compress_all_res",
                    "compress_time_res",
                ]
            ]
        )

    def to_json_string(self) -> str:
        import json

        return json.dumps(self.config.__dict__)

    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
        if any([key.startswith("vae.") for key in state_dict.keys()]):
            state_dict = {
                key.replace("vae.", ""): value
                for key, value in state_dict.items()
                if key.startswith("vae.")
            }
        ckpt_state_dict = {
            key: value
            for key, value in state_dict.items()
            if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX)
        }

        model_keys = set(name for name, _ in self.named_modules())

        key_mapping = {
            ".resnets.": ".res_blocks.",
            "downsamplers.0": "downsample",
            "upsamplers.0": "upsample",
        }
        converted_state_dict = {}
        for key, value in ckpt_state_dict.items():
            for k, v in key_mapping.items():
                key = key.replace(k, v)

            key_prefix = ".".join(key.split(".")[:-1])
            if "norm" in key and key_prefix not in model_keys:
                logger.info(
                    f"Removing key {key} from state_dict as it is not present in the model"
                )
                continue

            converted_state_dict[key] = value

        super().load_state_dict(converted_state_dict, strict=strict)

        data_dict = {
            key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value
            for key, value in state_dict.items()
            if key.startswith(PER_CHANNEL_STATISTICS_PREFIX)
        }
        if len(data_dict) > 0:
            self.register_buffer("std_of_means", data_dict["std-of-means"])
            self.register_buffer(
                "mean_of_means",
                data_dict.get(
                    "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
                ),
            )

    def last_layer(self):
        if hasattr(self.decoder, "conv_out"):
            if isinstance(self.decoder.conv_out, nn.Sequential):
                last_layer = self.decoder.conv_out[-1]
            else:
                last_layer = self.decoder.conv_out
        else:
            last_layer = self.decoder.layers[-1]
        return last_layer

    def set_use_tpu_flash_attention(self):
        for block in self.decoder.up_blocks:
            if isinstance(block, UNetMidBlock3D) and block.attention_blocks:
                for attention_block in block.attention_blocks:
                    attention_block.set_use_tpu_flash_attention()


class Encoder(nn.Module):
    r"""
    The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.

    Args:
        dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
            The number of dimensions to use in convolutions.
        in_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        out_channels (`int`, *optional*, defaults to 3):
            The number of output channels.
        blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
            The blocks to use. Each block is a tuple of the block name and the number of layers.
        base_channels (`int`, *optional*, defaults to 128):
            The number of output channels for the first convolutional layer.
        norm_num_groups (`int`, *optional*, defaults to 32):
            The number of groups for normalization.
        patch_size (`int`, *optional*, defaults to 1):
            The patch size to use. Should be a power of 2.
        norm_layer (`str`, *optional*, defaults to `group_norm`):
            The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
        latent_log_var (`str`, *optional*, defaults to `per_channel`):
            The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`.
    """

    def __init__(
        self,
        dims: Union[int, Tuple[int, int]] = 3,
        in_channels: int = 3,
        out_channels: int = 3,
        blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
        base_channels: int = 128,
        norm_num_groups: int = 32,
        patch_size: Union[int, Tuple[int]] = 1,
        norm_layer: str = "group_norm",  # group_norm, pixel_norm
        latent_log_var: str = "per_channel",
        spatial_padding_mode: str = "zeros",
    ):
        super().__init__()
        self.patch_size = patch_size
        self.norm_layer = norm_layer
        self.latent_channels = out_channels
        self.latent_log_var = latent_log_var
        self.blocks_desc = blocks

        in_channels = in_channels * patch_size**2
        output_channel = base_channels

        self.conv_in = make_conv_nd(
            dims=dims,
            in_channels=in_channels,
            out_channels=output_channel,
            kernel_size=3,
            stride=1,
            padding=1,
            causal=True,
            spatial_padding_mode=spatial_padding_mode,
        )

        self.down_blocks = nn.ModuleList([])

        for block_name, block_params in blocks:
            input_channel = output_channel
            if isinstance(block_params, int):
                block_params = {"num_layers": block_params}

            if block_name == "res_x":
                block = UNetMidBlock3D(
                    dims=dims,
                    in_channels=input_channel,
                    num_layers=block_params["num_layers"],
                    resnet_eps=1e-6,
                    resnet_groups=norm_num_groups,
                    norm_layer=norm_layer,
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "res_x_y":
                output_channel = block_params.get("multiplier", 2) * output_channel
                block = ResnetBlock3D(
                    dims=dims,
                    in_channels=input_channel,
                    out_channels=output_channel,
                    eps=1e-6,
                    groups=norm_num_groups,
                    norm_layer=norm_layer,
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "compress_time":
                block = make_conv_nd(
                    dims=dims,
                    in_channels=input_channel,
                    out_channels=output_channel,
                    kernel_size=3,
                    stride=(2, 1, 1),
                    causal=True,
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "compress_space":
                block = make_conv_nd(
                    dims=dims,
                    in_channels=input_channel,
                    out_channels=output_channel,
                    kernel_size=3,
                    stride=(1, 2, 2),
                    causal=True,
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "compress_all":
                block = make_conv_nd(
                    dims=dims,
                    in_channels=input_channel,
                    out_channels=output_channel,
                    kernel_size=3,
                    stride=(2, 2, 2),
                    causal=True,
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "compress_all_x_y":
                output_channel = block_params.get("multiplier", 2) * output_channel
                block = make_conv_nd(
                    dims=dims,
                    in_channels=input_channel,
                    out_channels=output_channel,
                    kernel_size=3,
                    stride=(2, 2, 2),
                    causal=True,
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "compress_all_res":
                output_channel = block_params.get("multiplier", 2) * output_channel
                block = SpaceToDepthDownsample(
                    dims=dims,
                    in_channels=input_channel,
                    out_channels=output_channel,
                    stride=(2, 2, 2),
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "compress_space_res":
                output_channel = block_params.get("multiplier", 2) * output_channel
                block = SpaceToDepthDownsample(
                    dims=dims,
                    in_channels=input_channel,
                    out_channels=output_channel,
                    stride=(1, 2, 2),
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "compress_time_res":
                output_channel = block_params.get("multiplier", 2) * output_channel
                block = SpaceToDepthDownsample(
                    dims=dims,
                    in_channels=input_channel,
                    out_channels=output_channel,
                    stride=(2, 1, 1),
                    spatial_padding_mode=spatial_padding_mode,
                )
            else:
                raise ValueError(f"unknown block: {block_name}")

            self.down_blocks.append(block)

        # out
        if norm_layer == "group_norm":
            self.conv_norm_out = nn.GroupNorm(
                num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
            )
        elif norm_layer == "pixel_norm":
            self.conv_norm_out = PixelNorm()
        elif norm_layer == "layer_norm":
            self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)

        self.conv_act = nn.SiLU()

        conv_out_channels = out_channels
        if latent_log_var == "per_channel":
            conv_out_channels *= 2
        elif latent_log_var == "uniform":
            conv_out_channels += 1
        elif latent_log_var == "constant":
            conv_out_channels += 1
        elif latent_log_var != "none":
            raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
        self.conv_out = make_conv_nd(
            dims,
            output_channel,
            conv_out_channels,
            3,
            padding=1,
            causal=True,
            spatial_padding_mode=spatial_padding_mode,
        )

        self.gradient_checkpointing = False

    def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
        r"""The forward method of the `Encoder` class."""

        sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
        sample = self.conv_in(sample)

        checkpoint_fn = (
            partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
            if self.gradient_checkpointing and self.training
            else lambda x: x
        )

        for down_block in self.down_blocks:
            sample = checkpoint_fn(down_block)(sample)

        sample = self.conv_norm_out(sample)
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)

        if self.latent_log_var == "uniform":
            last_channel = sample[:, -1:, ...]
            num_dims = sample.dim()

            if num_dims == 4:
                # For shape (B, C, H, W)
                repeated_last_channel = last_channel.repeat(
                    1, sample.shape[1] - 2, 1, 1
                )
                sample = torch.cat([sample, repeated_last_channel], dim=1)
            elif num_dims == 5:
                # For shape (B, C, F, H, W)
                repeated_last_channel = last_channel.repeat(
                    1, sample.shape[1] - 2, 1, 1, 1
                )
                sample = torch.cat([sample, repeated_last_channel], dim=1)
            else:
                raise ValueError(f"Invalid input shape: {sample.shape}")
        elif self.latent_log_var == "constant":
            sample = sample[:, :-1, ...]
            approx_ln_0 = (
                -30
            )  # this is the minimal clamp value in DiagonalGaussianDistribution objects
            sample = torch.cat(
                [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],
                dim=1,
            )

        return sample


class Decoder(nn.Module):
    r"""
    The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.

    Args:
        dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
            The number of dimensions to use in convolutions.
        in_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        out_channels (`int`, *optional*, defaults to 3):
            The number of output channels.
        blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
            The blocks to use. Each block is a tuple of the block name and the number of layers.
        base_channels (`int`, *optional*, defaults to 128):
            The number of output channels for the first convolutional layer.
        norm_num_groups (`int`, *optional*, defaults to 32):
            The number of groups for normalization.
        patch_size (`int`, *optional*, defaults to 1):
            The patch size to use. Should be a power of 2.
        norm_layer (`str`, *optional*, defaults to `group_norm`):
            The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
        causal (`bool`, *optional*, defaults to `True`):
            Whether to use causal convolutions or not.
    """

    def __init__(
        self,
        dims,
        in_channels: int = 3,
        out_channels: int = 3,
        blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
        base_channels: int = 128,
        layers_per_block: int = 2,
        norm_num_groups: int = 32,
        patch_size: int = 1,
        norm_layer: str = "group_norm",
        causal: bool = True,
        timestep_conditioning: bool = False,
        spatial_padding_mode: str = "zeros",
    ):
        super().__init__()
        self.patch_size = patch_size
        self.layers_per_block = layers_per_block
        out_channels = out_channels * patch_size**2
        self.causal = causal
        self.blocks_desc = blocks

        # Compute output channel to be product of all channel-multiplier blocks
        output_channel = base_channels
        for block_name, block_params in list(reversed(blocks)):
            block_params = block_params if isinstance(block_params, dict) else {}
            if block_name == "res_x_y":
                output_channel = output_channel * block_params.get("multiplier", 2)
            if block_name.startswith("compress"):
                output_channel = output_channel * block_params.get("multiplier", 1)

        self.conv_in = make_conv_nd(
            dims,
            in_channels,
            output_channel,
            kernel_size=3,
            stride=1,
            padding=1,
            causal=True,
            spatial_padding_mode=spatial_padding_mode,
        )

        self.up_blocks = nn.ModuleList([])

        for block_name, block_params in list(reversed(blocks)):
            input_channel = output_channel
            if isinstance(block_params, int):
                block_params = {"num_layers": block_params}

            if block_name == "res_x":
                block = UNetMidBlock3D(
                    dims=dims,
                    in_channels=input_channel,
                    num_layers=block_params["num_layers"],
                    resnet_eps=1e-6,
                    resnet_groups=norm_num_groups,
                    norm_layer=norm_layer,
                    inject_noise=block_params.get("inject_noise", False),
                    timestep_conditioning=timestep_conditioning,
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "attn_res_x":
                block = UNetMidBlock3D(
                    dims=dims,
                    in_channels=input_channel,
                    num_layers=block_params["num_layers"],
                    resnet_groups=norm_num_groups,
                    norm_layer=norm_layer,
                    inject_noise=block_params.get("inject_noise", False),
                    timestep_conditioning=timestep_conditioning,
                    attention_head_dim=block_params["attention_head_dim"],
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "res_x_y":
                output_channel = output_channel // block_params.get("multiplier", 2)
                block = ResnetBlock3D(
                    dims=dims,
                    in_channels=input_channel,
                    out_channels=output_channel,
                    eps=1e-6,
                    groups=norm_num_groups,
                    norm_layer=norm_layer,
                    inject_noise=block_params.get("inject_noise", False),
                    timestep_conditioning=False,
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "compress_time":
                block = DepthToSpaceUpsample(
                    dims=dims,
                    in_channels=input_channel,
                    stride=(2, 1, 1),
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "compress_space":
                block = DepthToSpaceUpsample(
                    dims=dims,
                    in_channels=input_channel,
                    stride=(1, 2, 2),
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "compress_all":
                output_channel = output_channel // block_params.get("multiplier", 1)
                block = DepthToSpaceUpsample(
                    dims=dims,
                    in_channels=input_channel,
                    stride=(2, 2, 2),
                    residual=block_params.get("residual", False),
                    out_channels_reduction_factor=block_params.get("multiplier", 1),
                    spatial_padding_mode=spatial_padding_mode,
                )
            else:
                raise ValueError(f"unknown layer: {block_name}")

            self.up_blocks.append(block)

        if norm_layer == "group_norm":
            self.conv_norm_out = nn.GroupNorm(
                num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
            )
        elif norm_layer == "pixel_norm":
            self.conv_norm_out = PixelNorm()
        elif norm_layer == "layer_norm":
            self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)

        self.conv_act = nn.SiLU()
        self.conv_out = make_conv_nd(
            dims,
            output_channel,
            out_channels,
            3,
            padding=1,
            causal=True,
            spatial_padding_mode=spatial_padding_mode,
        )

        self.gradient_checkpointing = False

        self.timestep_conditioning = timestep_conditioning

        if timestep_conditioning:
            self.timestep_scale_multiplier = nn.Parameter(
                torch.tensor(1000.0, dtype=torch.float32)
            )
            self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
                output_channel * 2, 0
            )
            self.last_scale_shift_table = nn.Parameter(
                torch.randn(2, output_channel) / output_channel**0.5
            )

    def forward(
        self,
        sample: torch.FloatTensor,
        target_shape,
        timestep: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
        r"""The forward method of the `Decoder` class."""
        assert target_shape is not None, "target_shape must be provided"
        batch_size = sample.shape[0]

        sample = self.conv_in(sample, causal=self.causal)

        upscale_dtype = next(iter(self.up_blocks.parameters())).dtype

        checkpoint_fn = (
            partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
            if self.gradient_checkpointing and self.training
            else lambda x: x
        )

        sample = sample.to(upscale_dtype)

        if self.timestep_conditioning:
            assert (
                timestep is not None
            ), "should pass timestep with timestep_conditioning=True"
            scaled_timestep = timestep * self.timestep_scale_multiplier

        for up_block in self.up_blocks:
            if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
                sample = checkpoint_fn(up_block)(
                    sample, causal=self.causal, timestep=scaled_timestep
                )
            else:
                sample = checkpoint_fn(up_block)(sample, causal=self.causal)

        sample = self.conv_norm_out(sample)

        if self.timestep_conditioning:
            embedded_timestep = self.last_time_embedder(
                timestep=scaled_timestep.flatten(),
                resolution=None,
                aspect_ratio=None,
                batch_size=sample.shape[0],
                hidden_dtype=sample.dtype,
            )
            embedded_timestep = embedded_timestep.view(
                batch_size, embedded_timestep.shape[-1], 1, 1, 1
            )
            ada_values = self.last_scale_shift_table[
                None, ..., None, None, None
            ] + embedded_timestep.reshape(
                batch_size,
                2,
                -1,
                embedded_timestep.shape[-3],
                embedded_timestep.shape[-2],
                embedded_timestep.shape[-1],
            )
            shift, scale = ada_values.unbind(dim=1)
            sample = sample * (1 + scale) + shift

        sample = self.conv_act(sample)
        sample = self.conv_out(sample, causal=self.causal)

        sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)

        return sample


class UNetMidBlock3D(nn.Module):
    """
    A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.

    Args:
        in_channels (`int`): The number of input channels.
        dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
        num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
        resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
        resnet_groups (`int`, *optional*, defaults to 32):
            The number of groups to use in the group normalization layers of the resnet blocks.
        norm_layer (`str`, *optional*, defaults to `group_norm`):
            The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
        inject_noise (`bool`, *optional*, defaults to `False`):
            Whether to inject noise into the hidden states.
        timestep_conditioning (`bool`, *optional*, defaults to `False`):
            Whether to condition the hidden states on the timestep.
        attention_head_dim (`int`, *optional*, defaults to -1):
            The dimension of the attention head. If -1, no attention is used.

    Returns:
        `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
        in_channels, height, width)`.

    """

    def __init__(
        self,
        dims: Union[int, Tuple[int, int]],
        in_channels: int,
        dropout: float = 0.0,
        num_layers: int = 1,
        resnet_eps: float = 1e-6,
        resnet_groups: int = 32,
        norm_layer: str = "group_norm",
        inject_noise: bool = False,
        timestep_conditioning: bool = False,
        attention_head_dim: int = -1,
        spatial_padding_mode: str = "zeros",
    ):
        super().__init__()
        resnet_groups = (
            resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
        )
        self.timestep_conditioning = timestep_conditioning

        if timestep_conditioning:
            self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
                in_channels * 4, 0
            )

        self.res_blocks = nn.ModuleList(
            [
                ResnetBlock3D(
                    dims=dims,
                    in_channels=in_channels,
                    out_channels=in_channels,
                    eps=resnet_eps,
                    groups=resnet_groups,
                    dropout=dropout,
                    norm_layer=norm_layer,
                    inject_noise=inject_noise,
                    timestep_conditioning=timestep_conditioning,
                    spatial_padding_mode=spatial_padding_mode,
                )
                for _ in range(num_layers)
            ]
        )

        self.attention_blocks = None

        if attention_head_dim > 0:
            if attention_head_dim > in_channels:
                raise ValueError(
                    "attention_head_dim must be less than or equal to in_channels"
                )

            self.attention_blocks = nn.ModuleList(
                [
                    Attention(
                        query_dim=in_channels,
                        heads=in_channels // attention_head_dim,
                        dim_head=attention_head_dim,
                        bias=True,
                        out_bias=True,
                        qk_norm="rms_norm",
                        residual_connection=True,
                    )
                    for _ in range(num_layers)
                ]
            )

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        causal: bool = True,
        timestep: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
        timestep_embed = None
        if self.timestep_conditioning:
            assert (
                timestep is not None
            ), "should pass timestep with timestep_conditioning=True"
            batch_size = hidden_states.shape[0]
            timestep_embed = self.time_embedder(
                timestep=timestep.flatten(),
                resolution=None,
                aspect_ratio=None,
                batch_size=batch_size,
                hidden_dtype=hidden_states.dtype,
            )
            timestep_embed = timestep_embed.view(
                batch_size, timestep_embed.shape[-1], 1, 1, 1
            )

        if self.attention_blocks:
            for resnet, attention in zip(self.res_blocks, self.attention_blocks):
                hidden_states = resnet(
                    hidden_states, causal=causal, timestep=timestep_embed
                )

                # Reshape the hidden states to be (batch_size, frames * height * width, channel)
                batch_size, channel, frames, height, width = hidden_states.shape
                hidden_states = hidden_states.view(
                    batch_size, channel, frames * height * width
                ).transpose(1, 2)

                if attention.use_tpu_flash_attention:
                    # Pad the second dimension to be divisible by block_k_major (block in flash attention)
                    seq_len = hidden_states.shape[1]
                    block_k_major = 512
                    pad_len = (block_k_major - seq_len % block_k_major) % block_k_major
                    if pad_len > 0:
                        hidden_states = F.pad(
                            hidden_states, (0, 0, 0, pad_len), "constant", 0
                        )

                    # Create a mask with ones for the original sequence length and zeros for the padded indexes
                    mask = torch.ones(
                        (hidden_states.shape[0], seq_len),
                        device=hidden_states.device,
                        dtype=hidden_states.dtype,
                    )
                    if pad_len > 0:
                        mask = F.pad(mask, (0, pad_len), "constant", 0)

                hidden_states = attention(
                    hidden_states,
                    attention_mask=(
                        None if not attention.use_tpu_flash_attention else mask
                    ),
                )

                if attention.use_tpu_flash_attention:
                    # Remove the padding
                    if pad_len > 0:
                        hidden_states = hidden_states[:, :-pad_len, :]

                # Reshape the hidden states back to (batch_size, channel, frames, height, width, channel)
                hidden_states = hidden_states.transpose(-1, -2).reshape(
                    batch_size, channel, frames, height, width
                )
        else:
            for resnet in self.res_blocks:
                hidden_states = resnet(
                    hidden_states, causal=causal, timestep=timestep_embed
                )

        return hidden_states


class SpaceToDepthDownsample(nn.Module):
    def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode):
        super().__init__()
        self.stride = stride
        self.group_size = in_channels * np.prod(stride) // out_channels
        self.conv = make_conv_nd(
            dims=dims,
            in_channels=in_channels,
            out_channels=out_channels // np.prod(stride),
            kernel_size=3,
            stride=1,
            causal=True,
            spatial_padding_mode=spatial_padding_mode,
        )

    def forward(self, x, causal: bool = True):
        if self.stride[0] == 2:
            x = torch.cat(
                [x[:, :, :1, :, :], x], dim=2
            )  # duplicate first frames for padding

        # skip connection
        x_in = rearrange(
            x,
            "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
            p1=self.stride[0],
            p2=self.stride[1],
            p3=self.stride[2],
        )
        x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size)
        x_in = x_in.mean(dim=2)

        # conv
        x = self.conv(x, causal=causal)
        x = rearrange(
            x,
            "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
            p1=self.stride[0],
            p2=self.stride[1],
            p3=self.stride[2],
        )

        x = x + x_in

        return x


class DepthToSpaceUpsample(nn.Module):
    def __init__(
        self,
        dims,
        in_channels,
        stride,
        residual=False,
        out_channels_reduction_factor=1,
        spatial_padding_mode="zeros",
    ):
        super().__init__()
        self.stride = stride
        self.out_channels = (
            np.prod(stride) * in_channels // out_channels_reduction_factor
        )
        self.conv = make_conv_nd(
            dims=dims,
            in_channels=in_channels,
            out_channels=self.out_channels,
            kernel_size=3,
            stride=1,
            causal=True,
            spatial_padding_mode=spatial_padding_mode,
        )
        self.pixel_shuffle = PixelShuffleND(dims=dims, upscale_factors=stride)
        self.residual = residual
        self.out_channels_reduction_factor = out_channels_reduction_factor

    def forward(self, x, causal: bool = True):
        if self.residual:
            # Reshape and duplicate the input to match the output shape
            x_in = self.pixel_shuffle(x)
            num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor
            x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
            if self.stride[0] == 2:
                x_in = x_in[:, :, 1:, :, :]
        x = self.conv(x, causal=causal)
        x = self.pixel_shuffle(x)
        if self.stride[0] == 2:
            x = x[:, :, 1:, :, :]
        if self.residual:
            x = x + x_in
        return x


class LayerNorm(nn.Module):
    def __init__(self, dim, eps, elementwise_affine=True) -> None:
        super().__init__()
        self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)

    def forward(self, x):
        x = rearrange(x, "b c d h w -> b d h w c")
        x = self.norm(x)
        x = rearrange(x, "b d h w c -> b c d h w")
        return x


class ResnetBlock3D(nn.Module):
    r"""
    A Resnet block.

    Parameters:
        in_channels (`int`): The number of channels in the input.
        out_channels (`int`, *optional*, default to be `None`):
            The number of output channels for the first conv layer. If None, same as `in_channels`.
        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
        groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
    """

    def __init__(
        self,
        dims: Union[int, Tuple[int, int]],
        in_channels: int,
        out_channels: Optional[int] = None,
        dropout: float = 0.0,
        groups: int = 32,
        eps: float = 1e-6,
        norm_layer: str = "group_norm",
        inject_noise: bool = False,
        timestep_conditioning: bool = False,
        spatial_padding_mode: str = "zeros",
    ):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.inject_noise = inject_noise

        if norm_layer == "group_norm":
            self.norm1 = nn.GroupNorm(
                num_groups=groups, num_channels=in_channels, eps=eps, affine=True
            )
        elif norm_layer == "pixel_norm":
            self.norm1 = PixelNorm()
        elif norm_layer == "layer_norm":
            self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True)

        self.non_linearity = nn.SiLU()

        self.conv1 = make_conv_nd(
            dims,
            in_channels,
            out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            causal=True,
            spatial_padding_mode=spatial_padding_mode,
        )

        if inject_noise:
            self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))

        if norm_layer == "group_norm":
            self.norm2 = nn.GroupNorm(
                num_groups=groups, num_channels=out_channels, eps=eps, affine=True
            )
        elif norm_layer == "pixel_norm":
            self.norm2 = PixelNorm()
        elif norm_layer == "layer_norm":
            self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True)

        self.dropout = torch.nn.Dropout(dropout)

        self.conv2 = make_conv_nd(
            dims,
            out_channels,
            out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            causal=True,
            spatial_padding_mode=spatial_padding_mode,
        )

        if inject_noise:
            self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))

        self.conv_shortcut = (
            make_linear_nd(
                dims=dims, in_channels=in_channels, out_channels=out_channels
            )
            if in_channels != out_channels
            else nn.Identity()
        )

        self.norm3 = (
            LayerNorm(in_channels, eps=eps, elementwise_affine=True)
            if in_channels != out_channels
            else nn.Identity()
        )

        self.timestep_conditioning = timestep_conditioning

        if timestep_conditioning:
            self.scale_shift_table = nn.Parameter(
                torch.randn(4, in_channels) / in_channels**0.5
            )

    def _feed_spatial_noise(
        self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
    ) -> torch.FloatTensor:
        spatial_shape = hidden_states.shape[-2:]
        device = hidden_states.device
        dtype = hidden_states.dtype

        # similar to the "explicit noise inputs" method in style-gan
        spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
        scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
        hidden_states = hidden_states + scaled_noise

        return hidden_states

    def forward(
        self,
        input_tensor: torch.FloatTensor,
        causal: bool = True,
        timestep: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
        hidden_states = input_tensor
        batch_size = hidden_states.shape[0]

        hidden_states = self.norm1(hidden_states)
        if self.timestep_conditioning:
            assert (
                timestep is not None
            ), "should pass timestep with timestep_conditioning=True"
            ada_values = self.scale_shift_table[
                None, ..., None, None, None
            ] + timestep.reshape(
                batch_size,
                4,
                -1,
                timestep.shape[-3],
                timestep.shape[-2],
                timestep.shape[-1],
            )
            shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)

            hidden_states = hidden_states * (1 + scale1) + shift1

        hidden_states = self.non_linearity(hidden_states)

        hidden_states = self.conv1(hidden_states, causal=causal)

        if self.inject_noise:
            hidden_states = self._feed_spatial_noise(
                hidden_states, self.per_channel_scale1
            )

        hidden_states = self.norm2(hidden_states)

        if self.timestep_conditioning:
            hidden_states = hidden_states * (1 + scale2) + shift2

        hidden_states = self.non_linearity(hidden_states)

        hidden_states = self.dropout(hidden_states)

        hidden_states = self.conv2(hidden_states, causal=causal)

        if self.inject_noise:
            hidden_states = self._feed_spatial_noise(
                hidden_states, self.per_channel_scale2
            )

        input_tensor = self.norm3(input_tensor)

        batch_size = input_tensor.shape[0]

        input_tensor = self.conv_shortcut(input_tensor)

        output_tensor = input_tensor + hidden_states

        return output_tensor


def patchify(x, patch_size_hw, patch_size_t=1):
    if patch_size_hw == 1 and patch_size_t == 1:
        return x
    if x.dim() == 4:
        x = rearrange(
            x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
        )
    elif x.dim() == 5:
        x = rearrange(
            x,
            "b c (f p) (h q) (w r) -> b (c p r q) f h w",
            p=patch_size_t,
            q=patch_size_hw,
            r=patch_size_hw,
        )
    else:
        raise ValueError(f"Invalid input shape: {x.shape}")

    return x


def unpatchify(x, patch_size_hw, patch_size_t=1):
    if patch_size_hw == 1 and patch_size_t == 1:
        return x

    if x.dim() == 4:
        x = rearrange(
            x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
        )
    elif x.dim() == 5:
        x = rearrange(
            x,
            "b (c p r q) f h w -> b c (f p) (h q) (w r)",
            p=patch_size_t,
            q=patch_size_hw,
            r=patch_size_hw,
        )

    return x


def create_video_autoencoder_demo_config(
    latent_channels: int = 64,
):
    encoder_blocks = [
        ("res_x", {"num_layers": 2}),
        ("compress_space_res", {"multiplier": 2}),
        ("compress_time_res", {"multiplier": 2}),
        ("compress_all_res", {"multiplier": 2}),
        ("compress_all_res", {"multiplier": 2}),
        ("res_x", {"num_layers": 1}),
    ]
    decoder_blocks = [
        ("res_x", {"num_layers": 2, "inject_noise": False}),
        ("compress_all", {"residual": True, "multiplier": 2}),
        ("compress_all", {"residual": True, "multiplier": 2}),
        ("compress_all", {"residual": True, "multiplier": 2}),
        ("res_x", {"num_layers": 2, "inject_noise": False}),
    ]
    return {
        "_class_name": "CausalVideoAutoencoder",
        "dims": 3,
        "encoder_blocks": encoder_blocks,
        "decoder_blocks": decoder_blocks,
        "latent_channels": latent_channels,
        "norm_layer": "pixel_norm",
        "patch_size": 4,
        "latent_log_var": "uniform",
        "use_quant_conv": False,
        "causal_decoder": False,
        "timestep_conditioning": True,
        "spatial_padding_mode": "replicate",
    }


def test_vae_patchify_unpatchify():
    import torch

    x = torch.randn(2, 3, 8, 64, 64)
    x_patched = patchify(x, patch_size_hw=4, patch_size_t=4)
    x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4)
    assert torch.allclose(x, x_unpatched)


def demo_video_autoencoder_forward_backward():
    # Configuration for the VideoAutoencoder
    config = create_video_autoencoder_demo_config()

    # Instantiate the VideoAutoencoder with the specified configuration
    video_autoencoder = CausalVideoAutoencoder.from_config(config)

    print(video_autoencoder)
    video_autoencoder.eval()
    # Print the total number of parameters in the video autoencoder
    total_params = sum(p.numel() for p in video_autoencoder.parameters())
    print(f"Total number of parameters in VideoAutoencoder: {total_params:,}")

    # Create a mock input tensor simulating a batch of videos
    # Shape: (batch_size, channels, depth, height, width)
    # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame
    input_videos = torch.randn(2, 3, 17, 64, 64)

    # Forward pass: encode and decode the input videos
    latent = video_autoencoder.encode(input_videos).latent_dist.mode()
    print(f"input shape={input_videos.shape}")
    print(f"latent shape={latent.shape}")

    timestep = torch.ones(input_videos.shape[0]) * 0.1
    reconstructed_videos = video_autoencoder.decode(
        latent, target_shape=input_videos.shape, timestep=timestep
    ).sample

    print(f"reconstructed shape={reconstructed_videos.shape}")

    # Validate that single image gets treated the same way as first frame
    input_image = input_videos[:, :, :1, :, :]
    image_latent = video_autoencoder.encode(input_image).latent_dist.mode()
    _ = video_autoencoder.decode(
        image_latent, target_shape=image_latent.shape, timestep=timestep
    ).sample

    first_frame_latent = latent[:, :, :1, :, :]

    assert torch.allclose(image_latent, first_frame_latent, atol=1e-6)
    # assert torch.allclose(reconstructed_image, reconstructed_videos[:, :, :1, :, :], atol=1e-6)
    # assert torch.allclose(image_latent, first_frame_latent, atol=1e-6)
    # assert (reconstructed_image == reconstructed_videos[:, :, :1, :, :]).all()

    # Calculate the loss (e.g., mean squared error)
    loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)

    # Perform backward pass
    loss.backward()

    print(f"Demo completed with loss: {loss.item()}")


# Ensure to call the demo function to execute the forward and backward pass
if __name__ == "__main__":
    demo_video_autoencoder_forward_backward()


================================================
FILE: ltx_video/models/autoencoders/conv_nd_factory.py
================================================
from typing import Tuple, Union

import torch

from ltx_video.models.autoencoders.dual_conv3d import DualConv3d
from ltx_video.models.autoencoders.causal_conv3d import CausalConv3d


def make_conv_nd(
    dims: Union[int, Tuple[int, int]],
    in_channels: int,
    out_channels: int,
    kernel_size: int,
    stride=1,
    padding=0,
    dilation=1,
    groups=1,
    bias=True,
    causal=False,
    spatial_padding_mode="zeros",
    temporal_padding_mode="zeros",
):
    if not (spatial_padding_mode == temporal_padding_mode or causal):
        raise NotImplementedError("spatial and temporal padding modes must be equal")
    if dims == 2:
        return torch.nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
            padding_mode=spatial_padding_mode,
        )
    elif dims == 3:
        if causal:
            return CausalConv3d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation,
                groups=groups,
                bias=bias,
                spatial_padding_mode=spatial_padding_mode,
            )
        return torch.nn.Conv3d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
            padding_mode=spatial_padding_mode,
        )
    elif dims == (2, 1):
        return DualConv3d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            bias=bias,
            padding_mode=spatial_padding_mode,
        )
    else:
        raise ValueError(f"unsupported dimensions: {dims}")


def make_linear_nd(
    dims: int,
    in_channels: int,
    out_channels: int,
    bias=True,
):
    if dims == 2:
        return torch.nn.Conv2d(
            in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
        )
    elif dims == 3 or dims == (2, 1):
        return torch.nn.Conv3d(
            in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
        )
    else:
        raise ValueError(f"unsupported dimensions: {dims}")


================================================
FILE: ltx_video/models/autoencoders/dual_conv3d.py
================================================
import math
from typing import Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange


class DualConv3d(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride: Union[int, Tuple[int, int, int]] = 1,
        padding: Union[int, Tuple[int, int, int]] = 0,
        dilation: Union[int, Tuple[int, int, int]] = 1,
        groups=1,
        bias=True,
        padding_mode="zeros",
    ):
        super(DualConv3d, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.padding_mode = padding_mode
        # Ensure kernel_size, stride, padding, and dilation are tuples of length 3
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size, kernel_size)
        if kernel_size == (1, 1, 1):
            raise ValueError(
                "kernel_size must be greater than 1. Use make_linear_nd instead."
            )
        if isinstance(stride, int):
            stride = (stride, stride, stride)
        if isinstance(padding, int):
            padding = (padding, padding, padding)
        if isinstance(dilation, int):
            dilation = (dilation, dilation, dilation)

        # Set parameters for convolutions
        self.groups = groups
        self.bias = bias

        # Define the size of the channels after the first convolution
        intermediate_channels = (
            out_channels if in_channels < out_channels else in_channels
        )

        # Define parameters for the first convolution
        self.weight1 = nn.Parameter(
            torch.Tensor(
                intermediate_channels,
                in_channels // groups,
                1,
                kernel_size[1],
                kernel_size[2],
            )
        )
        self.stride1 = (1, stride[1], stride[2])
        self.padding1 = (0, padding[1], padding[2])
        self.dilation1 = (1, dilation[1], dilation[2])
        if bias:
            self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
        else:
            self.register_parameter("bias1", None)

        # Define parameters for the second convolution
        self.weight2 = nn.Parameter(
            torch.Tensor(
                out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
            )
        )
        self.stride2 = (stride[0], 1, 1)
        self.padding2 = (padding[0], 0, 0)
        self.dilation2 = (dilation[0], 1, 1)
        if bias:
            self.bias2 = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter("bias2", None)

        # Initialize weights and biases
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
        if self.bias:
            fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
            bound1 = 1 / math.sqrt(fan_in1)
            nn.init.uniform_(self.bias1, -bound1, bound1)
            fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
            bound2 = 1 / math.sqrt(fan_in2)
            nn.init.uniform_(self.bias2, -bound2, bound2)

    def forward(self, x, use_conv3d=False, skip_time_conv=False):
        if use_conv3d:
            return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
        else:
            return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)

    def forward_with_3d(self, x, skip_time_conv):
        # First convolution
        x = F.conv3d(
            x,
            self.weight1,
            self.bias1,
            self.stride1,
            self.padding1,
            self.dilation1,
            self.groups,
            padding_mode=self.padding_mode,
        )

        if skip_time_conv:
            return x

        # Second convolution
        x = F.conv3d(
            x,
            self.weight2,
            self.bias2,
            self.stride2,
            self.padding2,
            self.dilation2,
            self.groups,
            padding_mode=self.padding_mode,
        )

        return x

    def forward_with_2d(self, x, skip_time_conv):
        b, c, d, h, w = x.shape

        # First 2D convolution
        x = rearrange(x, "b c d h w -> (b d) c h w")
        # Squeeze the depth dimension out of weight1 since it's 1
        weight1 = self.weight1.squeeze(2)
        # Select stride, padding, and dilation for the 2D convolution
        stride1 = (self.stride1[1], self.stride1[2])
        padding1 = (self.padding1[1], self.padding1[2])
        dilation1 = (self.dilation1[1], self.dilation1[2])
        x = F.conv2d(
            x,
            weight1,
            self.bias1,
            stride1,
            padding1,
            dilation1,
            self.groups,
            padding_mode=self.padding_mode,
        )

        _, _, h, w = x.shape

        if skip_time_conv:
            x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
            return x

        # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
        x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)

        # Reshape weight2 to match the expected dimensions for conv1d
        weight2 = self.weight2.squeeze(-1).squeeze(-1)
        # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
        stride2 = self.stride2[0]
        padding2 = self.padding2[0]
        dilation2 = self.dilation2[0]
        x = F.conv1d(
            x,
            weight2,
            self.bias2,
            stride2,
            padding2,
            dilation2,
            self.groups,
            padding_mode=self.padding_mode,
        )
        x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)

        return x

    @property
    def weight(self):
        return self.weight2


def test_dual_conv3d_consistency():
    # Initialize parameters
    in_channels = 3
    out_channels = 5
    kernel_size = (3, 3, 3)
    stride = (2, 2, 2)
    padding = (1, 1, 1)

    # Create an instance of the DualConv3d class
    dual_conv3d = DualConv3d(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        bias=True,
    )

    # Example input tensor
    test_input = torch.randn(1, 3, 10, 10, 10)

    # Perform forward passes with both 3D and 2D settings
    output_conv3d = dual_conv3d(test_input, use_conv3d=True)
    output_2d = dual_conv3d(test_input, use_conv3d=False)

    # Assert that the outputs from both methods are sufficiently close
    assert torch.allclose(
        output_conv3d, output_2d, atol=1e-6
    ), "Outputs are not consistent between 3D and 2D convolutions."


================================================
FILE: ltx_video/models/autoencoders/latent_upsampler.py
================================================
from typing import Optional, Union
from pathlib import Path
import os
import json

import torch
import torch.nn as nn
from einops import rearrange
from diffusers import ConfigMixin, ModelMixin
from safetensors.torch import safe_open

from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND


class ResBlock(nn.Module):
    def __init__(
        self, channels: int, mid_channels: Optional[int] = None, dims: int = 3
    ):
        super().__init__()
        if mid_channels is None:
            mid_channels = channels

        Conv = nn.Conv2d if dims == 2 else nn.Conv3d

        self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
        self.norm1 = nn.GroupNorm(32, mid_channels)
        self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
        self.norm2 = nn.GroupNorm(32, channels)
        self.activation = nn.SiLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.activation(x)
        x = self.conv2(x)
        x = self.norm2(x)
        x = self.activation(x + residual)
        return x


class LatentUpsampler(ModelMixin, ConfigMixin):
    """
    Model to spatially upsample VAE latents.

    Args:
        in_channels (`int`): Number of channels in the input latent
        mid_channels (`int`): Number of channels in the middle layers
        num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
        dims (`int`): Number of dimensions for convolutions (2 or 3)
        spatial_upsample (`bool`): Whether to spatially upsample the latent
        temporal_upsample (`bool`): Whether to temporally upsample the latent
    """

    def __init__(
        self,
        in_channels: int = 128,
        mid_channels: int = 512,
        num_blocks_per_stage: int = 4,
        dims: int = 3,
        spatial_upsample: bool = True,
        temporal_upsample: bool = False,
    ):
        super().__init__()

        self.in_channels = in_channels
        self.mid_channels = mid_channels
        self.num_blocks_per_stage = num_blocks_per_stage
        self.dims = dims
        self.spatial_upsample = spatial_upsample
        self.temporal_upsample = temporal_upsample

        Conv = nn.Conv2d if dims == 2 else nn.Conv3d

        self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1)
        self.initial_norm = nn.GroupNorm(32, mid_channels)
        self.initial_activation = nn.SiLU()

        self.res_blocks = nn.ModuleList(
            [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
        )

        if spatial_upsample and temporal_upsample:
            self.upsampler = nn.Sequential(
                nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
                PixelShuffleND(3),
            )
        elif spatial_upsample:
            self.upsampler = nn.Sequential(
                nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
                PixelShuffleND(2),
            )
        elif temporal_upsample:
            self.upsampler = nn.Sequential(
                nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
                PixelShuffleND(1),
            )
        else:
            raise ValueError(
                "Either spatial_upsample or temporal_upsample must be True"
            )

        self.post_upsample_res_blocks = nn.ModuleList(
            [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
        )

        self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1)

    def forward(self, latent: torch.Tensor) -> torch.Tensor:
        b, c, f, h, w = latent.shape

        if self.dims == 2:
            x = rearrange(latent, "b c f h w -> (b f) c h w")
            x = self.initial_conv(x)
            x = self.initial_norm(x)
            x = self.initial_activation(x)

            for block in self.res_blocks:
                x = block(x)

            x = self.upsampler(x)

            for block in self.post_upsample_res_blocks:
                x = block(x)

            x = self.final_conv(x)
            x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
        else:
            x = self.initial_conv(latent)
            x = self.initial_norm(x)
            x = self.initial_activation(x)

            for block in self.res_blocks:
                x = block(x)

            if self.temporal_upsample:
                x = self.upsampler(x)
                x = x[:, :, 1:, :, :]
            else:
                x = rearrange(x, "b c f h w -> (b f) c h w")
                x = self.upsampler(x)
                x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)

            for block in self.post_upsample_res_blocks:
                x = block(x)

            x = self.final_conv(x)

        return x

    @classmethod
    def from_config(cls, config):
        return cls(
            in_channels=config.get("in_channels", 4),
            mid_channels=config.get("mid_channels", 128),
            num_blocks_per_stage=config.get("num_blocks_per_stage", 4),
            dims=config.get("dims", 2),
            spatial_upsample=config.get("spatial_upsample", True),
            temporal_upsample=config.get("temporal_upsample", False),
        )

    def config(self):
        return {
            "_class_name": "LatentUpsampler",
            "in_channels": self.in_channels,
            "mid_channels": self.mid_channels,
            "num_blocks_per_stage": self.num_blocks_per_stage,
            "dims": self.dims,
            "spatial_upsample": self.spatial_upsample,
            "temporal_upsample": self.temporal_upsample,
        }

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_path: Optional[Union[str, os.PathLike]],
        *args,
        **kwargs,
    ):
        pretrained_model_path = Path(pretrained_model_path)
        if pretrained_model_path.is_file() and str(pretrained_model_path).endswith(
            ".safetensors"
        ):
            state_dict = {}
            with safe_open(pretrained_model_path, framework="pt", device="cpu") as f:
                metadata = f.metadata()
                for k in f.keys():
                    state_dict[k] = f.get_tensor(k)
            config = json.loads(metadata["config"])
            with torch.device("meta"):
                latent_upsampler = LatentUpsampler.from_config(config)
            latent_upsampler.load_state_dict(state_dict, assign=True)
        return latent_upsampler


if __name__ == "__main__":
    latent_upsampler = LatentUpsampler(num_blocks_per_stage=4, dims=3)
    print(latent_upsampler)
    total_params = sum(p.numel() for p in latent_upsampler.parameters())
    print(f"Total number of parameters: {total_params:,}")
    latent = torch.randn(1, 128, 9, 16, 16)
    upsampled_latent = latent_upsampler(latent)
    print(f"Upsampled latent shape: {upsampled_latent.shape}")


================================================
FILE: ltx_video/models/autoencoders/pixel_norm.py
================================================
import torch
from torch import nn


class PixelNorm(nn.Module):
    def __init__(self, dim=1, eps=1e-8):
        super(PixelNorm, self).__init__()
        self.dim = dim
        self.eps = eps

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)


================================================
FILE: ltx_video/models/autoencoders/pixel_shuffle.py
================================================
import torch.nn as nn
from einops import rearrange


class PixelShuffleND(nn.Module):
    def __init__(self, dims, upscale_factors=(2, 2, 2)):
        super().__init__()
        assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
        self.dims = dims
        self.upscale_factors = upscale_factors

    def forward(self, x):
        if self.dims == 3:
            return rearrange(
                x,
                "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
                p1=self.upscale_factors[0],
                p2=self.upscale_factors[1],
                p3=self.upscale_factors[2],
            )
        elif self.dims == 2:
            return rearrange(
                x,
                "b (c p1 p2) h w -> b c (h p1) (w p2)",
                p1=self.upscale_factors[0],
                p2=self.upscale_factors[1],
            )
        elif self.dims == 1:
            return rearrange(
                x,
                "b (c p1) f h w -> b c (f p1) h w",
                p1=self.upscale_factors[0],
            )


================================================
FILE: ltx_video/models/autoencoders/vae.py
================================================
from typing import Optional, Union

import torch
import inspect
import math
import torch.nn as nn
from diffusers import ConfigMixin, ModelMixin
from diffusers.models.autoencoders.vae import (
    DecoderOutput,
    DiagonalGaussianDistribution,
)
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd


class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
    """Variational Autoencoder (VAE) model with KL loss.

    VAE from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling.
    This model is a wrapper around an encoder and a decoder, and it adds a KL loss term to the reconstruction loss.

    Args:
        encoder (`nn.Module`):
            Encoder module.
        decoder (`nn.Module`):
            Decoder module.
        latent_channels (`int`, *optional*, defaults to 4):
            Number of latent channels.
    """

    def __init__(
        self,
        encoder: nn.Module,
        decoder: nn.Module,
        latent_channels: int = 4,
        dims: int = 2,
        sample_size=512,
        use_quant_conv: bool = True,
        normalize_latent_channels: bool = False,
    ):
        super().__init__()

        # pass init params to Encoder
        self.encoder = encoder
        self.use_quant_conv = use_quant_conv
        self.normalize_latent_channels = normalize_latent_channels

        # pass init params to Decoder
        quant_dims = 2 if dims == 2 else 3
        self.decoder = decoder
        if use_quant_conv:
            self.quant_conv = make_conv_nd(
                quant_dims, 2 * latent_channels, 2 * latent_channels, 1
            )
            self.post_quant_conv = make_conv_nd(
                quant_dims, latent_channels, latent_channels, 1
            )
        else:
            self.quant_conv = nn.Identity()
            self.post_quant_conv = nn.Identity()

        if normalize_latent_channels:
            if dims == 2:
                self.latent_norm_out = nn.BatchNorm2d(latent_channels, affine=False)
            else:
                self.latent_norm_out = nn.BatchNorm3d(latent_channels, affine=False)
        else:
            self.latent_norm_out = nn.Identity()
        self.use_z_tiling = False
        self.use_hw_tiling = False
        self.dims = dims
        self.z_sample_size = 1

        self.decoder_params = inspect.signature(self.decoder.forward).parameters

        # only relevant if vae tiling is enabled
        self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25)

    def set_tiling_params(self, sample_size: int = 512, overlap_factor: float = 0.25):
        self.tile_sample_min_size = sample_size
        num_blocks = len(self.encoder.down_blocks)
        self.tile_latent_min_size = int(sample_size / (2 ** (num_blocks - 1)))
        self.tile_overlap_factor = overlap_factor

    def enable_z_tiling(self, z_sample_size: int = 8):
        r"""
        Enable tiling during VAE decoding.

        When this option is enabled, the VAE will split the input tensor in tiles to compute decoding in several
        steps. This is useful to save some memory and allow larger batch sizes.
        """
        self.use_z_tiling = z_sample_size > 1
        self.z_sample_size = z_sample_size
        assert (
            z_sample_size % 8 == 0 or z_sample_size == 1
        ), f"z_sample_size must be a multiple of 8 or 1. Got {z_sample_size}."

    def disable_z_tiling(self):
        r"""
        Disable tiling during VAE decoding. If `use_tiling` was previously invoked, this method will go back to computing
        decoding in one step.
        """
        self.use_z_tiling = False

    def enable_hw_tiling(self):
        r"""
        Enable tiling during VAE decoding along the height and width dimension.
        """
        self.use_hw_tiling = True

    def disable_hw_tiling(self):
        r"""
        Disable tiling during VAE decoding along the height and width dimension.
        """
        self.use_hw_tiling = False

    def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True):
        overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
        blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
        row_limit = self.tile_latent_min_size - blend_extent

        # Split the image into 512x512 tiles and encode them separately.
        rows = []
        for i in range(0, x.shape[3], overlap_size):
            row = []
            for j in range(0, x.shape[4], overlap_size):
                tile = x[
                    :,
                    :,
                    :,
                    i : i + self.tile_sample_min_size,
                    j : j + self.tile_sample_min_size,
                ]
                tile = self.encoder(tile)
                tile = self.quant_conv(tile)
                row.append(tile)
            rows.append(row)
        result_rows = []
        for i, row in enumerate(rows):
            result_row = []
            for j, tile in enumerate(row):
                # blend the above tile and the left tile
                # to the current tile and add the current tile to the result row
                if i > 0:
                    tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
                if j > 0:
                    tile = self.blend_h(row[j - 1], tile, blend_extent)
                result_row.append(tile[:, :, :, :row_limit, :row_limit])
            result_rows.append(torch.cat(result_row, dim=4))

        moments = torch.cat(result_rows, dim=3)
        return moments

    def blend_z(
        self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
    ) -> torch.Tensor:
        blend_extent = min(a.shape[2], b.shape[2], blend_extent)
        for z in range(blend_extent):
            b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * (
                1 - z / blend_extent
            ) + b[:, :, z, :, :] * (z / blend_extent)
        return b

    def blend_v(
        self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
    ) -> torch.Tensor:
        blend_extent = min(a.shape[3], b.shape[3], blend_extent)
        for y in range(blend_extent):
            b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
                1 - y / blend_extent
            ) + b[:, :, :, y, :] * (y / blend_extent)
        return b

    def blend_h(
        self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
    ) -> torch.Tensor:
        blend_extent = min(a.shape[4], b.shape[4], blend_extent)
        for x in range(blend_extent):
            b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
                1 - x / blend_extent
            ) + b[:, :, :, :, x] * (x / blend_extent)
        return b

    def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape):
        overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
        blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
        row_limit = self.tile_sample_min_size - blend_extent
        tile_target_shape = (
            *target_shape[:3],
            self.tile_sample_min_size,
            self.tile_sample_min_size,
        )
        # Split z into overlapping 64x64 tiles and decode them separately.
        # The tiles have an overlap to avoid seams between tiles.
        rows = []
        for i in range(0, z.shape[3], overlap_size):
            row = []
            for j in range(0, z.shape[4], overlap_size):
                tile = z[
                    :,
                    :,
                    :,
                    i : i + self.tile_latent_min_size,
                    j : j + self.tile_latent_min_size,
                ]
                tile = self.post_quant_conv(tile)
                decoded = self.decoder(tile, target_shape=tile_target_shape)
                row.append(decoded)
            rows.append(row)
        result_rows = []
        for i, row in enumerate(rows):
            result_row = []
            for j, tile in enumerate(row):
                # blend the above tile and the left tile
                # to the current tile and add the current tile to the result row
                if i > 0:
                    tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
                if j > 0:
                    tile = self.blend_h(row[j - 1], tile, blend_extent)
                result_row.append(tile[:, :, :, :row_limit, :row_limit])
            result_rows.append(torch.cat(result_row, dim=4))

        dec = torch.cat(result_rows, dim=3)
        return dec

    def encode(
        self, z: torch.FloatTensor, return_dict: bool = True
    ) -> Union[DecoderOutput, torch.FloatTensor]:
        if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
            num_splits = z.shape[2] // self.z_sample_size
            sizes = [self.z_sample_size] * num_splits
            sizes = (
                sizes + [z.shape[2] - sum(sizes)]
                if z.shape[2] - sum(sizes) > 0
                else sizes
            )
            tiles = z.split(sizes, dim=2)
            moments_tiles = [
                (
                    self._hw_tiled_encode(z_tile, return_dict)
                    if self.use_hw_tiling
                    else self._encode(z_tile)
                )
                for z_tile in tiles
            ]
            moments = torch.cat(moments_tiles, dim=2)

        else:
            moments = (
                self._hw_tiled_encode(z, return_dict)
                if self.use_hw_tiling
                else self._encode(z)
            )

        posterior = DiagonalGaussianDistribution(moments)
        if not return_dict:
            return (posterior,)

        return AutoencoderKLOutput(latent_dist=posterior)

    def _normalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor:
        if isinstance(self.latent_norm_out, nn.BatchNorm3d):
            _, c, _, _, _ = z.shape
            z = torch.cat(
                [
                    self.latent_norm_out(z[:, : c // 2, :, :, :]),
                    z[:, c // 2 :, :, :, :],
                ],
                dim=1,
            )
        elif isinstance(self.latent_norm_out, nn.BatchNorm2d):
            raise NotImplementedError("BatchNorm2d not supported")
        return z

    def _unnormalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor:
        if isinstance(self.latent_norm_out, nn.BatchNorm3d):
            running_mean = self.latent_norm_out.running_mean.view(1, -1, 1, 1, 1)
            running_var = self.latent_norm_out.running_var.view(1, -1, 1, 1, 1)
            eps = self.latent_norm_out.eps

            z = z * torch.sqrt(running_var + eps) + running_mean
        elif isinstance(self.latent_norm_out, nn.BatchNorm3d):
            raise NotImplementedError("BatchNorm2d not supported")
        return z

    def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput:
        h = self.encoder(x)
        moments = self.quant_conv(h)
        moments = self._normalize_latent_channels(moments)
        return moments

    def _decode(
        self,
        z: torch.FloatTensor,
        target_shape=None,
        timestep: Optional[torch.Tensor] = None,
    ) -> Union[DecoderOutput, torch.FloatTensor]:
        z = self._unnormalize_latent_channels(z)
        z = self.post_quant_conv(z)
        if "timestep" in self.decoder_params:
            dec = self.decoder(z, target_shape=target_shape, timestep=timestep)
        else:
            dec = self.decoder(z, target_shape=target_shape)
        return dec

    def decode(
        self,
        z: torch.FloatTensor,
        return_dict: bool = True,
        target_shape=None,
        timestep: Optional[torch.Tensor] = None,
    ) -> Union[DecoderOutput, torch.FloatTensor]:
        assert target_shape is not None, "target_shape must be provided for decoding"
        if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
            reduction_factor = int(
                self.encoder.patch_size_t
                * 2
                ** (
                    len(self.encoder.down_blocks)
                    - 1
                    - math.sqrt(self.encoder.patch_size)
                )
            )
            split_size = self.z_sample_size // reduction_factor
            num_splits = z.shape[2] // split_size

            # copy target shape, and divide frame dimension (=2) by the context size
            target_shape_split = list(target_shape)
            target_shape_split[2] = target_shape[2] // num_splits

            decoded_tiles = [
                (
                    self._hw_tiled_decode(z_tile, target_shape_split)
                    if self.use_hw_tiling
                    else self._decode(z_tile, target_shape=target_shape_split)
                )
                for z_tile in torch.tensor_split(z, num_splits, dim=2)
            ]
            decoded = torch.cat(decoded_tiles, dim=2)
        else:
            decoded = (
                self._hw_tiled_decode(z, target_shape)
                if self.use_hw_tiling
                else self._decode(z, target_shape=target_shape, timestep=timestep)
            )

        if not return_dict:
            return (decoded,)

        return DecoderOutput(sample=decoded)

    def forward(
        self,
        sample: torch.FloatTensor,
        sample_posterior: bool = False,
        return_dict: bool = True,
        generator: Optional[torch.Generator] = None,
    ) -> Union[DecoderOutput, torch.FloatTensor]:
        r"""
        Args:
            sample (`torch.FloatTensor`): Input sample.
            sample_posterior (`bool`, *optional*, defaults to `False`):
                Whether to sample from the posterior.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether to return a [`DecoderOutput`] instead of a plain tuple.
            generator (`torch.Generator`, *optional*):
                Generator used to sample from the posterior.
        """
        x = sample
        posterior = self.encode(x).latent_dist
        if sample_posterior:
            z = posterior.sample(generator=generator)
        else:
            z = posterior.mode()
        dec = self.decode(z, target_shape=sample.shape).sample

        if not return_dict:
            return (dec,)

        return DecoderOutput(sample=dec)


================================================
FILE: ltx_video/models/autoencoders/vae_encode.py
================================================
from typing import Tuple
import torch
from diffusers import AutoencoderKL
from einops import rearrange
from torch import Tensor


from ltx_video.models.autoencoders.causal_video_autoencoder import (
    CausalVideoAutoencoder,
)
from ltx_video.models.autoencoders.video_autoencoder import (
    Downsample3D,
    VideoAutoencoder,
)

try:
    import torch_xla.core.xla_model as xm
except ImportError:
    xm = None


def vae_encode(
    media_items: Tensor,
    vae: AutoencoderKL,
    split_size: int = 1,
    vae_per_channel_normalize=False,
) -> Tensor:
    """
    Encodes media items (images or videos) into latent representations using a specified VAE model.
    The function supports processing batches of images or video frames and can handle the processing
    in smaller sub-batches if needed.

    Args:
        media_items (Tensor): A torch Tensor containing the media items to encode. The expected
            shape is (batch_size, channels, height, width) for images or (batch_size, channels,
            frames, height, width) for videos.
        vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library,
            pre-configured and loaded with the appropriate model weights.
        split_size (int, optional): The number of sub-batches to split the input batch into for encoding.
            If set to more than 1, the input media items are processed in smaller batches according to
            this value. Defaults to 1, which processes all items in a single batch.

    Returns:
        Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted
            to match the input shape, scaled by the model's configuration.

    Examples:
        >>> import torch
        >>> from diffusers import AutoencoderKL
        >>> vae = AutoencoderKL.from_pretrained('your-model-name')
        >>> images = torch.rand(10, 3, 8 256, 256)  # Example tensor with 10 videos of 8 frames.
        >>> latents = vae_encode(images, vae)
        >>> print(latents.shape)  # Output shape will depend on the model's latent configuration.

    Note:
        In case of a video, the function encodes the media item frame-by frame.
    """
    is_video_shaped = media_items.dim() == 5
    batch_size, channels = media_items.shape[0:2]

    if channels != 3:
        raise ValueError(f"Expects tensors with 3 channels, got {channels}.")

    if is_video_shaped and not isinstance(
        vae, (VideoAutoencoder, CausalVideoAutoencoder)
    ):
        media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
    if split_size > 1:
        if len(media_items) % split_size != 0:
            raise ValueError(
                "Error: The batch size must be divisible by 'train.vae_bs_split"
            )
        encode_bs = len(media_items) // split_size
        # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
        latents = []
        if media_items.device.type == "xla":
            xm.mark_step()
        for image_batch in media_items.split(encode_bs):
            latents.append(vae.encode(image_batch).latent_dist.sample())
            if media_items.device.type == "xla":
                xm.mark_step()
        latents = torch.cat(latents, dim=0)
    else:
        latents = vae.encode(media_items).latent_dist.sample()

    latents = normalize_latents(latents, vae, vae_per_channel_normalize)
    if is_video_shaped and not isinstance(
        vae, (VideoAutoencoder, CausalVideoAutoencoder)
    ):
        latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
    return latents


def vae_decode(
    latents: Tensor,
    vae: AutoencoderKL,
    is_video: bool = True,
    split_size: int = 1,
    vae_per_channel_normalize=False,
    timestep=None,
) -> Tensor:
    is_video_shaped = latents.dim() == 5
    batch_size = latents.shape[0]

    if is_video_shaped and not isinstance(
        vae, (VideoAutoencoder, CausalVideoAutoencoder)
    ):
        latents = rearrange(latents, "b c n h w -> (b n) c h w")
    if split_size > 1:
        if len(latents) % split_size != 0:
            raise ValueError(
                "Error: The batch size must be divisible by 'train.vae_bs_split"
            )
        encode_bs = len(latents) // split_size
        image_batch = [
            _run_decoder(
                latent_batch, vae, is_video, vae_per_channel_normalize, timestep
            )
            for latent_batch in latents.split(encode_bs)
        ]
        images = torch.cat(image_batch, dim=0)
    else:
        images = _run_decoder(
            latents, vae, is_video, vae_per_channel_normalize, timestep
        )

    if is_video_shaped and not isinstance(
        vae, (VideoAutoencoder, CausalVideoAutoencoder)
    ):
        images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
    return images


def _run_decoder(
    latents: Tensor,
    vae: AutoencoderKL,
    is_video: bool,
    vae_per_channel_normalize=False,
    timestep=None,
) -> Tensor:
    if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
        *_, fl, hl, wl = latents.shape
        temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
        latents = latents.to(vae.dtype)
        vae_decode_kwargs = {}
        if timestep is not None:
            vae_decode_kwargs["timestep"] = timestep
        image = vae.decode(
            un_normalize_latents(latents, vae, vae_per_channel_normalize),
            return_dict=False,
            target_shape=(
                1,
                3,
                fl * temporal_scale if is_video else 1,
                hl * spatial_scale,
                wl * spatial_scale,
            ),
            **vae_decode_kwargs,
        )[0]
    else:
        image = vae.decode(
            un_normalize_latents(latents, vae, vae_per_channel_normalize),
            return_dict=False,
        )[0]
    return image


def get_vae_size_scale_factor(vae: AutoencoderKL) -> float:
    if isinstance(vae, CausalVideoAutoencoder):
        spatial = vae.spatial_downscale_factor
        temporal = vae.temporal_downscale_factor
    else:
        down_blocks = len(
            [
                block
                for block in vae.encoder.down_blocks
                if isinstance(block.downsample, Downsample3D)
            ]
        )
        spatial = vae.config.patch_size * 2**down_blocks
        temporal = (
            vae.config.patch_size_t * 2**down_blocks
            if isinstance(vae, VideoAutoencoder)
            else 1
        )

    return (temporal, spatial, spatial)


def latent_to_pixel_coords(
    latent_coords: Tensor, vae: AutoencoderKL, causal_fix: bool = False
) -> Tensor:
    """
    Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
    configuration.

    Args:
        latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
        containing the latent corner coordinates of each token.
        vae (AutoencoderKL): The VAE model
        causal_fix (bool): Whether to take into account the different temporal scale
            of the first frame. Default = False for backwards compatibility.
    Returns:
        Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
    """

    scale_factors = get_vae_size_scale_factor(vae)
    causal_fix = isinstance(vae, CausalVideoAutoencoder) and causal_fix
    pixel_coords = latent_to_pixel_coords_from_factors(
        latent_coords, scale_factors, causal_fix
    )
    return pixel_coords


def latent_to_pixel_coords_from_factors(
    latent_coords: Tensor, scale_factors: Tuple, causal_fix: bool = False
) -> Tensor:
    pixel_coords = (
        latent_coords
        * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
    )
    if causal_fix:
        # Fix temporal scale for first frame to 1 due to causality
        pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
    return pixel_coords


def normalize_latents(
    latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
) -> Tensor:
    return (
        (latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1))
        / vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
        if vae_per_channel_normalize
        else latents * vae.config.scaling_factor
    )


def un_normalize_latents(
    latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
) -> Tensor:
    return (
        latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
        + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
        if vae_per_channel_normalize
        else latents / vae.config.scaling_factor
    )


================================================
FILE: ltx_video/models/autoencoders/video_autoencoder.py
================================================
import json
import os
from functools import partial
from types import SimpleNamespace
from typing import Any, Mapping, Optional, Tuple, Union

import torch
from einops import rearrange
from torch import nn
from torch.nn import functional

from diffusers.utils import logging

from ltx_video.utils.torch_utils import Identity
from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
from ltx_video.models.autoencoders.pixel_norm import PixelNorm
from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper

logger = logging.get_logger(__name__)


class VideoAutoencoder(AutoencoderKLWrapper):
    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
        *args,
        **kwargs,
    ):
        config_local_path = pretrained_model_name_or_path / "config.json"
        config = cls.load_config(config_local_path, **kwargs)
        video_vae = cls.from_config(config)
        video_vae.to(kwargs["torch_dtype"])

        model_local_path = pretrained_model_name_or_path / "autoencoder.pth"
        ckpt_state_dict = torch.load(model_local_path)
        video_vae.load_state_dict(ckpt_state_dict)

        statistics_local_path = (
            pretrained_model_name_or_path / "per_channel_statistics.json"
        )
        if statistics_local_path.exists():
            with open(statistics_local_path, "r") as file:
                data = json.load(file)
            transposed_data = list(zip(*data["data"]))
            data_dict = {
                col: torch.tensor(vals)
                for col, vals in zip(data["columns"], transposed_data)
            }
            video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
            video_vae.register_buffer(
                "mean_of_means",
                data_dict.get(
                    "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
                ),
            )

        return video_vae

    @staticmethod
    def from_config(config):
        assert (
            config["_class_name"] == "VideoAutoencoder"
        ), "config must have _class_name=VideoAutoencoder"
        if isinstance(config["dims"], list):
            config["dims"] = tuple(config["dims"])

        assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"

        double_z = config.get("double_z", True)
        latent_log_var = config.get(
            "latent_log_var", "per_channel" if double_z else "none"
        )
        use_quant_conv = config.get("use_quant_conv", True)

        if use_quant_conv and latent_log_var == "uniform":
            raise ValueError("uniform latent_log_var requires use_quant_conv=False")

        encoder = Encoder(
            dims=config["dims"],
            in_channels=config.get("in_channels", 3),
            out_channels=config["latent_channels"],
            block_out_channels=config["block_out_channels"],
            patch_size=config.get("patch_size", 1),
            latent_log_var=latent_log_var,
            norm_layer=config.get("norm_layer", "group_norm"),
            patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
            add_channel_padding=config.get("add_channel_padding", False),
        )

        decoder = Decoder(
            dims=config["dims"],
            in_channels=config["latent_channels"],
            out_channels=config.get("out_channels", 3),
            block_out_channels=config["block_out_channels"],
            patch_size=config.get("patch_size", 1),
            norm_layer=config.get("norm_layer", "group_norm"),
            patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
            add_channel_padding=config.get("add_channel_padding", False),
        )

        dims = config["dims"]
        return VideoAutoencoder(
            encoder=encoder,
            decoder=decoder,
            latent_channels=config["latent_channels"],
            dims=dims,
            use_quant_conv=use_quant_conv,
        )

    @property
    def config(self):
        return SimpleNamespace(
            _class_name="VideoAutoencoder",
            dims=self.dims,
            in_channels=self.encoder.conv_in.in_channels
            // (self.encoder.patch_size_t * self.encoder.patch_size**2),
            out_channels=self.decoder.conv_out.out_channels
            // (self.decoder.patch_size_t * self.decoder.patch_size**2),
            latent_channels=self.decoder.conv_in.in_channels,
            block_out_channels=[
                self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels
                for i in range(len(self.encoder.down_blocks))
            ],
            scaling_factor=1.0,
            norm_layer=self.encoder.norm_layer,
            patch_size=self.encoder.patch_size,
            latent_log_var=self.encoder.latent_log_var,
            use_quant_conv=self.use_quant_conv,
            patch_size_t=self.encoder.patch_size_t,
            add_channel_padding=self.encoder.add_channel_padding,
        )

    @property
    def is_video_supported(self):
        """
        Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images.
        """
        return self.dims != 2

    @property
    def downscale_factor(self):
        return self.encoder.downsample_factor

    def to_json_string(self) -> str:
        import json

        return json.dumps(self.config.__dict__)

    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
        model_keys = set(name for name, _ in self.named_parameters())

        key_mapping = {
            ".resnets.": ".res_blocks.",
            "downsamplers.0": "downsample",
            "upsamplers.0": "upsample",
        }

        converted_state_dict = {}
        for key, value in state_dict.items():
            for k, v in key_mapping.items():
                key = key.replace(k, v)

            if "norm" in key and key not in model_keys:
                logger.info(
                    f"Removing key {key} from state_dict as it is not present in the model"
                )
                continue

            converted_state_dict[key] = value

        super().load_state_dict(converted_state_dict, strict=strict)

    def last_layer(self):
        if hasattr(self.decoder, "conv_out"):
            if isinstance(self.decoder.conv_out, nn.Sequential):
                last_layer = self.decoder.conv_out[-1]
            else:
                last_layer = self.decoder.conv_out
        else:
            last_layer = self.decoder.layers[-1]
        return last_layer


class Encoder(nn.Module):
    r"""
    The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.

    Args:
        in_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        out_channels (`int`, *optional*, defaults to 3):
            The number of output channels.
        block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
            The number of output channels for each block.
        layers_per_block (`int`, *optional*, defaults to 2):
            The number of layers per block.
        norm_num_groups (`int`, *optional*, defaults to 32):
            The number of groups for normalization.
        patch_size (`int`, *optional*, defaults to 1):
            The patch size to use. Should be a power of 2.
        norm_layer (`str`, *optional*, defaults to `group_norm`):
            The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
        latent_log_var (`str`, *optional*, defaults to `per_channel`):
            The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
    """

    def __init__(
        self,
        dims: Union[int, Tuple[int, int]] = 3,
        in_channels: int = 3,
        out_channels: int = 3,
        block_out_channels: Tuple[int, ...] = (64,),
        layers_per_block: int = 2,
        norm_num_groups: int = 32,
        patch_size: Union[int, Tuple[int]] = 1,
        norm_layer: str = "group_norm",  # group_norm, pixel_norm
        latent_log_var: str = "per_channel",
        patch_size_t: Optional[int] = None,
        add_channel_padding: Optional[bool] = False,
    ):
        super().__init__()
        self.patch_size = patch_size
        self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
        self.add_channel_padding = add_channel_padding
        self.layers_per_block = layers_per_block
        self.norm_layer = norm_layer
        self.latent_channels = out_channels
        self.latent_log_var = latent_log_var
        if add_channel_padding:
            in_channels = in_channels * self.patch_size**3
        else:
            in_channels = in_channels * self.patch_size_t * self.patch_size**2
        self.in_channels = in_channels
        output_channel = block_out_channels[0]

        self.conv_in = make_conv_nd(
            dims=dims,
            in_channels=in_channels,
            out_channels=output_channel,
            kernel_size=3,
            stride=1,
            padding=1,
        )

        self.down_blocks = nn.ModuleList([])

        for i in range(len(block_out_channels)):
            input_channel = output_channel
            output_channel = block_out_channels[i]
            is_final_block = i == len(block_out_channels) - 1

            down_block = DownEncoderBlock3D(
                dims=dims,
                in_channels=input_channel,
                out_channels=output_channel,
                num_layers=self.layers_per_block,
                add_downsample=not is_final_block and 2**i >= patch_size,
                resnet_eps=1e-6,
                downsample_padding=0,
                resnet_groups=norm_num_groups,
                norm_layer=norm_layer,
            )
            self.down_blocks.append(down_block)

        self.mid_block = UNetMidBlock3D(
            dims=dims,
            in_channels=block_out_channels[-1],
            num_layers=self.layers_per_block,
            resnet_eps=1e-6,
            resnet_groups=norm_num_groups,
            norm_layer=norm_layer,
        )

        # out
        if norm_layer == "group_norm":
            self.conv_norm_out = nn.GroupNorm(
                num_channels=block_out_channels[-1],
                num_groups=norm_num_groups,
                eps=1e-6,
            )
        elif norm_layer == "pixel_norm":
            self.conv_norm_out = PixelNorm()
        self.conv_act = nn.SiLU()

        conv_out_channels = out_channels
        if latent_log_var == "per_channel":
            conv_out_channels *= 2
        elif latent_log_var == "uniform":
            conv_out_channels += 1
        elif latent_log_var != "none":
            raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
        self.conv_out = make_conv_nd(
            dims, block_out_channels[-1], conv_out_channels, 3, padding=1
        )

        self.gradient_checkpointing = False

    @property
    def downscale_factor(self):
        return (
            2
            ** len(
                [
                    block
                    for block in self.down_blocks
                    if isinstance(block.downsample, Downsample3D)
                ]
            )
            * self.patch_size
        )

    def forward(
        self, sample: torch.FloatTensor, return_features=False
    ) -> torch.FloatTensor:
        r"""The forward method of the `Encoder` class."""

        downsample_in_time = sample.shape[2] != 1

        # patchify
        patch_size_t = self.patch_size_t if downsample_in_time else 1
        sample = patchify(
            sample,
            patch_size_hw=self.patch_size,
            patch_size_t=patch_size_t,
            add_channel_padding=self.add_channel_padding,
        )

        sample = self.conv_in(sample)

        checkpoint_fn = (
            partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
            if self.gradient_checkpointing and self.training
            else lambda x: x
        )

        if return_features:
            features = []
        for down_block in self.down_blocks:
            sample = checkpoint_fn(down_block)(
                sample, downsample_in_time=downsample_in_time
            )
            if return_features:
                features.append(sample)

        sample = checkpoint_fn(self.mid_block)(sample)

        # post-process
        sample = self.conv_norm_out(sample)
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)

        if self.latent_log_var == "uniform":
            last_channel = sample[:, -1:, ...]
            num_dims = sample.dim()

            if num_dims == 4:
                # For shape (B, C, H, W)
                repeated_last_channel = last_channel.repeat(
                    1, sample.shape[1] - 2, 1, 1
                )
                sample = torch.cat([sample, repeated_last_channel], dim=1)
            elif num_dims == 5:
                # For shape (B, C, F, H, W)
                repeated_last_channel = last_channel.repeat(
                    1, sample.shape[1] - 2, 1, 1, 1
                )
                sample = torch.cat([sample, repeated_last_channel], dim=1)
            else:
                raise ValueError(f"Invalid input shape: {sample.shape}")

        if return_features:
            features.append(sample[:, : self.latent_channels, ...])
            return sample, features
        return sample


class Decoder(nn.Module):
    r"""
    The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.

    Args:
        in_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        out_channels (`int`, *optional*, defaults to 3):
            The number of output channels.
        block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
            The number of output channels for each block.
        layers_per_block (`int`, *optional*, defaults to 2):
            The number of layers per block.
        norm_num_groups (`int`, *optional*, defaults to 32):
            The number of groups for normalization.
        patch_size (`int`, *optional*, defaults to 1):
            The patch size to use. Should be a power of 2.
        norm_layer (`str`, *optional*, defaults to `group_norm`):
            The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
    """

    def __init__(
        self,
        dims,
        in_channels: int = 3,
        out_channels: int = 3,
        block_out_channels: Tuple[int, ...] = (64,),
        layers_per_block: int = 2,
        norm_num_groups: int = 32,
        patch_size: int = 1,
        norm_layer: str = "group_norm",
        patch_size_t: Optional[int] = None,
        add_channel_padding: Optional[bool] = False,
    ):
        super().__init__()
        self.patch_size = patch_size
        self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
        self.add_channel_padding = add_channel_padding
        self.layers_per_block = layers_per_block
        if add_channel_padding:
            out_channels = out_channels * self.patch_size**3
        else:
            out_channels = out_channels * self.patch_size_t * self.patch_size**2
        self.out_channels = out_channels

        self.conv_in = make_conv_nd(
            dims,
            in_channels,
            block_out_channels[-1],
            kernel_size=3,
            stride=1,
            padding=1,
        )

        self.mid_block = None
        self.up_blocks = nn.ModuleList([])

        self.mid_block = UNetMidBlock3D(
            dims=dims,
            in_channels=block_out_channels[-1],
            num_layers=self.layers_per_block,
            resnet_eps=1e-6,
            resnet_groups=norm_num_groups,
            norm_layer=norm_layer,
        )

        reversed_block_out_channels = list(reversed(block_out_channels))
        output_channel = reversed_block_out_channels[0]
        for i in range(len(reversed_block_out_channels)):
            prev_output_channel = output_channel
            output_channel = reversed_block_out_channels[i]

            is_final_block = i == len(block_out_channels) - 1

            up_block = UpDecoderBlock3D(
                dims=dims,
                num_layers=self.layers_per_block + 1,
                in_channels=prev_output_channel,
                out_channels=output_channel,
                add_upsample=not is_final_block
                and 2 ** (len(block_out_channels) - i - 1) > patch_size,
                resnet_eps=1e-6,
                resnet_groups=norm_num_groups,
                norm_layer=norm_layer,
            )
            self.up_blocks.append(up_block)

        if norm_layer == "group_norm":
            self.conv_norm_out = nn.GroupNorm(
                num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
            )
        elif norm_layer == "pixel_norm":
            self.conv_norm_out = PixelNorm()

        self.conv_act = nn.SiLU()
        self.conv_out = make_conv_nd(
            dims, block_out_channels[0], out_channels, 3, padding=1
        )

        self.gradient_checkpointing = False

    def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
        r"""The forward method of the `Decoder` class."""
        assert target_shape is not None, "target_shape must be provided"
        upsample_in_time = sample.shape[2] < target_shape[2]

        sample = self.conv_in(sample)

        upscale_dtype = next(iter(self.up_blocks.parameters())).dtype

        checkpoint_fn = (
            partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
            if self.gradient_checkpointing and self.training
            else lambda x: x
        )

        sample = checkpoint_fn(self.mid_block)(sample)
        sample = sample.to(upscale_dtype)

        for up_block in self.up_blocks:
            sample = checkpoint_fn(up_block)(sample, upsample_in_time=upsample_in_time)

        # post-process
        sample = self.conv_norm_out(sample)
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)

        # un-patchify
        patch_size_t = self.patch_size_t if upsample_in_time else 1
        sample = unpatchify(
            sample,
            patch_size_hw=self.patch_size,
            patch_size_t=patch_size_t,
            add_channel_padding=self.add_channel_padding,
        )

        return sample


class DownEncoderBlock3D(nn.Module):
    def __init__(
        self,
        dims: Union[int, Tuple[int, int]],
        in_channels: int,
        out_channels: int,
        dropout: float = 0.0,
        num_layers: int = 1,
        resnet_eps: float = 1e-6,
        resnet_groups: int = 32,
        add_downsample: bool = True,
        downsample_padding: int = 1,
        norm_layer: str = "group_norm",
    ):
        super().__init__()
        res_blocks = []

        for i in range(num_layers):
            in_channels = in_channels if i == 0 else out_channels
            res_blocks.append(
                ResnetBlock3D(
                    dims=dims,
                    in_channels=in_channels,
                    out_channels=out_channels,
                    eps=resnet_eps,
                    groups=resnet_groups,
                    dropout=dropout,
                    norm_layer=norm_layer,
                )
            )

        self.res_blocks = nn.ModuleList(res_blocks)

        if add_downsample:
            self.downsample = Downsample3D(
                dims,
                out_channels,
                out_channels=out_channels,
                padding=downsample_padding,
            )
        else:
            self.downsample = Identity()

    def forward(
        self, hidden_states: torch.FloatTensor, downsample_in_time
    ) -> torch.FloatTensor:
        for resnet in self.res_blocks:
            hidden_states = resnet(hidden_states)

        hidden_states = self.downsamp
Download .txt
gitextract_jt2joi_e/

├── .gitattributes
├── .github/
│   └── workflows/
│       └── pylint.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── configs/
│   ├── ltxv-13b-0.9.8-dev-fp8.yaml
│   ├── ltxv-13b-0.9.8-dev.yaml
│   ├── ltxv-13b-0.9.8-distilled-fp8.yaml
│   ├── ltxv-13b-0.9.8-distilled.yaml
│   ├── ltxv-2b-0.9.1.yaml
│   ├── ltxv-2b-0.9.5.yaml
│   ├── ltxv-2b-0.9.6-dev.yaml
│   ├── ltxv-2b-0.9.6-distilled.yaml
│   ├── ltxv-2b-0.9.8-distilled-fp8.yaml
│   ├── ltxv-2b-0.9.8-distilled.yaml
│   └── ltxv-2b-0.9.yaml
├── inference.py
├── ltx_video/
│   ├── __init__.py
│   ├── inference.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── autoencoders/
│   │   │   ├── __init__.py
│   │   │   ├── causal_conv3d.py
│   │   │   ├── causal_video_autoencoder.py
│   │   │   ├── conv_nd_factory.py
│   │   │   ├── dual_conv3d.py
│   │   │   ├── latent_upsampler.py
│   │   │   ├── pixel_norm.py
│   │   │   ├── pixel_shuffle.py
│   │   │   ├── vae.py
│   │   │   ├── vae_encode.py
│   │   │   └── video_autoencoder.py
│   │   └── transformers/
│   │       ├── __init__.py
│   │       ├── attention.py
│   │       ├── embeddings.py
│   │       ├── symmetric_patchifier.py
│   │       └── transformer3d.py
│   ├── pipelines/
│   │   ├── __init__.py
│   │   ├── crf_compressor.py
│   │   └── pipeline_ltx_video.py
│   ├── schedulers/
│   │   ├── __init__.py
│   │   └── rf.py
│   └── utils/
│       ├── __init__.py
│       ├── diffusers_config_mapping.py
│       ├── prompt_enhance_utils.py
│       ├── skip_layer_strategy.py
│       └── torch_utils.py
├── pyproject.toml
└── tests/
    ├── conftest.py
    ├── test_configs.py
    ├── test_inference.py
    ├── test_scheduler.py
    ├── test_vae.py
    └── utils/
        └── .gitattributes
Download .txt
SYMBOL INDEX (287 symbols across 28 files)

FILE: inference.py
  function main (line 6) | def main():

FILE: ltx_video/inference.py
  function get_total_gpu_memory (line 44) | def get_total_gpu_memory():
  function get_device (line 51) | def get_device():
  function load_image_to_tensor_with_resize_and_crop (line 59) | def load_image_to_tensor_with_resize_and_crop(
  function calculate_padding (line 108) | def calculate_padding(
  function convert_prompt_to_filename (line 128) | def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
  function get_unique_filename (line 155) | def get_unique_filename(
  function seed_everething (line 175) | def seed_everething(seed: int):
  function create_transformer (line 185) | def create_transformer(ckpt_path: str, precision: str) -> Transformer3DM...
  function create_ltx_video_pipeline (line 207) | def create_ltx_video_pipeline(
  function create_latent_upsampler (line 294) | def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
  function load_pipeline_config (line 301) | def load_pipeline_config(pipeline_config: str):
  class InferenceConfig (line 317) | class InferenceConfig:
  function infer (line 389) | def infer(config: InferenceConfig):
  function prepare_conditioning (line 637) | def prepare_conditioning(
  function get_media_num_frames (line 690) | def get_media_num_frames(media_path: str) -> int:
  function load_media_file (line 702) | def load_media_file(

FILE: ltx_video/models/autoencoders/causal_conv3d.py
  class CausalConv3d (line 7) | class CausalConv3d(nn.Module):
    method __init__ (line 8) | def __init__(
    method forward (line 44) | def forward(self, x, causal: bool = True):
    method weight (line 62) | def weight(self):

FILE: ltx_video/models/autoencoders/causal_video_autoencoder.py
  class CausalVideoAutoencoder (line 33) | class CausalVideoAutoencoder(AutoencoderKLWrapper):
    method from_pretrained (line 35) | def from_pretrained(
    method from_config (line 123) | def from_config(config):
    method config (line 180) | def config(self):
    method is_video_supported (line 201) | def is_video_supported(self):
    method spatial_downscale_factor (line 208) | def spatial_downscale_factor(self):
    method temporal_downscale_factor (line 228) | def temporal_downscale_factor(self):
    method to_json_string (line 243) | def to_json_string(self) -> str:
    method load_state_dict (line 248) | def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool ...
    method last_layer (line 298) | def last_layer(self):
    method set_use_tpu_flash_attention (line 308) | def set_use_tpu_flash_attention(self):
  class Encoder (line 315) | class Encoder(nn.Module):
    method __init__ (line 340) | def __init__(
    method forward (line 508) | def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
  class Decoder (line 558) | class Decoder(nn.Module):
    method __init__ (line 583) | def __init__(
    method forward (line 733) | def forward(
  class UNetMidBlock3D (line 803) | class UNetMidBlock3D(nn.Module):
    method __init__ (line 829) | def __init__(
    method forward (line 895) | def forward(
  class SpaceToDepthDownsample (line 974) | class SpaceToDepthDownsample(nn.Module):
    method __init__ (line 975) | def __init__(self, dims, in_channels, out_channels, stride, spatial_pa...
    method forward (line 989) | def forward(self, x, causal: bool = True):
  class DepthToSpaceUpsample (line 1021) | class DepthToSpaceUpsample(nn.Module):
    method __init__ (line 1022) | def __init__(
    method forward (line 1049) | def forward(self, x, causal: bool = True):
  class LayerNorm (line 1066) | class LayerNorm(nn.Module):
    method __init__ (line 1067) | def __init__(self, dim, eps, elementwise_affine=True) -> None:
    method forward (line 1071) | def forward(self, x):
  class ResnetBlock3D (line 1078) | class ResnetBlock3D(nn.Module):
    method __init__ (line 1091) | def __init__(
    method _feed_spatial_noise (line 1181) | def _feed_spatial_noise(
    method forward (line 1195) | def forward(
  function patchify (line 1259) | def patchify(x, patch_size_hw, patch_size_t=1):
  function unpatchify (line 1280) | def unpatchify(x, patch_size_hw, patch_size_t=1):
  function create_video_autoencoder_demo_config (line 1300) | def create_video_autoencoder_demo_config(
  function test_vae_patchify_unpatchify (line 1334) | def test_vae_patchify_unpatchify():
  function demo_video_autoencoder_forward_backward (line 1343) | def demo_video_autoencoder_forward_backward():

FILE: ltx_video/models/autoencoders/conv_nd_factory.py
  function make_conv_nd (line 9) | def make_conv_nd(
  function make_linear_nd (line 75) | def make_linear_nd(

FILE: ltx_video/models/autoencoders/dual_conv3d.py
  class DualConv3d (line 10) | class DualConv3d(nn.Module):
    method __init__ (line 11) | def __init__(
    method reset_parameters (line 86) | def reset_parameters(self):
    method forward (line 97) | def forward(self, x, use_conv3d=False, skip_time_conv=False):
    method forward_with_3d (line 103) | def forward_with_3d(self, x, skip_time_conv):
    method forward_with_2d (line 133) | def forward_with_2d(self, x, skip_time_conv):
    method weight (line 185) | def weight(self):
  function test_dual_conv3d_consistency (line 189) | def test_dual_conv3d_consistency():

FILE: ltx_video/models/autoencoders/latent_upsampler.py
  class ResBlock (line 15) | class ResBlock(nn.Module):
    method __init__ (line 16) | def __init__(
    method forward (line 31) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class LatentUpsampler (line 42) | class LatentUpsampler(ModelMixin, ConfigMixin):
    method __init__ (line 55) | def __init__(
    method forward (line 109) | def forward(self, latent: torch.Tensor) -> torch.Tensor:
    method from_config (line 152) | def from_config(cls, config):
    method config (line 162) | def config(self):
    method from_pretrained (line 174) | def from_pretrained(

FILE: ltx_video/models/autoencoders/pixel_norm.py
  class PixelNorm (line 5) | class PixelNorm(nn.Module):
    method __init__ (line 6) | def __init__(self, dim=1, eps=1e-8):
    method forward (line 11) | def forward(self, x):

FILE: ltx_video/models/autoencoders/pixel_shuffle.py
  class PixelShuffleND (line 5) | class PixelShuffleND(nn.Module):
    method __init__ (line 6) | def __init__(self, dims, upscale_factors=(2, 2, 2)):
    method forward (line 12) | def forward(self, x):

FILE: ltx_video/models/autoencoders/vae.py
  class AutoencoderKLWrapper (line 16) | class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
    method __init__ (line 31) | def __init__(
    method set_tiling_params (line 79) | def set_tiling_params(self, sample_size: int = 512, overlap_factor: fl...
    method enable_z_tiling (line 85) | def enable_z_tiling(self, z_sample_size: int = 8):
    method disable_z_tiling (line 98) | def disable_z_tiling(self):
    method enable_hw_tiling (line 105) | def enable_hw_tiling(self):
    method disable_hw_tiling (line 111) | def disable_hw_tiling(self):
    method _hw_tiled_encode (line 117) | def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = T...
    method blend_z (line 154) | def blend_z(
    method blend_v (line 164) | def blend_v(
    method blend_h (line 174) | def blend_h(
    method _hw_tiled_decode (line 184) | def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape):
    method encode (line 226) | def encode(
    method _normalize_latent_channels (line 261) | def _normalize_latent_channels(self, z: torch.FloatTensor) -> torch.Fl...
    method _unnormalize_latent_channels (line 275) | def _unnormalize_latent_channels(self, z: torch.FloatTensor) -> torch....
    method _encode (line 286) | def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput:
    method _decode (line 292) | def _decode(
    method decode (line 306) | def decode(
    method forward (line 352) | def forward(

FILE: ltx_video/models/autoencoders/vae_encode.py
  function vae_encode (line 22) | def vae_encode(
  function vae_decode (line 94) | def vae_decode(
  function _run_decoder (line 134) | def _run_decoder(
  function get_vae_size_scale_factor (line 168) | def get_vae_size_scale_factor(vae: AutoencoderKL) -> float:
  function latent_to_pixel_coords (line 190) | def latent_to_pixel_coords(
  function latent_to_pixel_coords_from_factors (line 215) | def latent_to_pixel_coords_from_factors(
  function normalize_latents (line 228) | def normalize_latents(
  function un_normalize_latents (line 239) | def un_normalize_latents(

FILE: ltx_video/models/autoencoders/video_autoencoder.py
  class VideoAutoencoder (line 22) | class VideoAutoencoder(AutoencoderKLWrapper):
    method from_pretrained (line 24) | def from_pretrained(
    method from_config (line 61) | def from_config(config):
    method config (line 112) | def config(self):
    method is_video_supported (line 135) | def is_video_supported(self):
    method downscale_factor (line 142) | def downscale_factor(self):
    method to_json_string (line 145) | def to_json_string(self) -> str:
    method load_state_dict (line 150) | def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool ...
    method last_layer (line 174) | def last_layer(self):
  class Encoder (line 185) | class Encoder(nn.Module):
    method __init__ (line 208) | def __init__(
    method downscale_factor (line 300) | def downscale_factor(self):
    method forward (line 313) | def forward(
  class Decoder (line 378) | class Decoder(nn.Module):
    method __init__ (line 399) | def __init__(
    method forward (line 479) | def forward(self, sample: torch.FloatTensor, target_shape) -> torch.Fl...
  class DownEncoderBlock3D (line 517) | class DownEncoderBlock3D(nn.Module):
    method __init__ (line 518) | def __init__(
    method forward (line 560) | def forward(
  class UNetMidBlock3D (line 573) | class UNetMidBlock3D(nn.Module):
    method __init__ (line 591) | def __init__(
    method forward (line 621) | def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
  class UpDecoderBlock3D (line 628) | class UpDecoderBlock3D(nn.Module):
    method __init__ (line 629) | def __init__(
    method forward (line 671) | def forward(
  class ResnetBlock3D (line 682) | class ResnetBlock3D(nn.Module):
    method __init__ (line 695) | def __init__(
    method forward (line 746) | def forward(
  class Downsample3D (line 773) | class Downsample3D(nn.Module):
    method __init__ (line 774) | def __init__(
    method forward (line 796) | def forward(self, x, downsample_in_time=True):
  class Upsample3D (line 812) | class Upsample3D(nn.Module):
    method __init__ (line 819) | def __init__(self, dims, channels, out_channels=None):
    method forward (line 828) | def forward(self, x, upsample_in_time):
  function patchify (line 868) | def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
  function unpatchify (line 906) | def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=Fal...
  function create_video_autoencoder_config (line 934) | def create_video_autoencoder_config(
  function create_video_autoencoder_pathify4x4x4_config (line 958) | def create_video_autoencoder_pathify4x4x4_config(
  function create_video_autoencoder_pathify4x4_config (line 979) | def create_video_autoencoder_pathify4x4_config(
  function test_vae_patchify_unpatchify (line 997) | def test_vae_patchify_unpatchify():
  function demo_video_autoencoder_forward_backward (line 1006) | def demo_video_autoencoder_forward_backward():

FILE: ltx_video/models/transformers/attention.py
  class BasicTransformerBlock (line 38) | class BasicTransformerBlock(nn.Module):
    method __init__ (line 77) | def __init__(
    method set_use_tpu_flash_attention (line 184) | def set_use_tpu_flash_attention(self):
    method set_chunk_feed_forward (line 193) | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int =...
    method forward (line 198) | def forward(
  class Attention (line 325) | class Attention(nn.Module):
    method __init__ (line 378) | def __init__(
    method set_use_tpu_flash_attention (line 526) | def set_use_tpu_flash_attention(self):
    method set_processor (line 532) | def set_processor(self, processor: "AttnProcessor") -> None:
    method get_processor (line 554) | def get_processor(
    method forward (line 660) | def forward(
    method batch_to_head_dim (line 720) | def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
    method head_to_batch_dim (line 739) | def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) ->...
    method get_attention_scores (line 771) | def get_attention_scores(
    method prepare_attention_mask (line 825) | def prepare_attention_mask(
    method norm_encoder_hidden_states (line 884) | def norm_encoder_hidden_states(
    method apply_rotary_emb (line 918) | def apply_rotary_emb(
  class AttnProcessor2_0 (line 935) | class AttnProcessor2_0:
    method __init__ (line 940) | def __init__(self):
    method __call__ (line 943) | def __call__(
  class AttnProcessor (line 1117) | class AttnProcessor:
    method __call__ (line 1122) | def __call__(
  class FeedForward (line 1204) | class FeedForward(nn.Module):
    method __init__ (line 1218) | def __init__(
    method forward (line 1257) | def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> ...

FILE: ltx_video/models/transformers/embeddings.py
  function get_timestep_embedding (line 10) | def get_timestep_embedding(
  function get_3d_sincos_pos_embed (line 53) | def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f):
  function get_3d_sincos_pos_embed_from_grid (line 66) | def get_3d_sincos_pos_embed_from_grid(embed_dim, grid):
  function get_1d_sincos_pos_embed_from_grid (line 79) | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
  class SinusoidalPositionalEmbedding (line 103) | class SinusoidalPositionalEmbedding(nn.Module):
    method __init__ (line 115) | def __init__(self, embed_dim: int, max_seq_length: int = 32):
    method forward (line 126) | def forward(self, x):

FILE: ltx_video/models/transformers/symmetric_patchifier.py
  class Patchifier (line 10) | class Patchifier(ConfigMixin, ABC):
    method __init__ (line 11) | def __init__(self, patch_size: int):
    method patchify (line 16) | def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
    method unpatchify (line 20) | def unpatchify(
    method patch_size (line 30) | def patch_size(self):
    method get_latent_coords (line 33) | def get_latent_coords(
  class SymmetricPatchifier (line 54) | class SymmetricPatchifier(Patchifier):
    method patchify (line 55) | def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
    method unpatchify (line 67) | def unpatchify(

FILE: ltx_video/models/transformers/transformer3d.py
  class Transformer3DModelOutput (line 35) | class Transformer3DModelOutput(BaseOutput):
  class Transformer3DModel (line 48) | class Transformer3DModel(ModelMixin, ConfigMixin):
    method __init__ (line 52) | def __init__(
    method set_use_tpu_flash_attention (line 162) | def set_use_tpu_flash_attention(self):
    method create_skip_layer_mask (line 173) | def create_skip_layer_mask(
    method _set_gradient_checkpointing (line 190) | def _set_gradient_checkpointing(self, module, value=False):
    method get_fractional_positions (line 194) | def get_fractional_positions(self, indices_grid):
    method precompute_freqs_cis (line 204) | def precompute_freqs_cis(self, indices_grid, spacing="exp"):
    method load_state_dict (line 259) | def load_state_dict(
    method from_pretrained (line 274) | def from_pretrained(
    method forward (line 330) | def forward(

FILE: ltx_video/pipelines/crf_compressor.py
  function _encode_single_frame (line 7) | def _encode_single_frame(output_file, image_array: np.ndarray, crf):
  function _decode_single_frame (line 24) | def _decode_single_frame(video_file):
  function compress (line 34) | def compress(image: torch.Tensor, crf=29):

FILE: ltx_video/pipelines/pipeline_ltx_video.py
  function retrieve_timesteps (line 125) | def retrieve_timesteps(
  class ConditioningItem (line 194) | class ConditioningItem:
  class LTXVideoPipeline (line 213) | class LTXVideoPipeline(DiffusionPipeline):
    method __init__ (line 262) | def __init__(
    method mask_text_embeddings (line 298) | def mask_text_embeddings(self, emb, mask):
    method encode_prompt (line 307) | def encode_prompt(
    method prepare_extra_step_kwargs (line 479) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 500) | def check_inputs(
    method _text_preprocessing (line 586) | def _text_preprocessing(self, text):
    method add_noise_to_image_conditioning_latents (line 597) | def add_noise_to_image_conditioning_latents(
    method prepare_latents (line 623) | def prepare_latents(
    method classify_height_width_bin (line 704) | def classify_height_width_bin(
    method resize_and_crop_tensor (line 714) | def resize_and_crop_tensor(
    method resize_tensor (line 740) | def resize_tensor(media_items, height, width):
    method __call__ (line 754) | def __call__(
    method denoising_step (line 1348) | def denoising_step(
    method prepare_conditioning (line 1383) | def prepare_conditioning(
    method _resize_conditioning_item (line 1590) | def _resize_conditioning_item(
    method _get_latent_spatial_position (line 1605) | def _get_latent_spatial_position(
    method _handle_non_first_conditioning_sequence (line 1653) | def _handle_non_first_conditioning_sequence(
    method trim_conditioning_sequence (line 1728) | def trim_conditioning_sequence(
    method tone_map_latents (line 1749) | def tone_map_latents(
  function adain_filter_latent (line 1790) | def adain_filter_latent(
  class LTXMultiScalePipeline (line 1821) | class LTXMultiScalePipeline:
    method _upsample_latents (line 1822) | def _upsample_latents(
    method __init__ (line 1836) | def __init__(
    method __call__ (line 1843) | def __call__(

FILE: ltx_video/schedulers/rf.py
  function linear_quadratic_schedule (line 25) | def linear_quadratic_schedule(num_steps, threshold_noise=0.025, linear_s...
  function simple_diffusion_resolution_dependent_timestep_shift (line 49) | def simple_diffusion_resolution_dependent_timestep_shift(
  function time_shift (line 69) | def time_shift(mu: float, sigma: float, t: Tensor):
  function get_normal_shift (line 73) | def get_normal_shift(
  function strech_shifts_to_terminal (line 85) | def strech_shifts_to_terminal(shifts: Tensor, terminal=0.1):
  function sd3_resolution_dependent_timestep_shift (line 112) | def sd3_resolution_dependent_timestep_shift(
  class TimestepShifter (line 152) | class TimestepShifter(ABC):
    method shift_timesteps (line 154) | def shift_timesteps(self, samples_shape: torch.Size, timesteps: Tensor...
  class RectifiedFlowSchedulerOutput (line 159) | class RectifiedFlowSchedulerOutput(BaseOutput):
  class RectifiedFlowScheduler (line 176) | class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
    method __init__ (line 180) | def __init__(
    method get_initial_timesteps (line 201) | def get_initial_timesteps(
    method shift_timesteps (line 216) | def shift_timesteps(self, samples_shape: torch.Size, timesteps: Tensor...
    method set_timesteps (line 227) | def set_timesteps(
    method from_pretrained (line 264) | def from_pretrained(pretrained_model_path: Union[str, os.PathLike]):
    method scale_model_input (line 288) | def scale_model_input(
    method step (line 305) | def step(
    method add_noise (line 376) | def add_noise(

FILE: ltx_video/utils/diffusers_config_mapping.py
  function make_hashable_key (line 1) | def make_hashable_key(dict_key):

FILE: ltx_video/utils/prompt_enhance_utils.py
  function tensor_to_pil (line 47) | def tensor_to_pil(tensor):
  function generate_cinematic_prompt (line 64) | def generate_cinematic_prompt(
  function _get_first_frames_from_conditioning_item (line 113) | def _get_first_frames_from_conditioning_item(conditioning_item) -> List[...
  function _generate_t2v_prompt (line 121) | def _generate_t2v_prompt(
  function _generate_i2v_prompt (line 151) | def _generate_i2v_prompt(
  function _generate_image_captions (line 188) | def _generate_image_captions(
  function _generate_and_decode_prompts (line 211) | def _generate_and_decode_prompts(

FILE: ltx_video/utils/skip_layer_strategy.py
  class SkipLayerStrategy (line 4) | class SkipLayerStrategy(Enum):

FILE: ltx_video/utils/torch_utils.py
  function append_dims (line 5) | def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
  class Identity (line 17) | class Identity(nn.Module):
    method __init__ (line 20) | def __init__(self, *args, **kwargs) -> None:  # pylint: disable=unused...
    method forward (line 24) | def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:

FILE: tests/conftest.py
  function pytest_make_parametrize_id (line 14) | def pytest_make_parametrize_id(config, val, argname):
  function num_latent_channels (line 21) | def num_latent_channels():
  function video_autoencoder (line 26) | def video_autoencoder(num_latent_channels):
  function transformer_config (line 34) | def transformer_config(num_latent_channels):
  function synthetic_ckpt_path (line 67) | def synthetic_ckpt_path(

FILE: tests/test_configs.py
  function prompt (line 10) | def prompt():
  function test_run_config (line 20) | def test_run_config(tmp_path, prompt, pipeline_config):

FILE: tests/test_inference.py
  function input_image_path (line 16) | def input_image_path():
  function input_video_path (line 21) | def input_video_path():
  function base_inference_config (line 25) | def base_inference_config(tmp_path, pipeline_config):
  function base_pipeline_config (line 44) | def base_pipeline_config(synthetic_ckpt_path):
  function test_condition_modes (line 66) | def test_condition_modes(
  function test_vid2vid (line 93) | def test_vid2vid(tmp_path, input_video_path, base_pipeline_config):
  function test_pipeline_on_batch (line 106) | def test_pipeline_on_batch(tmp_path, base_pipeline_config):
  function test_prompt_enhancement (line 161) | def test_prompt_enhancement(tmp_path, base_pipeline_config):

FILE: tests/test_scheduler.py
  function init_latents_and_scheduler (line 6) | def init_latents_and_scheduler(sampler):
  function test_scheduler_default_behavior (line 18) | def test_scheduler_default_behavior(sampler):
  function test_scheduler_per_token (line 41) | def test_scheduler_per_token(sampler):
  function test_scheduler_t_not_in_list (line 71) | def test_scheduler_t_not_in_list(sampler):

FILE: tests/test_vae.py
  function test_encode_decode_shape (line 8) | def test_encode_decode_shape(video_autoencoder, num_latent_channels):
  function test_temporal_causality (line 32) | def test_temporal_causality(video_autoencoder):
  function test_downscale_factors (line 59) | def test_downscale_factors(
Condensed preview — 54 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (446K chars).
[
  {
    "path": ".gitattributes",
    "chars": 225,
    "preview": "*.jpg filter=lfs diff=lfs merge=lfs -text\n*.jpeg filter=lfs diff=lfs merge=lfs -text\n*.png filter=lfs diff=lfs merge=lfs"
  },
  {
    "path": ".github/workflows/pylint.yml",
    "chars": 711,
    "preview": "name: Ruff\n\non: [push]\n\njobs:\n  build:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        python-version: [\""
  },
  {
    "path": ".gitignore",
    "chars": 3200,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": ".pre-commit-config.yaml",
    "chars": 524,
    "preview": "repos:\n  - repo: https://github.com/astral-sh/ruff-pre-commit\n    # Ruff version.\n    rev: v0.2.2\n    hooks:\n      # Run"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 26536,
    "preview": "<div align=\"center\">\n\n# LTX-Video\n\n[![Website](https://img.shields.io/badge/Website-LTXV-181717?logo=google-chrome)](htt"
  },
  {
    "path": "configs/ltxv-13b-0.9.8-dev-fp8.yaml",
    "chars": 1398,
    "preview": "pipeline_type: multi-scale\ncheckpoint_path: \"ltxv-13b-0.9.8-dev-fp8.safetensors\"\ndownscale_factor: 0.6666666\nspatial_ups"
  },
  {
    "path": "configs/ltxv-13b-0.9.8-dev.yaml",
    "chars": 1330,
    "preview": "pipeline_type: multi-scale\ncheckpoint_path: \"ltxv-13b-0.9.8-dev.safetensors\"\ndownscale_factor: 0.6666666\nspatial_upscale"
  },
  {
    "path": "configs/ltxv-13b-0.9.8-distilled-fp8.yaml",
    "chars": 1147,
    "preview": "pipeline_type: multi-scale\ncheckpoint_path: \"ltxv-13b-0.9.8-distilled-fp8.safetensors\"\ndownscale_factor: 0.6666666\nspati"
  },
  {
    "path": "configs/ltxv-13b-0.9.8-distilled.yaml",
    "chars": 1080,
    "preview": "pipeline_type: multi-scale\ncheckpoint_path: \"ltxv-13b-0.9.8-distilled.safetensors\"\ndownscale_factor: 0.6666666\nspatial_u"
  },
  {
    "path": "configs/ltxv-2b-0.9.1.yaml",
    "chars": 737,
    "preview": "pipeline_type: base\ncheckpoint_path: \"ltx-video-2b-v0.9.1.safetensors\"\nguidance_scale: 3\nstg_scale: 1\nrescaling_scale: 0"
  },
  {
    "path": "configs/ltxv-2b-0.9.5.yaml",
    "chars": 737,
    "preview": "pipeline_type: base\ncheckpoint_path: \"ltx-video-2b-v0.9.5.safetensors\"\nguidance_scale: 3\nstg_scale: 1\nrescaling_scale: 0"
  },
  {
    "path": "configs/ltxv-2b-0.9.6-dev.yaml",
    "chars": 741,
    "preview": "pipeline_type: base\ncheckpoint_path: \"ltxv-2b-0.9.6-dev-04-25.safetensors\"\nguidance_scale: 3\nstg_scale: 1\nrescaling_scal"
  },
  {
    "path": "configs/ltxv-2b-0.9.6-distilled.yaml",
    "chars": 721,
    "preview": "pipeline_type: base\ncheckpoint_path: \"ltxv-2b-0.9.6-distilled-04-25.safetensors\"\nguidance_scale: 1\nstg_scale: 0\nrescalin"
  },
  {
    "path": "configs/ltxv-2b-0.9.8-distilled-fp8.yaml",
    "chars": 1112,
    "preview": "pipeline_type: multi-scale\ncheckpoint_path: \"ltxv-2b-0.9.8-distilled-fp8.safetensors\"\ndownscale_factor: 0.6666666\nspatia"
  },
  {
    "path": "configs/ltxv-2b-0.9.8-distilled.yaml",
    "chars": 1045,
    "preview": "pipeline_type: multi-scale\ncheckpoint_path: \"ltxv-2b-0.9.8-distilled.safetensors\"\ndownscale_factor: 0.6666666\nspatial_up"
  },
  {
    "path": "configs/ltxv-2b-0.9.yaml",
    "chars": 735,
    "preview": "pipeline_type: base\ncheckpoint_path: \"ltx-video-2b-v0.9.safetensors\"\nguidance_scale: 3\nstg_scale: 1\nrescaling_scale: 0.7"
  },
  {
    "path": "inference.py",
    "chars": 277,
    "preview": "from transformers import HfArgumentParser\n\nfrom ltx_video.inference import infer, InferenceConfig\n\n\ndef main():\n    pars"
  },
  {
    "path": "ltx_video/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ltx_video/inference.py",
    "chars": 26715,
    "preview": "import os\nimport random\nfrom datetime import datetime\nfrom pathlib import Path\nfrom diffusers.utils import logging\nfrom "
  },
  {
    "path": "ltx_video/models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ltx_video/models/autoencoders/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ltx_video/models/autoencoders/causal_conv3d.py",
    "chars": 1759,
    "preview": "from typing import Tuple, Union\n\nimport torch\nimport torch.nn as nn\n\n\nclass CausalConv3d(nn.Module):\n    def __init__(\n "
  },
  {
    "path": "ltx_video/models/autoencoders/causal_video_autoencoder.py",
    "chars": 52253,
    "preview": "import json\nimport os\nfrom functools import partial\nfrom types import SimpleNamespace\nfrom typing import Any, Mapping, O"
  },
  {
    "path": "ltx_video/models/autoencoders/conv_nd_factory.py",
    "chars": 2609,
    "preview": "from typing import Tuple, Union\n\nimport torch\n\nfrom ltx_video.models.autoencoders.dual_conv3d import DualConv3d\nfrom ltx"
  },
  {
    "path": "ltx_video/models/autoencoders/dual_conv3d.py",
    "chars": 6885,
    "preview": "import math\nfrom typing import Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom ein"
  },
  {
    "path": "ltx_video/models/autoencoders/latent_upsampler.py",
    "chars": 7025,
    "preview": "from typing import Optional, Union\nfrom pathlib import Path\nimport os\nimport json\n\nimport torch\nimport torch.nn as nn\nfr"
  },
  {
    "path": "ltx_video/models/autoencoders/pixel_norm.py",
    "chars": 307,
    "preview": "import torch\nfrom torch import nn\n\n\nclass PixelNorm(nn.Module):\n    def __init__(self, dim=1, eps=1e-8):\n        super(P"
  },
  {
    "path": "ltx_video/models/autoencoders/pixel_shuffle.py",
    "chars": 1043,
    "preview": "import torch.nn as nn\nfrom einops import rearrange\n\n\nclass PixelShuffleND(nn.Module):\n    def __init__(self, dims, upsca"
  },
  {
    "path": "ltx_video/models/autoencoders/vae.py",
    "chars": 14404,
    "preview": "from typing import Optional, Union\n\nimport torch\nimport inspect\nimport math\nimport torch.nn as nn\nfrom diffusers import "
  },
  {
    "path": "ltx_video/models/autoencoders/vae_encode.py",
    "chars": 8779,
    "preview": "from typing import Tuple\nimport torch\nfrom diffusers import AutoencoderKL\nfrom einops import rearrange\nfrom torch import"
  },
  {
    "path": "ltx_video/models/autoencoders/video_autoencoder.py",
    "chars": 35462,
    "preview": "import json\nimport os\nfrom functools import partial\nfrom types import SimpleNamespace\nfrom typing import Any, Mapping, O"
  },
  {
    "path": "ltx_video/models/transformers/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ltx_video/models/transformers/attention.py",
    "chars": 52315,
    "preview": "import inspect\nfrom importlib import import_module\nfrom typing import Any, Dict, Optional, Tuple\n\nimport torch\nimport to"
  },
  {
    "path": "ltx_video/models/transformers/embeddings.py",
    "chars": 4471,
    "preview": "# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py\nimport math\n\nimpor"
  },
  {
    "path": "ltx_video/models/transformers/symmetric_patchifier.py",
    "chars": 2774,
    "preview": "from abc import ABC, abstractmethod\nfrom typing import Tuple\n\nimport torch\nfrom diffusers.configuration_utils import Con"
  },
  {
    "path": "ltx_video/models/transformers/transformer3d.py",
    "chars": 22744,
    "preview": "# Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.p"
  },
  {
    "path": "ltx_video/pipelines/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ltx_video/pipelines/crf_compressor.py",
    "chars": 1533,
    "preview": "import av\nimport torch\nimport io\nimport numpy as np\n\n\ndef _encode_single_frame(output_file, image_array: np.ndarray, crf"
  },
  {
    "path": "ltx_video/pipelines/pipeline_ltx_video.py",
    "chars": 83428,
    "preview": "# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_"
  },
  {
    "path": "ltx_video/schedulers/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ltx_video/schedulers/rf.py",
    "chars": 15283,
    "preview": "import math\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom typing import Callable, Optional,"
  },
  {
    "path": "ltx_video/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ltx_video/utils/diffusers_config_mapping.py",
    "chars": 5578,
    "preview": "def make_hashable_key(dict_key):\n    def convert_value(value):\n        if isinstance(value, list):\n            return tu"
  },
  {
    "path": "ltx_video/utils/prompt_enhance_utils.py",
    "chars": 7651,
    "preview": "import logging\nfrom typing import Union, List, Optional\n\nimport torch\nfrom PIL import Image\n\nlogger = logging.getLogger("
  },
  {
    "path": "ltx_video/utils/skip_layer_strategy.py",
    "chars": 169,
    "preview": "from enum import Enum, auto\n\n\nclass SkipLayerStrategy(Enum):\n    AttentionSkip = auto()\n    AttentionValues = auto()\n   "
  },
  {
    "path": "ltx_video/utils/torch_utils.py",
    "chars": 822,
    "preview": "import torch\nfrom torch import nn\n\n\ndef append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:\n    \"\"\"Appends d"
  },
  {
    "path": "pyproject.toml",
    "chars": 849,
    "preview": "[build-system]\nrequires = [\"setuptools>=42\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"ltx-vid"
  },
  {
    "path": "tests/conftest.py",
    "chars": 3050,
    "preview": "import json\nimport pytest\nimport safetensors.torch\nimport torch\n\nfrom ltx_video.models.autoencoders.causal_video_autoenc"
  },
  {
    "path": "tests/test_configs.py",
    "chars": 958,
    "preview": "import pytest\nfrom pathlib import Path\n\nfrom ltx_video.inference import infer, InferenceConfig\n\nCONFIGS_DIR = Path(__fil"
  },
  {
    "path": "tests/test_inference.py",
    "chars": 8780,
    "preview": "from dataclasses import asdict\nimport pytest\nimport torch\nimport yaml\n\nfrom ltx_video.inference import (\n    create_ltx_"
  },
  {
    "path": "tests/test_scheduler.py",
    "chars": 3688,
    "preview": "import pytest\nimport torch\nfrom ltx_video.schedulers.rf import RectifiedFlowScheduler\n\n\ndef init_latents_and_scheduler(s"
  },
  {
    "path": "tests/test_vae.py",
    "chars": 3173,
    "preview": "import pytest\nimport torch\nfrom ltx_video.models.autoencoders.causal_video_autoencoder import (\n    CausalVideoAutoencod"
  },
  {
    "path": "tests/utils/.gitattributes",
    "chars": 42,
    "preview": "*.mp4 filter=lfs diff=lfs merge=lfs -text\n"
  }
]

About this extraction

This page contains the full source code of the Lightricks/LTX-Video GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 54 files (418.1 KB), approximately 100.7k tokens, and a symbol index with 287 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!