Showing preview only (438K chars total). Download the full file or copy to clipboard to get everything.
Repository: Lightricks/LTX-Video
Branch: main
Commit: 4b2d05305762
Files: 54
Total size: 418.1 KB
Directory structure:
gitextract_jt2joi_e/
├── .gitattributes
├── .github/
│ └── workflows/
│ └── pylint.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── configs/
│ ├── ltxv-13b-0.9.8-dev-fp8.yaml
│ ├── ltxv-13b-0.9.8-dev.yaml
│ ├── ltxv-13b-0.9.8-distilled-fp8.yaml
│ ├── ltxv-13b-0.9.8-distilled.yaml
│ ├── ltxv-2b-0.9.1.yaml
│ ├── ltxv-2b-0.9.5.yaml
│ ├── ltxv-2b-0.9.6-dev.yaml
│ ├── ltxv-2b-0.9.6-distilled.yaml
│ ├── ltxv-2b-0.9.8-distilled-fp8.yaml
│ ├── ltxv-2b-0.9.8-distilled.yaml
│ └── ltxv-2b-0.9.yaml
├── inference.py
├── ltx_video/
│ ├── __init__.py
│ ├── inference.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── autoencoders/
│ │ │ ├── __init__.py
│ │ │ ├── causal_conv3d.py
│ │ │ ├── causal_video_autoencoder.py
│ │ │ ├── conv_nd_factory.py
│ │ │ ├── dual_conv3d.py
│ │ │ ├── latent_upsampler.py
│ │ │ ├── pixel_norm.py
│ │ │ ├── pixel_shuffle.py
│ │ │ ├── vae.py
│ │ │ ├── vae_encode.py
│ │ │ └── video_autoencoder.py
│ │ └── transformers/
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── embeddings.py
│ │ ├── symmetric_patchifier.py
│ │ └── transformer3d.py
│ ├── pipelines/
│ │ ├── __init__.py
│ │ ├── crf_compressor.py
│ │ └── pipeline_ltx_video.py
│ ├── schedulers/
│ │ ├── __init__.py
│ │ └── rf.py
│ └── utils/
│ ├── __init__.py
│ ├── diffusers_config_mapping.py
│ ├── prompt_enhance_utils.py
│ ├── skip_layer_strategy.py
│ └── torch_utils.py
├── pyproject.toml
└── tests/
├── conftest.py
├── test_configs.py
├── test_inference.py
├── test_scheduler.py
├── test_vae.py
└── utils/
└── .gitattributes
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitattributes
================================================
*.jpg filter=lfs diff=lfs merge=lfs -text
*.jpeg filter=lfs diff=lfs merge=lfs -text
*.png filter=lfs diff=lfs merge=lfs -text
*.gif filter=lfs diff=lfs merge=lfs -text
tests/utils/car.png filter=lfs diff=lfs merge=lfs -text
================================================
FILE: .github/workflows/pylint.yml
================================================
name: Ruff
on: [push]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
steps:
- name: Checkout repository and submodules
uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ruff==0.2.2 black==24.2.0
- name: Analyzing the code with ruff
run: |
ruff $(git ls-files '*.py')
- name: Verify that no Black changes are required
run: |
black --check $(git ls-files '*.py')
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# From inference.py
outputs/
*.mp4
*.png
!tests/utils/car.png
================================================
FILE: .pre-commit-config.yaml
================================================
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.2.2
hooks:
# Run the linter.
- id: ruff
args: [--fix] # Automatically fix issues if possible.
types: [python] # Ensure it only runs on .py files.
- repo: https://github.com/psf/black
rev: 24.2.0 # Specify the version of Black you want
hooks:
- id: black
name: Black code formatter
language_version: python3 # Use the Python version you're targeting (e.g., 3.10)
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
<div align="center">
# LTX-Video
[](https://ltx.video)
[](https://huggingface.co/Lightricks/LTX-Video)
[](https://app.ltx.studio/ltx-2-playground/t2v)
[](https://arxiv.org/abs/2501.00103)
[](https://github.com/Lightricks/LTX-Video-Trainer)
[](https://discord.gg/ltxplatform)
This is the official repository for LTX-Video.
</div>
---
## 🚀 **New: LTX-2 is Now Available!**
**We're excited to announce [LTX-2](https://github.com/Lightricks/LTX-2) - the next generation of LTX with synchronized audio+video generation!**
LTX-2 is the first DiT-based audio-video foundation model that contains all core capabilities of modern video generation in one model. **LTX-2 is now the primary home for LTX development** and includes significant improvements:
- 🎵 **Synchronized Audio+Video Generation** - Generate videos with perfectly synchronized audio
- 🎬 **Latest Model** - LTX-2 with improved quality and capabilities
- 🔌 **ComfyUI Integration** - Built into ComfyUI core for seamless workflows
- 🎯 **Advanced Features:**
- Multiple keyframe support
- IC-LoRA control models for precise generation
- Standard LoRA support for style customization
- Latent upsampler for multiscale pipelines
- 🛠️ **Training Tools** - LoRA training capabilities
- 📚 **Comprehensive Documentation** - Full documentation at [https://docs.ltx.video](https://docs.ltx.video)
- 🔄 **Active Development** - Ongoing improvements and community support
**[👉 Check out LTX-2 here](https://github.com/Lightricks/LTX-2)**
**[📖 View Documentation](https://docs.ltx.video)**
---
## Table of Contents
- [Introduction](#introduction)
- [What's New](#news)
- [Models](#models)
- [Quick Start Guide](#quick-start-guide)
- [Online demo](#online-inference)
- [Run locally](#run-locally)
- [Installation](#installation)
- [Inference](#inference)
- [ComfyUI Integration](#comfyui-integration)
- [Diffusers Integration](#diffusers-integration)
- [Model User Guide](#model-user-guide)
- [Community Contribution](#community-contribution)
- [Training](#training)
- [Control Models](#control-models)
- [Join Us!](#join-us)
- [Acknowledgement](#acknowledgement)
# Introduction
LTX-Video is the first DiT-based video generation model that contains all core capabilities of modern video generation in one model: synchronized audio and video, high fidelity, multiple performance modes, production-ready outputs, API access, and open access. It can generate up to 50 FPS videos at native 4K resolution with synchronized audio in one pass.
The model is trained on a large-scale dataset of diverse videos and can generate high-resolution videos with realistic and diverse content.
The model supports image-to-video, multi-keyframe conditioning, keyframe-based animation, video extension (both forward and backward), video-to-video transformations, and any combination of these features.
### Image-to-video examples
| | | |
|:---:|:---:|:---:|
|  |  |  |
|  |  |  |
|  |  |  |
### Controlled video examples
| | | |
|:---:|:---:|:---:|
|  |  |  |
| | |
|:---:|:---:|
|  |  |
# News
## October 23, 2025: LTX-2 Announced
Today we announced our newest foundation model, LTX-2. LTX-2 represents a major leap forward from our previous model, LTXV 0.9.8. Here’s what’s new:
* **Audio + Video, Together**: Visuals and sound are generated in one coherent process, with motion, dialogue, ambience, and music flowing simultaneously.
* **4K Fidelity**: Professional-grade precision with native 4K and up to 50 fps, sharp textures, clean motion, and synchronized audio.
* **Longer Generations**: LTX-2 supports longer, continuous clips with synchronized audio up to 10 seconds.
* **Low Cost & Efficiency**: Up to 50% lower compute cost than competing models, powered by a multi-GPU inference stack.
* **Creative Control**: Multi-keyframe conditioning, 3D camera logic, and LoRA fine-tuning deliver frame-level precision and style consistency.
For more details, please see our [blog post](https://website.ltx.video/blog/introducing-ltx-2). LTX-2 model weights, code, and benchmarks will be released to the community later in 2025.
## July, 16th, 2025: New Distilled models v0.9.8 with up to 60 seconds of video:
- Long shot generation in LTXV-13B!
* LTX-Video now supports up to 60 seconds of video.
* Compatible also with the official IC-LoRAs.
* Try now in [ComfyUI](https://github.com/Lightricks/ComfyUI-LTXVideo/tree/master/example_workflows/ltxv-13b-i2v-long-multi-prompt.json).
- Release a new distilled models:
* 13B distilled model [ltxv-13b-0.9.8-distilled](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-distilled.yaml)
* 2B distilled model [ltxv-2b-0.9.8-distilled](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.8-distilled.yaml)
* Both models are distilled from the same base model [ltxv-13b-0.9.8-dev](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-dev.yaml) and are compatible for use together in the same multiscale pipeline.
* Improved prompt understanding and detail generation
* Includes corresponding FP8 weights and workflows.
- Release a new detailer model [LTX-Video-ICLoRA-detailer-13B-0.9.8](https://huggingface.co/Lightricks/LTX-Video-ICLoRA-detailer-13b-0.9.8)
* Available in [ComfyUI](https://github.com/Lightricks/ComfyUI-LTXVideo/tree/master/example_workflows/ltxv-13b-upscale.json).
## July, 8th, 2025: New Control Models Released!
- Released three new control models for LTX-Video on HuggingFace:
* **Depth Control**: [LTX-Video-ICLoRA-depth-13b-0.9.7](https://huggingface.co/Lightricks/LTX-Video-ICLoRA-depth-13b-0.9.7)
* **Pose Control**: [LTX-Video-ICLoRA-pose-13b-0.9.7](https://huggingface.co/Lightricks/LTX-Video-ICLoRA-pose-13b-0.9.7)
* **Canny Control**: [LTX-Video-ICLoRA-canny-13b-0.9.7](https://huggingface.co/Lightricks/LTX-Video-ICLoRA-canny-13b-0.9.7)
## May, 14th, 2025: New distilled model 13B v0.9.7:
- Release a new 13B distilled model [ltxv-13b-0.9.7-distilled](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled.safetensors)
* Amazing for iterative work - generates HD videos in 10 seconds, with low-res preview after just 3 seconds (on H100)!
* Does not require classifier-free guidance and spatio-temporal guidance.
* Supports sampling with 8 (recommended), or less diffusion steps.
* Also released a LoRA version of the distilled model, [ltxv-13b-0.9.7-distilled-lora128](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled-lora128.safetensors)
* Requires only 1GB of VRAM
* Can be used with the full 13B model for fast inference
- Release a new quantized distilled model [ltxv-13b-0.9.7-distilled-fp8](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled-fp8.safetensors) for *real-time* generation (on H100) with even less VRAM
## May, 5th, 2025: New model 13B v0.9.7:
- Release a new 13B model [ltxv-13b-0.9.7-dev](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev.safetensors)
- Release a new quantized model [ltxv-13b-0.9.7-dev-fp8](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev-fp8.safetensors) for faster inference with less VRam
- Release a new upscalers
* [ltxv-temporal-upscaler-0.9.7](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-temporal-upscaler-0.9.7.safetensors)
* [ltxv-spatial-upscaler-0.9.7](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-spatial-upscaler-0.9.7.safetensors)
- Breakthrough prompt adherence and physical understanding.
- New Pipeline for multi-scale video rendering for fast and high quality results
## April, 15th, 2025: New checkpoints v0.9.6:
- Release a new checkpoint [ltxv-2b-0.9.6-dev-04-25](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-2b-0.9.6-dev-04-25.safetensors) with improved quality
- Release a new distilled model [ltxv-2b-0.9.6-distilled-04-25](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-2b-0.9.6-distilled-04-25.safetensors)
* 15x faster inference than non-distilled model.
* Does not require classifier-free guidance and spatio-temporal guidance.
* Supports sampling with 8 (recommended), or less diffusion steps.
- Improved prompt adherence, motion quality and fine details.
- New default resolution and FPS: 1216 × 704 pixels at 30 FPS
* Still real time on H100 with the distilled model.
* Other resolutions and FPS are still supported.
- Support stochastic inference (can improve visual quality when using the distilled model)
## March, 5th, 2025: New checkpoint v0.9.5
- New license for commercial use ([OpenRail-M](https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.5.license.txt))
- Release a new checkpoint v0.9.5 with improved quality
- Support keyframes and video extension
- Support higher resolutions
- Improved prompt understanding
- Improved VAE
- New online web app in [LTX-Studio](https://app.ltx.studio/ltx-video)
- Automatic prompt enhancement
## February, 20th, 2025: More inference options
- Improve STG (Spatiotemporal Guidance) for LTX-Video
- Support MPS on macOS with PyTorch 2.3.0
- Add support for 8-bit model, LTX-VideoQ8
- Add TeaCache for LTX-Video
- Add [ComfyUI-LTXTricks](#comfyui-integration)
- Add Diffusion-Pipe
## December 31st, 2024: Research paper
- Release the [research paper](https://arxiv.org/abs/2501.00103)
## December 20th, 2024: New checkpoint v0.9.1
- Release a new checkpoint v0.9.1 with improved quality
- Support for STG / PAG
- Support loading checkpoints of LTX-Video in Diffusers format (conversion is done on-the-fly)
- Support offloading unused parts to CPU
- Support the new timestep-conditioned VAE decoder
- Reference contributions from the community in the readme file
- Relax transformers dependency
## November 21th, 2024: Initial release v0.9.0
- Initial release of LTX-Video
- Support text-to-video and image-to-video generation
# Models
| Name | Notes | inference.py config | ComfyUI workflow (Recommended) |
|-------------------------|--------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------|
| ltxv-13b-0.9.8-dev | Highest quality, requires more VRAM | [ltxv-13b-0.9.8-dev.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-dev.yaml) | [ltxv-13b-i2v-base.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/ltxv-13b-i2v-base.json) |
| [ltxv-13b-0.9.8-mix](https://app.ltx.studio/motion-workspace?videoModel=ltxv-13b) | Mix ltxv-13b-dev and ltxv-13b-distilled in the same multi-scale rendering workflow for balanced speed-quality | N/A | [ltxv-13b-i2v-mixed-multiscale.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/ltxv-13b-i2v-mixed-multiscale.json) |
[ltxv-13b-0.9.8-distilled](https://app.ltx.studio/motion-workspace?videoModel=ltxv) | Faster, less VRAM usage, slight quality reduction compared to 13b. Ideal for rapid iterations | [ltxv-13b-0.9.8-distilled.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-distilled.yaml) | [ltxv-13b-dist-i2v-base.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/13b-distilled/ltxv-13b-dist-i2v-base.json) |
ltxv-2b-0.9.8-distilled | Smaller model, slight quality reduction compared to 13b distilled. Ideal for fast generation with light VRAM usage | [ltxv-2b-0.9.8-distilled.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.8-distilled.yaml) | N/A |
| ltxv-13b-0.9.8-dev-fp8 | Quantized version of ltxv-13b | [ltxv-13b-0.9.8-dev-fp8.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-dev-fp8.yaml) | [ltxv-13b-i2v-base-fp8.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/ltxv-13b-i2v-base-fp8.json) |
| ltxv-13b-0.9.8-distilled-fp8 | Quantized version of ltxv-13b-distilled | [ltxv-13b-0.9.8-distilled-fp8.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-distilled-fp8.yaml) | [ltxv-13b-dist-i2v-base-fp8.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/13b-distilled/ltxv-13b-dist-i2v-base-fp8.json) |
| ltxv-2b-0.9.8-distilled-fp8 | Quantized version of ltxv-2b-distilled | [ltxv-2b-0.9.8-distilled-fp8.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.8-distilled-fp8.yaml) | N/A |
| ltxv-2b-0.9.6 | Good quality, lower VRAM requirement than ltxv-13b | [ltxv-2b-0.9.6-dev.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.6-dev.yaml) | [ltxvideo-i2v.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/low_level/ltxvideo-i2v.json) |
| ltxv-2b-0.9.6-distilled | 15× faster, real-time capable, fewer steps needed, no STG/CFG required | [ltxv-2b-0.9.6-distilled.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.6-distilled.yaml) | [ltxvideo-i2v-distilled.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/low_level/ltxvideo-i2v-distilled.json) |
# Quick Start Guide
## Online inference
The model is accessible right away via the following links:
- [LTX-Studio image-to-video (13B-mix)](https://app.ltx.studio/motion-workspace?videoModel=ltxv-13b)
- [LTX-Studio image-to-video (13B distilled)](https://app.ltx.studio/motion-workspace?videoModel=ltxv)
- [Fal.ai image-to-video (13B full)](https://fal.ai/models/fal-ai/ltx-video-13b-dev/image-to-video)
- [Fal.ai image-to-video (13B distilled)](https://fal.ai/models/fal-ai/ltx-video-13b-distilled/image-to-video)
- [Replicate image-to-video](https://replicate.com/lightricks/ltx-video)
## Run locally
### Installation
The codebase was tested with Python 3.10.5, CUDA version 12.2, and supports PyTorch >= 2.1.2.
On macOS, MPS was tested with PyTorch 2.3.0, and should support PyTorch == 2.3 or >= 2.6.
```bash
git clone https://github.com/Lightricks/LTX-Video.git
cd LTX-Video
# create env
python -m venv env
source env/bin/activate
python -m pip install -e .\[inference\]
```
#### FP8 Kernels (optional)
[FP8 kernels](https://github.com/Lightricks/LTXVideo-Q8-Kernels) developed for LTX-Video provide performance boost on supported graphics cards (Ada architecture and later). To install FP8 kernels, follow the instructions in that repository.
### Inference
📝 **Note:** For best results, we recommend using our [ComfyUI](#comfyui-integration) workflow. We're working on updating the inference.py script to match the high quality and output fidelity of ComfyUI.
To use our model, please follow the inference code in [inference.py](./inference.py):
#### For image-to-video generation:
```bash
python inference.py --prompt "PROMPT" --conditioning_media_paths IMAGE_PATH --conditioning_start_frames 0 --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED --pipeline_config configs/ltxv-13b-0.9.8-distilled.yaml
```
#### Extending a video:
📝 **Note:** Input video segments must contain a multiple of 8 frames plus 1 (e.g., 9, 17, 25, etc.), and the target frame number should be a multiple of 8.
```bash
python inference.py --prompt "PROMPT" --conditioning_media_paths VIDEO_PATH --conditioning_start_frames START_FRAME --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED --pipeline_config configs/ltxv-13b-0.9.8-distilled.yaml
```
#### For video generation with multiple conditions:
You can now generate a video conditioned on a set of images and/or short video segments.
Simply provide a list of paths to the images or video segments you want to condition on, along with their target frame numbers in the generated video. You can also specify the conditioning strength for each item (default: 1.0).
```bash
python inference.py --prompt "PROMPT" --conditioning_media_paths IMAGE_OR_VIDEO_PATH_1 IMAGE_OR_VIDEO_PATH_2 --conditioning_start_frames TARGET_FRAME_1 TARGET_FRAME_2 --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED --pipeline_config configs/ltxv-13b-0.9.8-distilled.yaml
```
### Using as a library
```python
from ltx_video.inference import infer, InferenceConfig
infer(
InferenceConfig(
pipeline_config="configs/ltxv-13b-0.9.8-distilled.yaml",
prompt=PROMPT,
height=HEIGHT,
width=WIDTH,
num_frames=NUM_FRAMES,
output_path="output.mp4",
)
)
```
## ComfyUI Integration
To use our model with ComfyUI, please follow the instructions at [https://github.com/Lightricks/ComfyUI-LTXVideo/](https://github.com/Lightricks/ComfyUI-LTXVideo/).
## Diffusers Integration
To use our model with the Diffusers Python library, check out the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video).
Diffusers also support an 8-bit version of LTX-Video, [see details below](#ltx-videoq8)
# Model User Guide
## 📝 Prompt Engineering
When writing prompts, focus on detailed, chronological descriptions of actions and scenes. Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. Start directly with the action, and keep descriptions literal and precise. Think like a cinematographer describing a shot list. Keep within 200 words. For best results, build your prompts using this structure:
* Start with main action in a single sentence
* Add specific details about movements and gestures
* Describe character/object appearances precisely
* Include background and environment details
* Specify camera angles and movements
* Describe lighting and colors
* Note any changes or sudden events
* See [examples](#introduction) for more inspiration.
### Automatic Prompt Enhancement
When using `LTXVideoPipeline` directly, you can enable prompt enhancement by setting `enhance_prompt=True`.
## 🎮 Parameter Guide
* Resolution Preset: Higher resolutions for detailed scenes, lower for faster generation and simpler scenes. The model works on resolutions that are divisible by 32 and number of frames that are divisible by 8 + 1 (e.g. 257). In case the resolution or number of frames are not divisible by 32 or 8 + 1, the input will be padded with -1 and then cropped to the desired resolution and number of frames. The model works best on resolutions under 720 x 1280 and number of frames below 257
* Seed: Save seed values to recreate specific styles or compositions you like
* Guidance Scale: 3-3.5 are the recommended values
* Inference Steps: More steps (40+) for quality, fewer steps (20-30) for speed
📝 For advanced parameters usage, please see `python inference.py --help`
## Community Contribution
### ComfyUI-LTXTricks 🛠️
A community project providing additional nodes for enhanced control over the LTX Video model. It includes implementations of advanced techniques like RF-Inversion, RF-Edit, FlowEdit, and more. These nodes enable workflows such as Image and Video to Video (I+V2V), enhanced sampling via Spatiotemporal Skip Guidance (STG), and interpolation with precise frame settings.
- **Repository:** [ComfyUI-LTXTricks](https://github.com/logtd/ComfyUI-LTXTricks)
- **Features:**
- 🔄 **RF-Inversion:** Implements [RF-Inversion](https://rf-inversion.github.io/) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_inversion.json).
- ✂️ **RF-Edit:** Implements [RF-Solver-Edit](https://github.com/wangjiangshan0725/RF-Solver-Edit) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_rf_edit.json).
- 🌊 **FlowEdit:** Implements [FlowEdit](https://github.com/fallenshock/FlowEdit) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_flow_edit.json).
- 🎥 **I+V2V:** Enables Video to Video with a reference image. [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_iv2v.json).
- ✨ **Enhance:** Partial implementation of [STGuidance](https://junhahyung.github.io/STGuidance/). [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltxv_stg.json).
- 🖼️ **Interpolation and Frame Setting:** Nodes for precise control of latents per frame. [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_interpolation.json).
### LTX-VideoQ8 🎱 <a id="ltx-videoq8"></a>
**LTX-VideoQ8** is an 8-bit optimized version of [LTX-Video](https://github.com/Lightricks/LTX-Video), designed for faster performance on NVIDIA ADA GPUs.
- **Repository:** [LTX-VideoQ8](https://github.com/KONAKONA666/LTX-Video)
- **Features:**
- 🚀 Up to 3X speed-up with no accuracy loss
- 🎥 Generate 720x480x121 videos in under a minute on RTX 4060 (8GB VRAM)
- 🛠️ Fine-tune 2B transformer models with precalculated latents
- **Community Discussion:** [Reddit Thread](https://www.reddit.com/r/StableDiffusion/comments/1h79ks2/fast_ltx_video_on_rtx_4060_and_other_ada_gpus/)
- **Diffusers integration:** A diffusers integration for the 8-bit model is already out! [Details here](https://github.com/sayakpaul/q8-ltx-video)
### TeaCache for LTX-Video 🍵 <a id="TeaCache"></a>
**TeaCache** is a training-free caching approach that leverages timestep differences across model outputs to accelerate LTX-Video inference by up to 2x without significant visual quality degradation.
- **Repository:** [TeaCache4LTX-Video](https://github.com/ali-vilab/TeaCache/tree/main/TeaCache4LTX-Video)
- **Features:**
- 🚀 Speeds up LTX-Video inference.
- 📊 Adjustable trade-offs between speed (up to 2x) and visual quality using configurable parameters.
- 🛠️ No retraining required: Works directly with existing models.
### Your Contribution
...is welcome! If you have a project or tool that integrates with LTX-Video,
please let us know by opening an issue or pull request.
# Training
We provide an open-source repository for fine-tuning the LTX-Video model: [LTX-Video-Trainer](https://github.com/Lightricks/LTX-Video-Trainer).
This repository supports both the 2B and 13B model variants, enabling full fine-tuning as well as LoRA (Low-Rank Adaptation) fine-tuning for more efficient training. This includes:
- **Control LoRAs**: Train custom control models like depth, pose, and canny control
- **Effect LoRAs**: Create specialized effects and transformations for video generation
Explore the repository to customize the model for your specific use cases!
More information and training instructions can be found in the [README](https://github.com/Lightricks/LTX-Video-Trainer/blob/main/README.md).
# Control Models
[ComfyUI-LTXVideo](https://github.com/Lightricks/ComfyUI-LTXVideo) repository now contains workflows and models for 3 specialized models that enable precise control over LTX-Video generation:
Pose Control, Depth Control and Canny Control
**Example ComfyUI Workflow (for all control types):** [ic-lora.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/ic_lora/ic-lora.json)
# Join Us
Want to work on cutting-edge AI research and make a real impact on millions of users worldwide?
At **Lightricks**, an AI-first company, we're revolutionizing how visual content is created.
If you are passionate about AI, computer vision, and video generation, we would love to hear from you!
Please visit our [careers page](https://careers.lightricks.com/careers?query=&office=all&department=R%26D) for more information.
# Acknowledgement
We are grateful for the following awesome projects when implementing LTX-Video:
* [DiT](https://github.com/facebookresearch/DiT) and [PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha): vision transformers for image generation.
## Citation
📄 Our tech report is out! If you find our work helpful, please ⭐️ star the repository and cite our paper.
```
@article{HaCohen2024LTXVideo,
title={LTX-Video: Realtime Video Latent Diffusion},
author={HaCohen, Yoav and Chiprut, Nisan and Brazowski, Benny and Shalem, Daniel and Moshe, Dudu and Richardson, Eitan and Levin, Eran and Shiran, Guy and Zabari, Nir and Gordon, Ori and Panet, Poriya and Weissbuch, Sapir and Kulikov, Victor and Bitterman, Yaki and Melumian, Zeev and Bibi, Ofir},
journal={arXiv preprint arXiv:2501.00103},
year={2024}
}
```
================================================
FILE: configs/ltxv-13b-0.9.8-dev-fp8.yaml
================================================
pipeline_type: multi-scale
checkpoint_path: "ltxv-13b-0.9.8-dev-fp8.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "float8_e4m3fn" # options: "float8_e4m3fn", "bfloat16", "mixed_precision"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
first_pass:
guidance_scale: [1, 1, 6, 8, 6, 1, 1]
stg_scale: [0, 0, 4, 4, 4, 2, 1]
rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
num_inference_steps: 30
skip_final_inference_steps: 3
cfg_star_rescale: true
second_pass:
guidance_scale: [1]
stg_scale: [1]
rescaling_scale: [1]
guidance_timesteps: [1.0]
skip_block_list: [27]
num_inference_steps: 30
skip_initial_inference_steps: 17
cfg_star_rescale: true
================================================
FILE: configs/ltxv-13b-0.9.8-dev.yaml
================================================
pipeline_type: multi-scale
checkpoint_path: "ltxv-13b-0.9.8-dev.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
first_pass:
guidance_scale: [1, 1, 6, 8, 6, 1, 1]
stg_scale: [0, 0, 4, 4, 4, 2, 1]
rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
num_inference_steps: 30
skip_final_inference_steps: 3
cfg_star_rescale: true
second_pass:
guidance_scale: [1]
stg_scale: [1]
rescaling_scale: [1]
guidance_timesteps: [1.0]
skip_block_list: [27]
num_inference_steps: 30
skip_initial_inference_steps: 17
cfg_star_rescale: true
================================================
FILE: configs/ltxv-13b-0.9.8-distilled-fp8.yaml
================================================
pipeline_type: multi-scale
checkpoint_path: "ltxv-13b-0.9.8-distilled-fp8.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "float8_e4m3fn" # options: "float8_e4m3fn", "bfloat16", "mixed_precision"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
first_pass:
timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
guidance_scale: 1
stg_scale: 0
rescaling_scale: 1
skip_block_list: [42]
second_pass:
timesteps: [0.9094, 0.7250, 0.4219]
guidance_scale: 1
stg_scale: 0
rescaling_scale: 1
skip_block_list: [42]
tone_map_compression_ratio: 0.6
================================================
FILE: configs/ltxv-13b-0.9.8-distilled.yaml
================================================
pipeline_type: multi-scale
checkpoint_path: "ltxv-13b-0.9.8-distilled.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
first_pass:
timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
guidance_scale: 1
stg_scale: 0
rescaling_scale: 1
skip_block_list: [42]
second_pass:
timesteps: [0.9094, 0.7250, 0.4219]
guidance_scale: 1
stg_scale: 0
rescaling_scale: 1
skip_block_list: [42]
tone_map_compression_ratio: 0.6
================================================
FILE: configs/ltxv-2b-0.9.1.yaml
================================================
pipeline_type: base
checkpoint_path: "ltx-video-2b-v0.9.1.safetensors"
guidance_scale: 3
stg_scale: 1
rescaling_scale: 0.7
skip_block_list: [19]
num_inference_steps: 40
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
================================================
FILE: configs/ltxv-2b-0.9.5.yaml
================================================
pipeline_type: base
checkpoint_path: "ltx-video-2b-v0.9.5.safetensors"
guidance_scale: 3
stg_scale: 1
rescaling_scale: 0.7
skip_block_list: [19]
num_inference_steps: 40
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
================================================
FILE: configs/ltxv-2b-0.9.6-dev.yaml
================================================
pipeline_type: base
checkpoint_path: "ltxv-2b-0.9.6-dev-04-25.safetensors"
guidance_scale: 3
stg_scale: 1
rescaling_scale: 0.7
skip_block_list: [19]
num_inference_steps: 40
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
================================================
FILE: configs/ltxv-2b-0.9.6-distilled.yaml
================================================
pipeline_type: base
checkpoint_path: "ltxv-2b-0.9.6-distilled-04-25.safetensors"
guidance_scale: 1
stg_scale: 0
rescaling_scale: 1
num_inference_steps: 8
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: true
================================================
FILE: configs/ltxv-2b-0.9.8-distilled-fp8.yaml
================================================
pipeline_type: multi-scale
checkpoint_path: "ltxv-2b-0.9.8-distilled-fp8.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "float8_e4m3fn" # options: "float8_e4m3fn", "bfloat16", "mixed_precision"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
first_pass:
timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
guidance_scale: 1
stg_scale: 0
rescaling_scale: 1
skip_block_list: [42]
second_pass:
timesteps: [0.9094, 0.7250, 0.4219]
guidance_scale: 1
stg_scale: 0
rescaling_scale: 1
skip_block_list: [42]
================================================
FILE: configs/ltxv-2b-0.9.8-distilled.yaml
================================================
pipeline_type: multi-scale
checkpoint_path: "ltxv-2b-0.9.8-distilled.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
first_pass:
timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
guidance_scale: 1
stg_scale: 0
rescaling_scale: 1
skip_block_list: [42]
second_pass:
timesteps: [0.9094, 0.7250, 0.4219]
guidance_scale: 1
stg_scale: 0
rescaling_scale: 1
skip_block_list: [42]
================================================
FILE: configs/ltxv-2b-0.9.yaml
================================================
pipeline_type: base
checkpoint_path: "ltx-video-2b-v0.9.safetensors"
guidance_scale: 3
stg_scale: 1
rescaling_scale: 0.7
skip_block_list: [19]
num_inference_steps: 40
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
================================================
FILE: inference.py
================================================
from transformers import HfArgumentParser
from ltx_video.inference import infer, InferenceConfig
def main():
parser = HfArgumentParser(InferenceConfig)
config = parser.parse_args_into_dataclasses()[0]
infer(config=config)
if __name__ == "__main__":
main()
================================================
FILE: ltx_video/__init__.py
================================================
================================================
FILE: ltx_video/inference.py
================================================
import os
import random
from datetime import datetime
from pathlib import Path
from diffusers.utils import logging
from typing import Optional, List, Union
import yaml
import imageio
import json
import numpy as np
import torch
from safetensors import safe_open
from PIL import Image
import torchvision.transforms.functional as TVF
from transformers import (
T5EncoderModel,
T5Tokenizer,
AutoModelForCausalLM,
AutoProcessor,
AutoTokenizer,
)
from huggingface_hub import hf_hub_download
from dataclasses import dataclass, field
from ltx_video.models.autoencoders.causal_video_autoencoder import (
CausalVideoAutoencoder,
)
from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
from ltx_video.models.transformers.transformer3d import Transformer3DModel
from ltx_video.pipelines.pipeline_ltx_video import (
ConditioningItem,
LTXVideoPipeline,
LTXMultiScalePipeline,
)
from ltx_video.schedulers.rf import RectifiedFlowScheduler
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
import ltx_video.pipelines.crf_compressor as crf_compressor
logger = logging.get_logger("LTX-Video")
def get_total_gpu_memory():
if torch.cuda.is_available():
total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
return total_memory
return 0
def get_device():
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
return "cpu"
def load_image_to_tensor_with_resize_and_crop(
image_input: Union[str, Image.Image],
target_height: int = 512,
target_width: int = 768,
just_crop: bool = False,
) -> torch.Tensor:
"""Load and process an image into a tensor.
Args:
image_input: Either a file path (str) or a PIL Image object
target_height: Desired height of output tensor
target_width: Desired width of output tensor
just_crop: If True, only crop the image to the target size without resizing
"""
if isinstance(image_input, str):
image = Image.open(image_input).convert("RGB")
elif isinstance(image_input, Image.Image):
image = image_input
else:
raise ValueError("image_input must be either a file path or a PIL Image object")
input_width, input_height = image.size
aspect_ratio_target = target_width / target_height
aspect_ratio_frame = input_width / input_height
if aspect_ratio_frame > aspect_ratio_target:
new_width = int(input_height * aspect_ratio_target)
new_height = input_height
x_start = (input_width - new_width) // 2
y_start = 0
else:
new_width = input_width
new_height = int(input_width / aspect_ratio_target)
x_start = 0
y_start = (input_height - new_height) // 2
image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
if not just_crop:
image = image.resize((target_width, target_height))
frame_tensor = TVF.to_tensor(image) # PIL -> tensor (C, H, W), [0,1]
frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=3, sigma=1.0)
frame_tensor_hwc = frame_tensor.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
frame_tensor = frame_tensor_hwc.permute(2, 0, 1) * 255.0 # (H, W, C) -> (C, H, W)
frame_tensor = (frame_tensor / 127.5) - 1.0
# Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
return frame_tensor.unsqueeze(0).unsqueeze(2)
def calculate_padding(
source_height: int, source_width: int, target_height: int, target_width: int
) -> tuple[int, int, int, int]:
# Calculate total padding needed
pad_height = target_height - source_height
pad_width = target_width - source_width
# Calculate padding for each side
pad_top = pad_height // 2
pad_bottom = pad_height - pad_top # Handles odd padding
pad_left = pad_width // 2
pad_right = pad_width - pad_left # Handles odd padding
# Return padded tensor
# Padding format is (left, right, top, bottom)
padding = (pad_left, pad_right, pad_top, pad_bottom)
return padding
def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
# Remove non-letters and convert to lowercase
clean_text = "".join(
char.lower() for char in text if char.isalpha() or char.isspace()
)
# Split into words
words = clean_text.split()
# Build result string keeping track of length
result = []
current_length = 0
for word in words:
# Add word length plus 1 for underscore (except for first word)
new_length = current_length + len(word)
if new_length <= max_len:
result.append(word)
current_length += len(word)
else:
break
return "-".join(result)
# Generate output video name
def get_unique_filename(
base: str,
ext: str,
prompt: str,
seed: int,
resolution: tuple[int, int, int],
dir: Path,
endswith=None,
index_range=1000,
) -> Path:
base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
for i in range(index_range):
filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
if not os.path.exists(filename):
return filename
raise FileExistsError(
f"Could not find a unique filename after {index_range} attempts."
)
def seed_everething(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
if torch.backends.mps.is_available():
torch.mps.manual_seed(seed)
def create_transformer(ckpt_path: str, precision: str) -> Transformer3DModel:
if precision == "float8_e4m3fn":
try:
from q8_kernels.integration.patch_transformer import (
patch_diffusers_transformer as patch_transformer_for_q8_kernels,
)
transformer = Transformer3DModel.from_pretrained(
ckpt_path, dtype=torch.float8_e4m3fn
)
patch_transformer_for_q8_kernels(transformer)
return transformer
except ImportError:
raise ValueError(
"Q8-Kernels not found. To use FP8 checkpoint, please install Q8 kernels from https://github.com/Lightricks/LTXVideo-Q8-Kernels"
)
elif precision == "bfloat16":
return Transformer3DModel.from_pretrained(ckpt_path).to(torch.bfloat16)
else:
return Transformer3DModel.from_pretrained(ckpt_path)
def create_ltx_video_pipeline(
ckpt_path: str,
precision: str,
text_encoder_model_name_or_path: str,
sampler: Optional[str] = None,
device: Optional[str] = None,
enhance_prompt: bool = False,
prompt_enhancer_image_caption_model_name_or_path: Optional[str] = None,
prompt_enhancer_llm_model_name_or_path: Optional[str] = None,
) -> LTXVideoPipeline:
ckpt_path = Path(ckpt_path)
assert os.path.exists(
ckpt_path
), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist"
with safe_open(ckpt_path, framework="pt") as f:
metadata = f.metadata()
config_str = metadata.get("config")
configs = json.loads(config_str)
allowed_inference_steps = configs.get("allowed_inference_steps", None)
vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
transformer = create_transformer(ckpt_path, precision)
# Use constructor if sampler is specified, otherwise use from_pretrained
if sampler == "from_checkpoint" or not sampler:
scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
else:
scheduler = RectifiedFlowScheduler(
sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic")
)
text_encoder = T5EncoderModel.from_pretrained(
text_encoder_model_name_or_path, subfolder="text_encoder"
)
patchifier = SymmetricPatchifier(patch_size=1)
tokenizer = T5Tokenizer.from_pretrained(
text_encoder_model_name_or_path, subfolder="tokenizer"
)
transformer = transformer.to(device)
vae = vae.to(device)
text_encoder = text_encoder.to(device)
if enhance_prompt:
prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained(
prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
)
prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained(
prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
)
prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained(
prompt_enhancer_llm_model_name_or_path,
torch_dtype="bfloat16",
)
prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(
prompt_enhancer_llm_model_name_or_path,
)
else:
prompt_enhancer_image_caption_model = None
prompt_enhancer_image_caption_processor = None
prompt_enhancer_llm_model = None
prompt_enhancer_llm_tokenizer = None
vae = vae.to(torch.bfloat16)
text_encoder = text_encoder.to(torch.bfloat16)
# Use submodels for the pipeline
submodel_dict = {
"transformer": transformer,
"patchifier": patchifier,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"scheduler": scheduler,
"vae": vae,
"prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model,
"prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor,
"prompt_enhancer_llm_model": prompt_enhancer_llm_model,
"prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer,
"allowed_inference_steps": allowed_inference_steps,
}
pipeline = LTXVideoPipeline(**submodel_dict)
pipeline = pipeline.to(device)
return pipeline
def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
latent_upsampler.to(device)
latent_upsampler.eval()
return latent_upsampler
def load_pipeline_config(pipeline_config: str):
current_file = Path(__file__)
path = None
if os.path.isfile(current_file.parent / pipeline_config):
path = current_file.parent / pipeline_config
elif os.path.isfile(pipeline_config):
path = pipeline_config
else:
raise ValueError(f"Pipeline config file {pipeline_config} does not exist")
with open(path, "r") as f:
return yaml.safe_load(f)
@dataclass
class InferenceConfig:
prompt: str = field(metadata={"help": "Prompt for the generation"})
output_path: str = field(
default_factory=lambda: Path(
f"outputs/{datetime.today().strftime('%Y-%m-%d')}"
),
metadata={"help": "Path to the folder to save the output video"},
)
# Pipeline settings
pipeline_config: str = field(
default="configs/ltxv-13b-0.9.7-dev.yaml",
metadata={"help": "Path to the pipeline config file"},
)
seed: int = field(
default=171198, metadata={"help": "Random seed for the inference"}
)
height: int = field(
default=704, metadata={"help": "Height of the output video frames"}
)
width: int = field(
default=1216, metadata={"help": "Width of the output video frames"}
)
num_frames: int = field(
default=121,
metadata={"help": "Number of frames to generate in the output video"},
)
frame_rate: int = field(
default=30, metadata={"help": "Frame rate for the output video"}
)
offload_to_cpu: bool = field(
default=False, metadata={"help": "Offloading unnecessary computations to CPU."}
)
negative_prompt: str = field(
default="worst quality, inconsistent motion, blurry, jittery, distorted",
metadata={"help": "Negative prompt for undesired features"},
)
# Video-to-video arguments
input_media_path: Optional[str] = field(
default=None,
metadata={
"help": "Path to the input video (or image) to be modified using the video-to-video pipeline"
},
)
# Conditioning
image_cond_noise_scale: float = field(
default=0.15,
metadata={"help": "Amount of noise to add to the conditioned image"},
)
conditioning_media_paths: Optional[List[str]] = field(
default=None,
metadata={
"help": "List of paths to conditioning media (images or videos). Each path will be used as a conditioning item."
},
)
conditioning_strengths: Optional[List[float]] = field(
default=None,
metadata={
"help": "List of conditioning strengths (between 0 and 1) for each conditioning item. Must match the number of conditioning items."
},
)
conditioning_start_frames: Optional[List[int]] = field(
default=None,
metadata={
"help": "List of frame indices where each conditioning item should be applied. Must match the number of conditioning items."
},
)
def infer(config: InferenceConfig):
pipeline_config = load_pipeline_config(config.pipeline_config)
ltxv_model_name_or_path = pipeline_config["checkpoint_path"]
if not os.path.isfile(ltxv_model_name_or_path):
ltxv_model_path = hf_hub_download(
repo_id="Lightricks/LTX-Video",
filename=ltxv_model_name_or_path,
repo_type="model",
)
else:
ltxv_model_path = ltxv_model_name_or_path
spatial_upscaler_model_name_or_path = pipeline_config.get(
"spatial_upscaler_model_path"
)
if spatial_upscaler_model_name_or_path and not os.path.isfile(
spatial_upscaler_model_name_or_path
):
spatial_upscaler_model_path = hf_hub_download(
repo_id="Lightricks/LTX-Video",
filename=spatial_upscaler_model_name_or_path,
repo_type="model",
)
else:
spatial_upscaler_model_path = spatial_upscaler_model_name_or_path
conditioning_media_paths = config.conditioning_media_paths
conditioning_strengths = config.conditioning_strengths
conditioning_start_frames = config.conditioning_start_frames
# Validate conditioning arguments
if conditioning_media_paths:
# Use default strengths of 1.0
if not conditioning_strengths:
conditioning_strengths = [1.0] * len(conditioning_media_paths)
if not conditioning_start_frames:
raise ValueError(
"If `conditioning_media_paths` is provided, "
"`conditioning_start_frames` must also be provided"
)
if len(conditioning_media_paths) != len(conditioning_strengths) or len(
conditioning_media_paths
) != len(conditioning_start_frames):
raise ValueError(
"`conditioning_media_paths`, `conditioning_strengths`, "
"and `conditioning_start_frames` must have the same length"
)
if any(s < 0 or s > 1 for s in conditioning_strengths):
raise ValueError("All conditioning strengths must be between 0 and 1")
if any(f < 0 or f >= config.num_frames for f in conditioning_start_frames):
raise ValueError(
f"All conditioning start frames must be between 0 and {config.num_frames-1}"
)
seed_everething(config.seed)
if config.offload_to_cpu and not torch.cuda.is_available():
logger.warning(
"offload_to_cpu is set to True, but offloading will not occur since the model is already running on CPU."
)
offload_to_cpu = False
else:
offload_to_cpu = config.offload_to_cpu and get_total_gpu_memory() < 30
output_dir = (
Path(config.output_path)
if config.output_path
else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
)
output_dir.mkdir(parents=True, exist_ok=True)
# Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1)
height_padded = ((config.height - 1) // 32 + 1) * 32
width_padded = ((config.width - 1) // 32 + 1) * 32
num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1
padding = calculate_padding(
config.height, config.width, height_padded, width_padded
)
logger.warning(
f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}"
)
device = get_device()
prompt_enhancement_words_threshold = pipeline_config[
"prompt_enhancement_words_threshold"
]
prompt_word_count = len(config.prompt.split())
enhance_prompt = (
prompt_enhancement_words_threshold > 0
and prompt_word_count < prompt_enhancement_words_threshold
)
if prompt_enhancement_words_threshold > 0 and not enhance_prompt:
logger.info(
f"Prompt has {prompt_word_count} words, which exceeds the threshold of {prompt_enhancement_words_threshold}. Prompt enhancement disabled."
)
precision = pipeline_config["precision"]
text_encoder_model_name_or_path = pipeline_config["text_encoder_model_name_or_path"]
sampler = pipeline_config.get("sampler", None)
prompt_enhancer_image_caption_model_name_or_path = pipeline_config[
"prompt_enhancer_image_caption_model_name_or_path"
]
prompt_enhancer_llm_model_name_or_path = pipeline_config[
"prompt_enhancer_llm_model_name_or_path"
]
pipeline = create_ltx_video_pipeline(
ckpt_path=ltxv_model_path,
precision=precision,
text_encoder_model_name_or_path=text_encoder_model_name_or_path,
sampler=sampler,
device=device,
enhance_prompt=enhance_prompt,
prompt_enhancer_image_caption_model_name_or_path=prompt_enhancer_image_caption_model_name_or_path,
prompt_enhancer_llm_model_name_or_path=prompt_enhancer_llm_model_name_or_path,
)
if pipeline_config.get("pipeline_type", None) == "multi-scale":
if not spatial_upscaler_model_path:
raise ValueError(
"spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering"
)
latent_upsampler = create_latent_upsampler(
spatial_upscaler_model_path, pipeline.device
)
pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler)
media_item = None
if config.input_media_path:
media_item = load_media_file(
media_path=config.input_media_path,
height=config.height,
width=config.width,
max_frames=num_frames_padded,
padding=padding,
)
conditioning_items = (
prepare_conditioning(
conditioning_media_paths=conditioning_media_paths,
conditioning_strengths=conditioning_strengths,
conditioning_start_frames=conditioning_start_frames,
height=config.height,
width=config.width,
num_frames=config.num_frames,
padding=padding,
pipeline=pipeline,
)
if conditioning_media_paths
else None
)
stg_mode = pipeline_config.get("stg_mode", "attention_values")
del pipeline_config["stg_mode"]
if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values":
skip_layer_strategy = SkipLayerStrategy.AttentionValues
elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip":
skip_layer_strategy = SkipLayerStrategy.AttentionSkip
elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual":
skip_layer_strategy = SkipLayerStrategy.Residual
elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block":
skip_layer_strategy = SkipLayerStrategy.TransformerBlock
else:
raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}")
# Prepare input for the pipeline
sample = {
"prompt": config.prompt,
"prompt_attention_mask": None,
"negative_prompt": config.negative_prompt,
"negative_prompt_attention_mask": None,
}
generator = torch.Generator(device=device).manual_seed(config.seed)
images = pipeline(
**pipeline_config,
skip_layer_strategy=skip_layer_strategy,
generator=generator,
output_type="pt",
callback_on_step_end=None,
height=height_padded,
width=width_padded,
num_frames=num_frames_padded,
frame_rate=config.frame_rate,
**sample,
media_items=media_item,
conditioning_items=conditioning_items,
is_video=True,
vae_per_channel_normalize=True,
image_cond_noise_scale=config.image_cond_noise_scale,
mixed_precision=(precision == "mixed_precision"),
offload_to_cpu=offload_to_cpu,
device=device,
enhance_prompt=enhance_prompt,
).images
# Crop the padded images to the desired resolution and number of frames
(pad_left, pad_right, pad_top, pad_bottom) = padding
pad_bottom = -pad_bottom
pad_right = -pad_right
if pad_bottom == 0:
pad_bottom = images.shape[3]
if pad_right == 0:
pad_right = images.shape[4]
images = images[:, :, : config.num_frames, pad_top:pad_bottom, pad_left:pad_right]
for i in range(images.shape[0]):
# Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy()
# Unnormalizing images to [0, 255] range
video_np = (video_np * 255).astype(np.uint8)
fps = config.frame_rate
height, width = video_np.shape[1:3]
# In case a single image is generated
if video_np.shape[0] == 1:
output_filename = get_unique_filename(
f"image_output_{i}",
".png",
prompt=config.prompt,
seed=config.seed,
resolution=(height, width, config.num_frames),
dir=output_dir,
)
imageio.imwrite(output_filename, video_np[0])
else:
output_filename = get_unique_filename(
f"video_output_{i}",
".mp4",
prompt=config.prompt,
seed=config.seed,
resolution=(height, width, config.num_frames),
dir=output_dir,
)
# Write video
with imageio.get_writer(output_filename, fps=fps) as video:
for frame in video_np:
video.append_data(frame)
logger.warning(f"Output saved to {output_filename}")
def prepare_conditioning(
conditioning_media_paths: List[str],
conditioning_strengths: List[float],
conditioning_start_frames: List[int],
height: int,
width: int,
num_frames: int,
padding: tuple[int, int, int, int],
pipeline: LTXVideoPipeline,
) -> Optional[List[ConditioningItem]]:
"""Prepare conditioning items based on input media paths and their parameters.
Args:
conditioning_media_paths: List of paths to conditioning media (images or videos)
conditioning_strengths: List of conditioning strengths for each media item
conditioning_start_frames: List of frame indices where each item should be applied
height: Height of the output frames
width: Width of the output frames
num_frames: Number of frames in the output video
padding: Padding to apply to the frames
pipeline: LTXVideoPipeline object used for condition video trimming
Returns:
A list of ConditioningItem objects.
"""
conditioning_items = []
for path, strength, start_frame in zip(
conditioning_media_paths, conditioning_strengths, conditioning_start_frames
):
num_input_frames = orig_num_input_frames = get_media_num_frames(path)
if hasattr(pipeline, "trim_conditioning_sequence") and callable(
getattr(pipeline, "trim_conditioning_sequence")
):
num_input_frames = pipeline.trim_conditioning_sequence(
start_frame, orig_num_input_frames, num_frames
)
if num_input_frames < orig_num_input_frames:
logger.warning(
f"Trimming conditioning video {path} from {orig_num_input_frames} to {num_input_frames} frames."
)
media_tensor = load_media_file(
media_path=path,
height=height,
width=width,
max_frames=num_input_frames,
padding=padding,
just_crop=True,
)
conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength))
return conditioning_items
def get_media_num_frames(media_path: str) -> int:
is_video = any(
media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
)
num_frames = 1
if is_video:
reader = imageio.get_reader(media_path)
num_frames = reader.count_frames()
reader.close()
return num_frames
def load_media_file(
media_path: str,
height: int,
width: int,
max_frames: int,
padding: tuple[int, int, int, int],
just_crop: bool = False,
) -> torch.Tensor:
is_video = any(
media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
)
if is_video:
reader = imageio.get_reader(media_path)
num_input_frames = min(reader.count_frames(), max_frames)
# Read and preprocess the relevant frames from the video file.
frames = []
for i in range(num_input_frames):
frame = Image.fromarray(reader.get_data(i))
frame_tensor = load_image_to_tensor_with_resize_and_crop(
frame, height, width, just_crop=just_crop
)
frame_tensor = torch.nn.functional.pad(frame_tensor, padding)
frames.append(frame_tensor)
reader.close()
# Stack frames along the temporal dimension
media_tensor = torch.cat(frames, dim=2)
else: # Input image
media_tensor = load_image_to_tensor_with_resize_and_crop(
media_path, height, width, just_crop=just_crop
)
media_tensor = torch.nn.functional.pad(media_tensor, padding)
return media_tensor
================================================
FILE: ltx_video/models/__init__.py
================================================
================================================
FILE: ltx_video/models/autoencoders/__init__.py
================================================
================================================
FILE: ltx_video/models/autoencoders/causal_conv3d.py
================================================
from typing import Tuple, Union
import torch
import torch.nn as nn
class CausalConv3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size: int = 3,
stride: Union[int, Tuple[int]] = 1,
dilation: int = 1,
groups: int = 1,
spatial_padding_mode: str = "zeros",
**kwargs,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
kernel_size = (kernel_size, kernel_size, kernel_size)
self.time_kernel_size = kernel_size[0]
dilation = (dilation, 1, 1)
height_pad = kernel_size[1] // 2
width_pad = kernel_size[2] // 2
padding = (0, height_pad, width_pad)
self.conv = nn.Conv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
padding_mode=spatial_padding_mode,
groups=groups,
)
def forward(self, x, causal: bool = True):
if causal:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, self.time_kernel_size - 1, 1, 1)
)
x = torch.concatenate((first_frame_pad, x), dim=2)
else:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
last_frame_pad = x[:, :, -1:, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
x = self.conv(x)
return x
@property
def weight(self):
return self.conv.weight
================================================
FILE: ltx_video/models/autoencoders/causal_video_autoencoder.py
================================================
import json
import os
from functools import partial
from types import SimpleNamespace
from typing import Any, Mapping, Optional, Tuple, Union, List
from pathlib import Path
import torch
import numpy as np
from einops import rearrange
from torch import nn
from diffusers.utils import logging
import torch.nn.functional as F
from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
from safetensors import safe_open
from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
from ltx_video.models.autoencoders.pixel_norm import PixelNorm
from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND
from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper
from ltx_video.models.transformers.attention import Attention
from ltx_video.utils.diffusers_config_mapping import (
diffusers_and_ours_config_mapping,
make_hashable_key,
VAE_KEYS_RENAME_DICT,
)
PER_CHANNEL_STATISTICS_PREFIX = "per_channel_statistics."
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class CausalVideoAutoencoder(AutoencoderKLWrapper):
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*args,
**kwargs,
):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
if (
pretrained_model_name_or_path.is_dir()
and (pretrained_model_name_or_path / "autoencoder.pth").exists()
):
config_local_path = pretrained_model_name_or_path / "config.json"
config = cls.load_config(config_local_path, **kwargs)
model_local_path = pretrained_model_name_or_path / "autoencoder.pth"
state_dict = torch.load(model_local_path, map_location=torch.device("cpu"))
statistics_local_path = (
pretrained_model_name_or_path / "per_channel_statistics.json"
)
if statistics_local_path.exists():
with open(statistics_local_path, "r") as file:
data = json.load(file)
transposed_data = list(zip(*data["data"]))
data_dict = {
col: torch.tensor(vals)
for col, vals in zip(data["columns"], transposed_data)
}
std_of_means = data_dict["std-of-means"]
mean_of_means = data_dict.get(
"mean-of-means", torch.zeros_like(data_dict["std-of-means"])
)
state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}std-of-means"] = (
std_of_means
)
state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}mean-of-means"] = (
mean_of_means
)
elif pretrained_model_name_or_path.is_dir():
config_path = pretrained_model_name_or_path / "vae" / "config.json"
with open(config_path, "r") as f:
config = make_hashable_key(json.load(f))
assert config in diffusers_and_ours_config_mapping, (
"Provided diffusers checkpoint config for VAE is not suppported. "
"We only support diffusers configs found in Lightricks/LTX-Video."
)
config = diffusers_and_ours_config_mapping[config]
state_dict_path = (
pretrained_model_name_or_path
/ "vae"
/ "diffusion_pytorch_model.safetensors"
)
state_dict = {}
with safe_open(state_dict_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
for key in list(state_dict.keys()):
new_key = key
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
state_dict[new_key] = state_dict.pop(key)
elif pretrained_model_name_or_path.is_file() and str(
pretrained_model_name_or_path
).endswith(".safetensors"):
state_dict = {}
with safe_open(
pretrained_model_name_or_path, framework="pt", device="cpu"
) as f:
metadata = f.metadata()
for k in f.keys():
state_dict[k] = f.get_tensor(k)
configs = json.loads(metadata["config"])
config = configs["vae"]
video_vae = cls.from_config(config)
if "torch_dtype" in kwargs:
video_vae.to(kwargs["torch_dtype"])
video_vae.load_state_dict(state_dict)
return video_vae
@staticmethod
def from_config(config):
assert (
config["_class_name"] == "CausalVideoAutoencoder"
), "config must have _class_name=CausalVideoAutoencoder"
if isinstance(config["dims"], list):
config["dims"] = tuple(config["dims"])
assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
)
use_quant_conv = config.get("use_quant_conv", True)
normalize_latent_channels = config.get("normalize_latent_channels", False)
if use_quant_conv and latent_log_var in ["uniform", "constant"]:
raise ValueError(
f"latent_log_var={latent_log_var} requires use_quant_conv=False"
)
encoder = Encoder(
dims=config["dims"],
in_channels=config.get("in_channels", 3),
out_channels=config["latent_channels"],
blocks=config.get("encoder_blocks", config.get("blocks")),
patch_size=config.get("patch_size", 1),
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
base_channels=config.get("encoder_base_channels", 128),
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
)
decoder = Decoder(
dims=config["dims"],
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("blocks")),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
timestep_conditioning=config.get("timestep_conditioning", False),
base_channels=config.get("decoder_base_channels", 128),
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
)
dims = config["dims"]
return CausalVideoAutoencoder(
encoder=encoder,
decoder=decoder,
latent_channels=config["latent_channels"],
dims=dims,
use_quant_conv=use_quant_conv,
normalize_latent_channels=normalize_latent_channels,
)
@property
def config(self):
return SimpleNamespace(
_class_name="CausalVideoAutoencoder",
dims=self.dims,
in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2,
out_channels=self.decoder.conv_out.out_channels
// self.decoder.patch_size**2,
latent_channels=self.decoder.conv_in.in_channels,
encoder_blocks=self.encoder.blocks_desc,
decoder_blocks=self.decoder.blocks_desc,
scaling_factor=1.0,
norm_layer=self.encoder.norm_layer,
patch_size=self.encoder.patch_size,
latent_log_var=self.encoder.latent_log_var,
use_quant_conv=self.use_quant_conv,
causal_decoder=self.decoder.causal,
timestep_conditioning=self.decoder.timestep_conditioning,
normalize_latent_channels=self.normalize_latent_channels,
)
@property
def is_video_supported(self):
"""
Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images.
"""
return self.dims != 2
@property
def spatial_downscale_factor(self):
return (
2
** len(
[
block
for block in self.encoder.blocks_desc
if block[0]
in [
"compress_space",
"compress_all",
"compress_all_res",
"compress_space_res",
]
]
)
* self.encoder.patch_size
)
@property
def temporal_downscale_factor(self):
return 2 ** len(
[
block
for block in self.encoder.blocks_desc
if block[0]
in [
"compress_time",
"compress_all",
"compress_all_res",
"compress_time_res",
]
]
)
def to_json_string(self) -> str:
import json
return json.dumps(self.config.__dict__)
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
if any([key.startswith("vae.") for key in state_dict.keys()]):
state_dict = {
key.replace("vae.", ""): value
for key, value in state_dict.items()
if key.startswith("vae.")
}
ckpt_state_dict = {
key: value
for key, value in state_dict.items()
if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX)
}
model_keys = set(name for name, _ in self.named_modules())
key_mapping = {
".resnets.": ".res_blocks.",
"downsamplers.0": "downsample",
"upsamplers.0": "upsample",
}
converted_state_dict = {}
for key, value in ckpt_state_dict.items():
for k, v in key_mapping.items():
key = key.replace(k, v)
key_prefix = ".".join(key.split(".")[:-1])
if "norm" in key and key_prefix not in model_keys:
logger.info(
f"Removing key {key} from state_dict as it is not present in the model"
)
continue
converted_state_dict[key] = value
super().load_state_dict(converted_state_dict, strict=strict)
data_dict = {
key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value
for key, value in state_dict.items()
if key.startswith(PER_CHANNEL_STATISTICS_PREFIX)
}
if len(data_dict) > 0:
self.register_buffer("std_of_means", data_dict["std-of-means"])
self.register_buffer(
"mean_of_means",
data_dict.get(
"mean-of-means", torch.zeros_like(data_dict["std-of-means"])
),
)
def last_layer(self):
if hasattr(self.decoder, "conv_out"):
if isinstance(self.decoder.conv_out, nn.Sequential):
last_layer = self.decoder.conv_out[-1]
else:
last_layer = self.decoder.conv_out
else:
last_layer = self.decoder.layers[-1]
return last_layer
def set_use_tpu_flash_attention(self):
for block in self.decoder.up_blocks:
if isinstance(block, UNetMidBlock3D) and block.attention_blocks:
for attention_block in block.attention_blocks:
attention_block.set_use_tpu_flash_attention()
class Encoder(nn.Module):
r"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
Args:
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
The number of dimensions to use in convolutions.
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
The blocks to use. Each block is a tuple of the block name and the number of layers.
base_channels (`int`, *optional*, defaults to 128):
The number of output channels for the first convolutional layer.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
patch_size (`int`, *optional*, defaults to 1):
The patch size to use. Should be a power of 2.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
latent_log_var (`str`, *optional*, defaults to `per_channel`):
The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]] = 3,
in_channels: int = 3,
out_channels: int = 3,
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
base_channels: int = 128,
norm_num_groups: int = 32,
patch_size: Union[int, Tuple[int]] = 1,
norm_layer: str = "group_norm", # group_norm, pixel_norm
latent_log_var: str = "per_channel",
spatial_padding_mode: str = "zeros",
):
super().__init__()
self.patch_size = patch_size
self.norm_layer = norm_layer
self.latent_channels = out_channels
self.latent_log_var = latent_log_var
self.blocks_desc = blocks
in_channels = in_channels * patch_size**2
output_channel = base_channels
self.conv_in = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=output_channel,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.down_blocks = nn.ModuleList([])
for block_name, block_params in blocks:
input_channel = output_channel
if isinstance(block_params, int):
block_params = {"num_layers": block_params}
if block_name == "res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
block = ResnetBlock3D(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(2, 1, 1),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(1, 2, 2),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(2, 2, 2),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(2, 2, 2),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_res":
output_channel = block_params.get("multiplier", 2) * output_channel
block = SpaceToDepthDownsample(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
stride=(2, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space_res":
output_channel = block_params.get("multiplier", 2) * output_channel
block = SpaceToDepthDownsample(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time_res":
output_channel = block_params.get("multiplier", 2) * output_channel
block = SpaceToDepthDownsample(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"unknown block: {block_name}")
self.down_blocks.append(block)
# out
if norm_layer == "group_norm":
self.conv_norm_out = nn.GroupNorm(
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
)
elif norm_layer == "pixel_norm":
self.conv_norm_out = PixelNorm()
elif norm_layer == "layer_norm":
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = out_channels
if latent_log_var == "per_channel":
conv_out_channels *= 2
elif latent_log_var == "uniform":
conv_out_channels += 1
elif latent_log_var == "constant":
conv_out_channels += 1
elif latent_log_var != "none":
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
self.conv_out = make_conv_nd(
dims,
output_channel,
conv_out_channels,
3,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.gradient_checkpointing = False
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
sample = self.conv_in(sample)
checkpoint_fn = (
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
if self.gradient_checkpointing and self.training
else lambda x: x
)
for down_block in self.down_blocks:
sample = checkpoint_fn(down_block)(sample)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if self.latent_log_var == "uniform":
last_channel = sample[:, -1:, ...]
num_dims = sample.dim()
if num_dims == 4:
# For shape (B, C, H, W)
repeated_last_channel = last_channel.repeat(
1, sample.shape[1] - 2, 1, 1
)
sample = torch.cat([sample, repeated_last_channel], dim=1)
elif num_dims == 5:
# For shape (B, C, F, H, W)
repeated_last_channel = last_channel.repeat(
1, sample.shape[1] - 2, 1, 1, 1
)
sample = torch.cat([sample, repeated_last_channel], dim=1)
else:
raise ValueError(f"Invalid input shape: {sample.shape}")
elif self.latent_log_var == "constant":
sample = sample[:, :-1, ...]
approx_ln_0 = (
-30
) # this is the minimal clamp value in DiagonalGaussianDistribution objects
sample = torch.cat(
[sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],
dim=1,
)
return sample
class Decoder(nn.Module):
r"""
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
Args:
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
The number of dimensions to use in convolutions.
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
The blocks to use. Each block is a tuple of the block name and the number of layers.
base_channels (`int`, *optional*, defaults to 128):
The number of output channels for the first convolutional layer.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
patch_size (`int`, *optional*, defaults to 1):
The patch size to use. Should be a power of 2.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
causal (`bool`, *optional*, defaults to `True`):
Whether to use causal convolutions or not.
"""
def __init__(
self,
dims,
in_channels: int = 3,
out_channels: int = 3,
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
base_channels: int = 128,
layers_per_block: int = 2,
norm_num_groups: int = 32,
patch_size: int = 1,
norm_layer: str = "group_norm",
causal: bool = True,
timestep_conditioning: bool = False,
spatial_padding_mode: str = "zeros",
):
super().__init__()
self.patch_size = patch_size
self.layers_per_block = layers_per_block
out_channels = out_channels * patch_size**2
self.causal = causal
self.blocks_desc = blocks
# Compute output channel to be product of all channel-multiplier blocks
output_channel = base_channels
for block_name, block_params in list(reversed(blocks)):
block_params = block_params if isinstance(block_params, dict) else {}
if block_name == "res_x_y":
output_channel = output_channel * block_params.get("multiplier", 2)
if block_name.startswith("compress"):
output_channel = output_channel * block_params.get("multiplier", 1)
self.conv_in = make_conv_nd(
dims,
in_channels,
output_channel,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.up_blocks = nn.ModuleList([])
for block_name, block_params in list(reversed(blocks)):
input_channel = output_channel
if isinstance(block_params, int):
block_params = {"num_layers": block_params}
if block_name == "res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "attn_res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
attention_head_dim=block_params["attention_head_dim"],
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
output_channel = output_channel // block_params.get("multiplier", 2)
block = ResnetBlock3D(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=False,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
output_channel = output_channel // block_params.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(2, 2, 2),
residual=block_params.get("residual", False),
out_channels_reduction_factor=block_params.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"unknown layer: {block_name}")
self.up_blocks.append(block)
if norm_layer == "group_norm":
self.conv_norm_out = nn.GroupNorm(
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
)
elif norm_layer == "pixel_norm":
self.conv_norm_out = PixelNorm()
elif norm_layer == "layer_norm":
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = make_conv_nd(
dims,
output_channel,
out_channels,
3,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.gradient_checkpointing = False
self.timestep_conditioning = timestep_conditioning
if timestep_conditioning:
self.timestep_scale_multiplier = nn.Parameter(
torch.tensor(1000.0, dtype=torch.float32)
)
self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
output_channel * 2, 0
)
self.last_scale_shift_table = nn.Parameter(
torch.randn(2, output_channel) / output_channel**0.5
)
def forward(
self,
sample: torch.FloatTensor,
target_shape,
timestep: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
assert target_shape is not None, "target_shape must be provided"
batch_size = sample.shape[0]
sample = self.conv_in(sample, causal=self.causal)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
checkpoint_fn = (
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
if self.gradient_checkpointing and self.training
else lambda x: x
)
sample = sample.to(upscale_dtype)
if self.timestep_conditioning:
assert (
timestep is not None
), "should pass timestep with timestep_conditioning=True"
scaled_timestep = timestep * self.timestep_scale_multiplier
for up_block in self.up_blocks:
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
sample = self.conv_norm_out(sample)
if self.timestep_conditioning:
embedded_timestep = self.last_time_embedder(
timestep=scaled_timestep.flatten(),
resolution=None,
aspect_ratio=None,
batch_size=sample.shape[0],
hidden_dtype=sample.dtype,
)
embedded_timestep = embedded_timestep.view(
batch_size, embedded_timestep.shape[-1], 1, 1, 1
)
ada_values = self.last_scale_shift_table[
None, ..., None, None, None
] + embedded_timestep.reshape(
batch_size,
2,
-1,
embedded_timestep.shape[-3],
embedded_timestep.shape[-2],
embedded_timestep.shape[-1],
)
shift, scale = ada_values.unbind(dim=1)
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=self.causal)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample
class UNetMidBlock3D(nn.Module):
"""
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
Args:
in_channels (`int`): The number of input channels.
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
resnet_groups (`int`, *optional*, defaults to 32):
The number of groups to use in the group normalization layers of the resnet blocks.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
inject_noise (`bool`, *optional*, defaults to `False`):
Whether to inject noise into the hidden states.
timestep_conditioning (`bool`, *optional*, defaults to `False`):
Whether to condition the hidden states on the timestep.
attention_head_dim (`int`, *optional*, defaults to -1):
The dimension of the attention head. If -1, no attention is used.
Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
in_channels, height, width)`.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]],
in_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_groups: int = 32,
norm_layer: str = "group_norm",
inject_noise: bool = False,
timestep_conditioning: bool = False,
attention_head_dim: int = -1,
spatial_padding_mode: str = "zeros",
):
super().__init__()
resnet_groups = (
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
)
self.timestep_conditioning = timestep_conditioning
if timestep_conditioning:
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
in_channels * 4, 0
)
self.res_blocks = nn.ModuleList(
[
ResnetBlock3D(
dims=dims,
in_channels=in_channels,
out_channels=in_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
norm_layer=norm_layer,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
for _ in range(num_layers)
]
)
self.attention_blocks = None
if attention_head_dim > 0:
if attention_head_dim > in_channels:
raise ValueError(
"attention_head_dim must be less than or equal to in_channels"
)
self.attention_blocks = nn.ModuleList(
[
Attention(
query_dim=in_channels,
heads=in_channels // attention_head_dim,
dim_head=attention_head_dim,
bias=True,
out_bias=True,
qk_norm="rms_norm",
residual_connection=True,
)
for _ in range(num_layers)
]
)
def forward(
self,
hidden_states: torch.FloatTensor,
causal: bool = True,
timestep: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
timestep_embed = None
if self.timestep_conditioning:
assert (
timestep is not None
), "should pass timestep with timestep_conditioning=True"
batch_size = hidden_states.shape[0]
timestep_embed = self.time_embedder(
timestep=timestep.flatten(),
resolution=None,
aspect_ratio=None,
batch_size=batch_size,
hidden_dtype=hidden_states.dtype,
)
timestep_embed = timestep_embed.view(
batch_size, timestep_embed.shape[-1], 1, 1, 1
)
if self.attention_blocks:
for resnet, attention in zip(self.res_blocks, self.attention_blocks):
hidden_states = resnet(
hidden_states, causal=causal, timestep=timestep_embed
)
# Reshape the hidden states to be (batch_size, frames * height * width, channel)
batch_size, channel, frames, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, frames * height * width
).transpose(1, 2)
if attention.use_tpu_flash_attention:
# Pad the second dimension to be divisible by block_k_major (block in flash attention)
seq_len = hidden_states.shape[1]
block_k_major = 512
pad_len = (block_k_major - seq_len % block_k_major) % block_k_major
if pad_len > 0:
hidden_states = F.pad(
hidden_states, (0, 0, 0, pad_len), "constant", 0
)
# Create a mask with ones for the original sequence length and zeros for the padded indexes
mask = torch.ones(
(hidden_states.shape[0], seq_len),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if pad_len > 0:
mask = F.pad(mask, (0, pad_len), "constant", 0)
hidden_states = attention(
hidden_states,
attention_mask=(
None if not attention.use_tpu_flash_attention else mask
),
)
if attention.use_tpu_flash_attention:
# Remove the padding
if pad_len > 0:
hidden_states = hidden_states[:, :-pad_len, :]
# Reshape the hidden states back to (batch_size, channel, frames, height, width, channel)
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, frames, height, width
)
else:
for resnet in self.res_blocks:
hidden_states = resnet(
hidden_states, causal=causal, timestep=timestep_embed
)
return hidden_states
class SpaceToDepthDownsample(nn.Module):
def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode):
super().__init__()
self.stride = stride
self.group_size = in_channels * np.prod(stride) // out_channels
self.conv = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=out_channels // np.prod(stride),
kernel_size=3,
stride=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
def forward(self, x, causal: bool = True):
if self.stride[0] == 2:
x = torch.cat(
[x[:, :, :1, :, :], x], dim=2
) # duplicate first frames for padding
# skip connection
x_in = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size)
x_in = x_in.mean(dim=2)
# conv
x = self.conv(x, causal=causal)
x = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
x = x + x_in
return x
class DepthToSpaceUpsample(nn.Module):
def __init__(
self,
dims,
in_channels,
stride,
residual=False,
out_channels_reduction_factor=1,
spatial_padding_mode="zeros",
):
super().__init__()
self.stride = stride
self.out_channels = (
np.prod(stride) * in_channels // out_channels_reduction_factor
)
self.conv = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.pixel_shuffle = PixelShuffleND(dims=dims, upscale_factors=stride)
self.residual = residual
self.out_channels_reduction_factor = out_channels_reduction_factor
def forward(self, x, causal: bool = True):
if self.residual:
# Reshape and duplicate the input to match the output shape
x_in = self.pixel_shuffle(x)
num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
if self.stride[0] == 2:
x_in = x_in[:, :, 1:, :, :]
x = self.conv(x, causal=causal)
x = self.pixel_shuffle(x)
if self.stride[0] == 2:
x = x[:, :, 1:, :, :]
if self.residual:
x = x + x_in
return x
class LayerNorm(nn.Module):
def __init__(self, dim, eps, elementwise_affine=True) -> None:
super().__init__()
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(self, x):
x = rearrange(x, "b c d h w -> b d h w c")
x = self.norm(x)
x = rearrange(x, "b d h w c -> b c d h w")
return x
class ResnetBlock3D(nn.Module):
r"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv layer. If None, same as `in_channels`.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]],
in_channels: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
groups: int = 32,
eps: float = 1e-6,
norm_layer: str = "group_norm",
inject_noise: bool = False,
timestep_conditioning: bool = False,
spatial_padding_mode: str = "zeros",
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.inject_noise = inject_noise
if norm_layer == "group_norm":
self.norm1 = nn.GroupNorm(
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
)
elif norm_layer == "pixel_norm":
self.norm1 = PixelNorm()
elif norm_layer == "layer_norm":
self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True)
self.non_linearity = nn.SiLU()
self.conv1 = make_conv_nd(
dims,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
if inject_noise:
self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
if norm_layer == "group_norm":
self.norm2 = nn.GroupNorm(
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
)
elif norm_layer == "pixel_norm":
self.norm2 = PixelNorm()
elif norm_layer == "layer_norm":
self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = make_conv_nd(
dims,
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
if inject_noise:
self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
self.conv_shortcut = (
make_linear_nd(
dims=dims, in_channels=in_channels, out_channels=out_channels
)
if in_channels != out_channels
else nn.Identity()
)
self.norm3 = (
LayerNorm(in_channels, eps=eps, elementwise_affine=True)
if in_channels != out_channels
else nn.Identity()
)
self.timestep_conditioning = timestep_conditioning
if timestep_conditioning:
self.scale_shift_table = nn.Parameter(
torch.randn(4, in_channels) / in_channels**0.5
)
def _feed_spatial_noise(
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
) -> torch.FloatTensor:
spatial_shape = hidden_states.shape[-2:]
device = hidden_states.device
dtype = hidden_states.dtype
# similar to the "explicit noise inputs" method in style-gan
spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
hidden_states = hidden_states + scaled_noise
return hidden_states
def forward(
self,
input_tensor: torch.FloatTensor,
causal: bool = True,
timestep: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
hidden_states = input_tensor
batch_size = hidden_states.shape[0]
hidden_states = self.norm1(hidden_states)
if self.timestep_conditioning:
assert (
timestep is not None
), "should pass timestep with timestep_conditioning=True"
ada_values = self.scale_shift_table[
None, ..., None, None, None
] + timestep.reshape(
batch_size,
4,
-1,
timestep.shape[-3],
timestep.shape[-2],
timestep.shape[-1],
)
shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)
hidden_states = hidden_states * (1 + scale1) + shift1
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.conv1(hidden_states, causal=causal)
if self.inject_noise:
hidden_states = self._feed_spatial_noise(
hidden_states, self.per_channel_scale1
)
hidden_states = self.norm2(hidden_states)
if self.timestep_conditioning:
hidden_states = hidden_states * (1 + scale2) + shift2
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, causal=causal)
if self.inject_noise:
hidden_states = self._feed_spatial_noise(
hidden_states, self.per_channel_scale2
)
input_tensor = self.norm3(input_tensor)
batch_size = input_tensor.shape[0]
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = input_tensor + hidden_states
return output_tensor
def patchify(x, patch_size_hw, patch_size_t=1):
if patch_size_hw == 1 and patch_size_t == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
)
elif x.dim() == 5:
x = rearrange(
x,
"b c (f p) (h q) (w r) -> b (c p r q) f h w",
p=patch_size_t,
q=patch_size_hw,
r=patch_size_hw,
)
else:
raise ValueError(f"Invalid input shape: {x.shape}")
return x
def unpatchify(x, patch_size_hw, patch_size_t=1):
if patch_size_hw == 1 and patch_size_t == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
)
elif x.dim() == 5:
x = rearrange(
x,
"b (c p r q) f h w -> b c (f p) (h q) (w r)",
p=patch_size_t,
q=patch_size_hw,
r=patch_size_hw,
)
return x
def create_video_autoencoder_demo_config(
latent_channels: int = 64,
):
encoder_blocks = [
("res_x", {"num_layers": 2}),
("compress_space_res", {"multiplier": 2}),
("compress_time_res", {"multiplier": 2}),
("compress_all_res", {"multiplier": 2}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 1}),
]
decoder_blocks = [
("res_x", {"num_layers": 2, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("compress_all", {"residual": True, "multiplier": 2}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 2, "inject_noise": False}),
]
return {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
"encoder_blocks": encoder_blocks,
"decoder_blocks": decoder_blocks,
"latent_channels": latent_channels,
"norm_layer": "pixel_norm",
"patch_size": 4,
"latent_log_var": "uniform",
"use_quant_conv": False,
"causal_decoder": False,
"timestep_conditioning": True,
"spatial_padding_mode": "replicate",
}
def test_vae_patchify_unpatchify():
import torch
x = torch.randn(2, 3, 8, 64, 64)
x_patched = patchify(x, patch_size_hw=4, patch_size_t=4)
x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4)
assert torch.allclose(x, x_unpatched)
def demo_video_autoencoder_forward_backward():
# Configuration for the VideoAutoencoder
config = create_video_autoencoder_demo_config()
# Instantiate the VideoAutoencoder with the specified configuration
video_autoencoder = CausalVideoAutoencoder.from_config(config)
print(video_autoencoder)
video_autoencoder.eval()
# Print the total number of parameters in the video autoencoder
total_params = sum(p.numel() for p in video_autoencoder.parameters())
print(f"Total number of parameters in VideoAutoencoder: {total_params:,}")
# Create a mock input tensor simulating a batch of videos
# Shape: (batch_size, channels, depth, height, width)
# E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame
input_videos = torch.randn(2, 3, 17, 64, 64)
# Forward pass: encode and decode the input videos
latent = video_autoencoder.encode(input_videos).latent_dist.mode()
print(f"input shape={input_videos.shape}")
print(f"latent shape={latent.shape}")
timestep = torch.ones(input_videos.shape[0]) * 0.1
reconstructed_videos = video_autoencoder.decode(
latent, target_shape=input_videos.shape, timestep=timestep
).sample
print(f"reconstructed shape={reconstructed_videos.shape}")
# Validate that single image gets treated the same way as first frame
input_image = input_videos[:, :, :1, :, :]
image_latent = video_autoencoder.encode(input_image).latent_dist.mode()
_ = video_autoencoder.decode(
image_latent, target_shape=image_latent.shape, timestep=timestep
).sample
first_frame_latent = latent[:, :, :1, :, :]
assert torch.allclose(image_latent, first_frame_latent, atol=1e-6)
# assert torch.allclose(reconstructed_image, reconstructed_videos[:, :, :1, :, :], atol=1e-6)
# assert torch.allclose(image_latent, first_frame_latent, atol=1e-6)
# assert (reconstructed_image == reconstructed_videos[:, :, :1, :, :]).all()
# Calculate the loss (e.g., mean squared error)
loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)
# Perform backward pass
loss.backward()
print(f"Demo completed with loss: {loss.item()}")
# Ensure to call the demo function to execute the forward and backward pass
if __name__ == "__main__":
demo_video_autoencoder_forward_backward()
================================================
FILE: ltx_video/models/autoencoders/conv_nd_factory.py
================================================
from typing import Tuple, Union
import torch
from ltx_video.models.autoencoders.dual_conv3d import DualConv3d
from ltx_video.models.autoencoders.causal_conv3d import CausalConv3d
def make_conv_nd(
dims: Union[int, Tuple[int, int]],
in_channels: int,
out_channels: int,
kernel_size: int,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
causal=False,
spatial_padding_mode="zeros",
temporal_padding_mode="zeros",
):
if not (spatial_padding_mode == temporal_padding_mode or causal):
raise NotImplementedError("spatial and temporal padding modes must be equal")
if dims == 2:
return torch.nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=spatial_padding_mode,
)
elif dims == 3:
if causal:
return CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
spatial_padding_mode=spatial_padding_mode,
)
return torch.nn.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=spatial_padding_mode,
)
elif dims == (2, 1):
return DualConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"unsupported dimensions: {dims}")
def make_linear_nd(
dims: int,
in_channels: int,
out_channels: int,
bias=True,
):
if dims == 2:
return torch.nn.Conv2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
)
elif dims == 3 or dims == (2, 1):
return torch.nn.Conv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
)
else:
raise ValueError(f"unsupported dimensions: {dims}")
================================================
FILE: ltx_video/models/autoencoders/dual_conv3d.py
================================================
import math
from typing import Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class DualConv3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
groups=1,
bias=True,
padding_mode="zeros",
):
super(DualConv3d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.padding_mode = padding_mode
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if kernel_size == (1, 1, 1):
raise ValueError(
"kernel_size must be greater than 1. Use make_linear_nd instead."
)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
# Set parameters for convolutions
self.groups = groups
self.bias = bias
# Define the size of the channels after the first convolution
intermediate_channels = (
out_channels if in_channels < out_channels else in_channels
)
# Define parameters for the first convolution
self.weight1 = nn.Parameter(
torch.Tensor(
intermediate_channels,
in_channels // groups,
1,
kernel_size[1],
kernel_size[2],
)
)
self.stride1 = (1, stride[1], stride[2])
self.padding1 = (0, padding[1], padding[2])
self.dilation1 = (1, dilation[1], dilation[2])
if bias:
self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
else:
self.register_parameter("bias1", None)
# Define parameters for the second convolution
self.weight2 = nn.Parameter(
torch.Tensor(
out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
)
)
self.stride2 = (stride[0], 1, 1)
self.padding2 = (padding[0], 0, 0)
self.dilation2 = (dilation[0], 1, 1)
if bias:
self.bias2 = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter("bias2", None)
# Initialize weights and biases
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
if self.bias:
fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
bound1 = 1 / math.sqrt(fan_in1)
nn.init.uniform_(self.bias1, -bound1, bound1)
fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
bound2 = 1 / math.sqrt(fan_in2)
nn.init.uniform_(self.bias2, -bound2, bound2)
def forward(self, x, use_conv3d=False, skip_time_conv=False):
if use_conv3d:
return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
else:
return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
def forward_with_3d(self, x, skip_time_conv):
# First convolution
x = F.conv3d(
x,
self.weight1,
self.bias1,
self.stride1,
self.padding1,
self.dilation1,
self.groups,
padding_mode=self.padding_mode,
)
if skip_time_conv:
return x
# Second convolution
x = F.conv3d(
x,
self.weight2,
self.bias2,
self.stride2,
self.padding2,
self.dilation2,
self.groups,
padding_mode=self.padding_mode,
)
return x
def forward_with_2d(self, x, skip_time_conv):
b, c, d, h, w = x.shape
# First 2D convolution
x = rearrange(x, "b c d h w -> (b d) c h w")
# Squeeze the depth dimension out of weight1 since it's 1
weight1 = self.weight1.squeeze(2)
# Select stride, padding, and dilation for the 2D convolution
stride1 = (self.stride1[1], self.stride1[2])
padding1 = (self.padding1[1], self.padding1[2])
dilation1 = (self.dilation1[1], self.dilation1[2])
x = F.conv2d(
x,
weight1,
self.bias1,
stride1,
padding1,
dilation1,
self.groups,
padding_mode=self.padding_mode,
)
_, _, h, w = x.shape
if skip_time_conv:
x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
return x
# Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
# Reshape weight2 to match the expected dimensions for conv1d
weight2 = self.weight2.squeeze(-1).squeeze(-1)
# Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
stride2 = self.stride2[0]
padding2 = self.padding2[0]
dilation2 = self.dilation2[0]
x = F.conv1d(
x,
weight2,
self.bias2,
stride2,
padding2,
dilation2,
self.groups,
padding_mode=self.padding_mode,
)
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
return x
@property
def weight(self):
return self.weight2
def test_dual_conv3d_consistency():
# Initialize parameters
in_channels = 3
out_channels = 5
kernel_size = (3, 3, 3)
stride = (2, 2, 2)
padding = (1, 1, 1)
# Create an instance of the DualConv3d class
dual_conv3d = DualConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=True,
)
# Example input tensor
test_input = torch.randn(1, 3, 10, 10, 10)
# Perform forward passes with both 3D and 2D settings
output_conv3d = dual_conv3d(test_input, use_conv3d=True)
output_2d = dual_conv3d(test_input, use_conv3d=False)
# Assert that the outputs from both methods are sufficiently close
assert torch.allclose(
output_conv3d, output_2d, atol=1e-6
), "Outputs are not consistent between 3D and 2D convolutions."
================================================
FILE: ltx_video/models/autoencoders/latent_upsampler.py
================================================
from typing import Optional, Union
from pathlib import Path
import os
import json
import torch
import torch.nn as nn
from einops import rearrange
from diffusers import ConfigMixin, ModelMixin
from safetensors.torch import safe_open
from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND
class ResBlock(nn.Module):
def __init__(
self, channels: int, mid_channels: Optional[int] = None, dims: int = 3
):
super().__init__()
if mid_channels is None:
mid_channels = channels
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
self.norm1 = nn.GroupNorm(32, mid_channels)
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
self.norm2 = nn.GroupNorm(32, channels)
self.activation = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.conv1(x)
x = self.norm1(x)
x = self.activation(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.activation(x + residual)
return x
class LatentUpsampler(ModelMixin, ConfigMixin):
"""
Model to spatially upsample VAE latents.
Args:
in_channels (`int`): Number of channels in the input latent
mid_channels (`int`): Number of channels in the middle layers
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
dims (`int`): Number of dimensions for convolutions (2 or 3)
spatial_upsample (`bool`): Whether to spatially upsample the latent
temporal_upsample (`bool`): Whether to temporally upsample the latent
"""
def __init__(
self,
in_channels: int = 128,
mid_channels: int = 512,
num_blocks_per_stage: int = 4,
dims: int = 3,
spatial_upsample: bool = True,
temporal_upsample: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
self.num_blocks_per_stage = num_blocks_per_stage
self.dims = dims
self.spatial_upsample = spatial_upsample
self.temporal_upsample = temporal_upsample
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1)
self.initial_norm = nn.GroupNorm(32, mid_channels)
self.initial_activation = nn.SiLU()
self.res_blocks = nn.ModuleList(
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
)
if spatial_upsample and temporal_upsample:
self.upsampler = nn.Sequential(
nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(3),
)
elif spatial_upsample:
self.upsampler = nn.Sequential(
nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(2),
)
elif temporal_upsample:
self.upsampler = nn.Sequential(
nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(1),
)
else:
raise ValueError(
"Either spatial_upsample or temporal_upsample must be True"
)
self.post_upsample_res_blocks = nn.ModuleList(
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
)
self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1)
def forward(self, latent: torch.Tensor) -> torch.Tensor:
b, c, f, h, w = latent.shape
if self.dims == 2:
x = rearrange(latent, "b c f h w -> (b f) c h w")
x = self.initial_conv(x)
x = self.initial_norm(x)
x = self.initial_activation(x)
for block in self.res_blocks:
x = block(x)
x = self.upsampler(x)
for block in self.post_upsample_res_blocks:
x = block(x)
x = self.final_conv(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
else:
x = self.initial_conv(latent)
x = self.initial_norm(x)
x = self.initial_activation(x)
for block in self.res_blocks:
x = block(x)
if self.temporal_upsample:
x = self.upsampler(x)
x = x[:, :, 1:, :, :]
else:
x = rearrange(x, "b c f h w -> (b f) c h w")
x = self.upsampler(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
for block in self.post_upsample_res_blocks:
x = block(x)
x = self.final_conv(x)
return x
@classmethod
def from_config(cls, config):
return cls(
in_channels=config.get("in_channels", 4),
mid_channels=config.get("mid_channels", 128),
num_blocks_per_stage=config.get("num_blocks_per_stage", 4),
dims=config.get("dims", 2),
spatial_upsample=config.get("spatial_upsample", True),
temporal_upsample=config.get("temporal_upsample", False),
)
def config(self):
return {
"_class_name": "LatentUpsampler",
"in_channels": self.in_channels,
"mid_channels": self.mid_channels,
"num_blocks_per_stage": self.num_blocks_per_stage,
"dims": self.dims,
"spatial_upsample": self.spatial_upsample,
"temporal_upsample": self.temporal_upsample,
}
@classmethod
def from_pretrained(
cls,
pretrained_model_path: Optional[Union[str, os.PathLike]],
*args,
**kwargs,
):
pretrained_model_path = Path(pretrained_model_path)
if pretrained_model_path.is_file() and str(pretrained_model_path).endswith(
".safetensors"
):
state_dict = {}
with safe_open(pretrained_model_path, framework="pt", device="cpu") as f:
metadata = f.metadata()
for k in f.keys():
state_dict[k] = f.get_tensor(k)
config = json.loads(metadata["config"])
with torch.device("meta"):
latent_upsampler = LatentUpsampler.from_config(config)
latent_upsampler.load_state_dict(state_dict, assign=True)
return latent_upsampler
if __name__ == "__main__":
latent_upsampler = LatentUpsampler(num_blocks_per_stage=4, dims=3)
print(latent_upsampler)
total_params = sum(p.numel() for p in latent_upsampler.parameters())
print(f"Total number of parameters: {total_params:,}")
latent = torch.randn(1, 128, 9, 16, 16)
upsampled_latent = latent_upsampler(latent)
print(f"Upsampled latent shape: {upsampled_latent.shape}")
================================================
FILE: ltx_video/models/autoencoders/pixel_norm.py
================================================
import torch
from torch import nn
class PixelNorm(nn.Module):
def __init__(self, dim=1, eps=1e-8):
super(PixelNorm, self).__init__()
self.dim = dim
self.eps = eps
def forward(self, x):
return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)
================================================
FILE: ltx_video/models/autoencoders/pixel_shuffle.py
================================================
import torch.nn as nn
from einops import rearrange
class PixelShuffleND(nn.Module):
def __init__(self, dims, upscale_factors=(2, 2, 2)):
super().__init__()
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
self.dims = dims
self.upscale_factors = upscale_factors
def forward(self, x):
if self.dims == 3:
return rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.upscale_factors[0],
p2=self.upscale_factors[1],
p3=self.upscale_factors[2],
)
elif self.dims == 2:
return rearrange(
x,
"b (c p1 p2) h w -> b c (h p1) (w p2)",
p1=self.upscale_factors[0],
p2=self.upscale_factors[1],
)
elif self.dims == 1:
return rearrange(
x,
"b (c p1) f h w -> b c (f p1) h w",
p1=self.upscale_factors[0],
)
================================================
FILE: ltx_video/models/autoencoders/vae.py
================================================
from typing import Optional, Union
import torch
import inspect
import math
import torch.nn as nn
from diffusers import ConfigMixin, ModelMixin
from diffusers.models.autoencoders.vae import (
DecoderOutput,
DiagonalGaussianDistribution,
)
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd
class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
"""Variational Autoencoder (VAE) model with KL loss.
VAE from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling.
This model is a wrapper around an encoder and a decoder, and it adds a KL loss term to the reconstruction loss.
Args:
encoder (`nn.Module`):
Encoder module.
decoder (`nn.Module`):
Decoder module.
latent_channels (`int`, *optional*, defaults to 4):
Number of latent channels.
"""
def __init__(
self,
encoder: nn.Module,
decoder: nn.Module,
latent_channels: int = 4,
dims: int = 2,
sample_size=512,
use_quant_conv: bool = True,
normalize_latent_channels: bool = False,
):
super().__init__()
# pass init params to Encoder
self.encoder = encoder
self.use_quant_conv = use_quant_conv
self.normalize_latent_channels = normalize_latent_channels
# pass init params to Decoder
quant_dims = 2 if dims == 2 else 3
self.decoder = decoder
if use_quant_conv:
self.quant_conv = make_conv_nd(
quant_dims, 2 * latent_channels, 2 * latent_channels, 1
)
self.post_quant_conv = make_conv_nd(
quant_dims, latent_channels, latent_channels, 1
)
else:
self.quant_conv = nn.Identity()
self.post_quant_conv = nn.Identity()
if normalize_latent_channels:
if dims == 2:
self.latent_norm_out = nn.BatchNorm2d(latent_channels, affine=False)
else:
self.latent_norm_out = nn.BatchNorm3d(latent_channels, affine=False)
else:
self.latent_norm_out = nn.Identity()
self.use_z_tiling = False
self.use_hw_tiling = False
self.dims = dims
self.z_sample_size = 1
self.decoder_params = inspect.signature(self.decoder.forward).parameters
# only relevant if vae tiling is enabled
self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25)
def set_tiling_params(self, sample_size: int = 512, overlap_factor: float = 0.25):
self.tile_sample_min_size = sample_size
num_blocks = len(self.encoder.down_blocks)
self.tile_latent_min_size = int(sample_size / (2 ** (num_blocks - 1)))
self.tile_overlap_factor = overlap_factor
def enable_z_tiling(self, z_sample_size: int = 8):
r"""
Enable tiling during VAE decoding.
When this option is enabled, the VAE will split the input tensor in tiles to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_z_tiling = z_sample_size > 1
self.z_sample_size = z_sample_size
assert (
z_sample_size % 8 == 0 or z_sample_size == 1
), f"z_sample_size must be a multiple of 8 or 1. Got {z_sample_size}."
def disable_z_tiling(self):
r"""
Disable tiling during VAE decoding. If `use_tiling` was previously invoked, this method will go back to computing
decoding in one step.
"""
self.use_z_tiling = False
def enable_hw_tiling(self):
r"""
Enable tiling during VAE decoding along the height and width dimension.
"""
self.use_hw_tiling = True
def disable_hw_tiling(self):
r"""
Disable tiling during VAE decoding along the height and width dimension.
"""
self.use_hw_tiling = False
def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True):
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[3], overlap_size):
row = []
for j in range(0, x.shape[4], overlap_size):
tile = x[
:,
:,
:,
i : i + self.tile_sample_min_size,
j : j + self.tile_sample_min_size,
]
tile = self.encoder(tile)
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=4))
moments = torch.cat(result_rows, dim=3)
return moments
def blend_z(
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for z in range(blend_extent):
b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * (
1 - z / blend_extent
) + b[:, :, z, :, :] * (z / blend_extent)
return b
def blend_v(
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
1 - y / blend_extent
) + b[:, :, :, y, :] * (y / blend_extent)
return b
def blend_h(
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
) -> torch.Tensor:
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
1 - x / blend_extent
) + b[:, :, :, :, x] * (x / blend_extent)
return b
def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape):
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
tile_target_shape = (
*target_shape[:3],
self.tile_sample_min_size,
self.tile_sample_min_size,
)
# Split z into overlapping 64x64 tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[3], overlap_size):
row = []
for j in range(0, z.shape[4], overlap_size):
tile = z[
:,
:,
:,
i : i + self.tile_latent_min_size,
j : j + self.tile_latent_min_size,
]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile, target_shape=tile_target_shape)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=4))
dec = torch.cat(result_rows, dim=3)
return dec
def encode(
self, z: torch.FloatTensor, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
num_splits = z.shape[2] // self.z_sample_size
sizes = [self.z_sample_size] * num_splits
sizes = (
sizes + [z.shape[2] - sum(sizes)]
if z.shape[2] - sum(sizes) > 0
else sizes
)
tiles = z.split(sizes, dim=2)
moments_tiles = [
(
self._hw_tiled_encode(z_tile, return_dict)
if self.use_hw_tiling
else self._encode(z_tile)
)
for z_tile in tiles
]
moments = torch.cat(moments_tiles, dim=2)
else:
moments = (
self._hw_tiled_encode(z, return_dict)
if self.use_hw_tiling
else self._encode(z)
)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _normalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor:
if isinstance(self.latent_norm_out, nn.BatchNorm3d):
_, c, _, _, _ = z.shape
z = torch.cat(
[
self.latent_norm_out(z[:, : c // 2, :, :, :]),
z[:, c // 2 :, :, :, :],
],
dim=1,
)
elif isinstance(self.latent_norm_out, nn.BatchNorm2d):
raise NotImplementedError("BatchNorm2d not supported")
return z
def _unnormalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor:
if isinstance(self.latent_norm_out, nn.BatchNorm3d):
running_mean = self.latent_norm_out.running_mean.view(1, -1, 1, 1, 1)
running_var = self.latent_norm_out.running_var.view(1, -1, 1, 1, 1)
eps = self.latent_norm_out.eps
z = z * torch.sqrt(running_var + eps) + running_mean
elif isinstance(self.latent_norm_out, nn.BatchNorm3d):
raise NotImplementedError("BatchNorm2d not supported")
return z
def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput:
h = self.encoder(x)
moments = self.quant_conv(h)
moments = self._normalize_latent_channels(moments)
return moments
def _decode(
self,
z: torch.FloatTensor,
target_shape=None,
timestep: Optional[torch.Tensor] = None,
) -> Union[DecoderOutput, torch.FloatTensor]:
z = self._unnormalize_latent_channels(z)
z = self.post_quant_conv(z)
if "timestep" in self.decoder_params:
dec = self.decoder(z, target_shape=target_shape, timestep=timestep)
else:
dec = self.decoder(z, target_shape=target_shape)
return dec
def decode(
self,
z: torch.FloatTensor,
return_dict: bool = True,
target_shape=None,
timestep: Optional[torch.Tensor] = None,
) -> Union[DecoderOutput, torch.FloatTensor]:
assert target_shape is not None, "target_shape must be provided for decoding"
if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
reduction_factor = int(
self.encoder.patch_size_t
* 2
** (
len(self.encoder.down_blocks)
- 1
- math.sqrt(self.encoder.patch_size)
)
)
split_size = self.z_sample_size // reduction_factor
num_splits = z.shape[2] // split_size
# copy target shape, and divide frame dimension (=2) by the context size
target_shape_split = list(target_shape)
target_shape_split[2] = target_shape[2] // num_splits
decoded_tiles = [
(
self._hw_tiled_decode(z_tile, target_shape_split)
if self.use_hw_tiling
else self._decode(z_tile, target_shape=target_shape_split)
)
for z_tile in torch.tensor_split(z, num_splits, dim=2)
]
decoded = torch.cat(decoded_tiles, dim=2)
else:
decoded = (
self._hw_tiled_decode(z, target_shape)
if self.use_hw_tiling
else self._decode(z, target_shape=target_shape, timestep=timestep)
)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.FloatTensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`DecoderOutput`] instead of a plain tuple.
generator (`torch.Generator`, *optional*):
Generator used to sample from the posterior.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, target_shape=sample.shape).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
================================================
FILE: ltx_video/models/autoencoders/vae_encode.py
================================================
from typing import Tuple
import torch
from diffusers import AutoencoderKL
from einops import rearrange
from torch import Tensor
from ltx_video.models.autoencoders.causal_video_autoencoder import (
CausalVideoAutoencoder,
)
from ltx_video.models.autoencoders.video_autoencoder import (
Downsample3D,
VideoAutoencoder,
)
try:
import torch_xla.core.xla_model as xm
except ImportError:
xm = None
def vae_encode(
media_items: Tensor,
vae: AutoencoderKL,
split_size: int = 1,
vae_per_channel_normalize=False,
) -> Tensor:
"""
Encodes media items (images or videos) into latent representations using a specified VAE model.
The function supports processing batches of images or video frames and can handle the processing
in smaller sub-batches if needed.
Args:
media_items (Tensor): A torch Tensor containing the media items to encode. The expected
shape is (batch_size, channels, height, width) for images or (batch_size, channels,
frames, height, width) for videos.
vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library,
pre-configured and loaded with the appropriate model weights.
split_size (int, optional): The number of sub-batches to split the input batch into for encoding.
If set to more than 1, the input media items are processed in smaller batches according to
this value. Defaults to 1, which processes all items in a single batch.
Returns:
Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted
to match the input shape, scaled by the model's configuration.
Examples:
>>> import torch
>>> from diffusers import AutoencoderKL
>>> vae = AutoencoderKL.from_pretrained('your-model-name')
>>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames.
>>> latents = vae_encode(images, vae)
>>> print(latents.shape) # Output shape will depend on the model's latent configuration.
Note:
In case of a video, the function encodes the media item frame-by frame.
"""
is_video_shaped = media_items.dim() == 5
batch_size, channels = media_items.shape[0:2]
if channels != 3:
raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
if is_video_shaped and not isinstance(
vae, (VideoAutoencoder, CausalVideoAutoencoder)
):
media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
if split_size > 1:
if len(media_items) % split_size != 0:
raise ValueError(
"Error: The batch size must be divisible by 'train.vae_bs_split"
)
encode_bs = len(media_items) // split_size
# latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
latents = []
if media_items.device.type == "xla":
xm.mark_step()
for image_batch in media_items.split(encode_bs):
latents.append(vae.encode(image_batch).latent_dist.sample())
if media_items.device.type == "xla":
xm.mark_step()
latents = torch.cat(latents, dim=0)
else:
latents = vae.encode(media_items).latent_dist.sample()
latents = normalize_latents(latents, vae, vae_per_channel_normalize)
if is_video_shaped and not isinstance(
vae, (VideoAutoencoder, CausalVideoAutoencoder)
):
latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
return latents
def vae_decode(
latents: Tensor,
vae: AutoencoderKL,
is_video: bool = True,
split_size: int = 1,
vae_per_channel_normalize=False,
timestep=None,
) -> Tensor:
is_video_shaped = latents.dim() == 5
batch_size = latents.shape[0]
if is_video_shaped and not isinstance(
vae, (VideoAutoencoder, CausalVideoAutoencoder)
):
latents = rearrange(latents, "b c n h w -> (b n) c h w")
if split_size > 1:
if len(latents) % split_size != 0:
raise ValueError(
"Error: The batch size must be divisible by 'train.vae_bs_split"
)
encode_bs = len(latents) // split_size
image_batch = [
_run_decoder(
latent_batch, vae, is_video, vae_per_channel_normalize, timestep
)
for latent_batch in latents.split(encode_bs)
]
images = torch.cat(image_batch, dim=0)
else:
images = _run_decoder(
latents, vae, is_video, vae_per_channel_normalize, timestep
)
if is_video_shaped and not isinstance(
vae, (VideoAutoencoder, CausalVideoAutoencoder)
):
images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
return images
def _run_decoder(
latents: Tensor,
vae: AutoencoderKL,
is_video: bool,
vae_per_channel_normalize=False,
timestep=None,
) -> Tensor:
if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
*_, fl, hl, wl = latents.shape
temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
latents = latents.to(vae.dtype)
vae_decode_kwargs = {}
if timestep is not None:
vae_decode_kwargs["timestep"] = timestep
image = vae.decode(
un_normalize_latents(latents, vae, vae_per_channel_normalize),
return_dict=False,
target_shape=(
1,
3,
fl * temporal_scale if is_video else 1,
hl * spatial_scale,
wl * spatial_scale,
),
**vae_decode_kwargs,
)[0]
else:
image = vae.decode(
un_normalize_latents(latents, vae, vae_per_channel_normalize),
return_dict=False,
)[0]
return image
def get_vae_size_scale_factor(vae: AutoencoderKL) -> float:
if isinstance(vae, CausalVideoAutoencoder):
spatial = vae.spatial_downscale_factor
temporal = vae.temporal_downscale_factor
else:
down_blocks = len(
[
block
for block in vae.encoder.down_blocks
if isinstance(block.downsample, Downsample3D)
]
)
spatial = vae.config.patch_size * 2**down_blocks
temporal = (
vae.config.patch_size_t * 2**down_blocks
if isinstance(vae, VideoAutoencoder)
else 1
)
return (temporal, spatial, spatial)
def latent_to_pixel_coords(
latent_coords: Tensor, vae: AutoencoderKL, causal_fix: bool = False
) -> Tensor:
"""
Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
configuration.
Args:
latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
containing the latent corner coordinates of each token.
vae (AutoencoderKL): The VAE model
causal_fix (bool): Whether to take into account the different temporal scale
of the first frame. Default = False for backwards compatibility.
Returns:
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
"""
scale_factors = get_vae_size_scale_factor(vae)
causal_fix = isinstance(vae, CausalVideoAutoencoder) and causal_fix
pixel_coords = latent_to_pixel_coords_from_factors(
latent_coords, scale_factors, causal_fix
)
return pixel_coords
def latent_to_pixel_coords_from_factors(
latent_coords: Tensor, scale_factors: Tuple, causal_fix: bool = False
) -> Tensor:
pixel_coords = (
latent_coords
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
)
if causal_fix:
# Fix temporal scale for first frame to 1 due to causality
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
return pixel_coords
def normalize_latents(
latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
) -> Tensor:
return (
(latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1))
/ vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
if vae_per_channel_normalize
else latents * vae.config.scaling_factor
)
def un_normalize_latents(
latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
) -> Tensor:
return (
latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
+ vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
if vae_per_channel_normalize
else latents / vae.config.scaling_factor
)
================================================
FILE: ltx_video/models/autoencoders/video_autoencoder.py
================================================
import json
import os
from functools import partial
from types import SimpleNamespace
from typing import Any, Mapping, Optional, Tuple, Union
import torch
from einops import rearrange
from torch import nn
from torch.nn import functional
from diffusers.utils import logging
from ltx_video.utils.torch_utils import Identity
from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
from ltx_video.models.autoencoders.pixel_norm import PixelNorm
from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper
logger = logging.get_logger(__name__)
class VideoAutoencoder(AutoencoderKLWrapper):
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*args,
**kwargs,
):
config_local_path = pretrained_model_name_or_path / "config.json"
config = cls.load_config(config_local_path, **kwargs)
video_vae = cls.from_config(config)
video_vae.to(kwargs["torch_dtype"])
model_local_path = pretrained_model_name_or_path / "autoencoder.pth"
ckpt_state_dict = torch.load(model_local_path)
video_vae.load_state_dict(ckpt_state_dict)
statistics_local_path = (
pretrained_model_name_or_path / "per_channel_statistics.json"
)
if statistics_local_path.exists():
with open(statistics_local_path, "r") as file:
data = json.load(file)
transposed_data = list(zip(*data["data"]))
data_dict = {
col: torch.tensor(vals)
for col, vals in zip(data["columns"], transposed_data)
}
video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
video_vae.register_buffer(
"mean_of_means",
data_dict.get(
"mean-of-means", torch.zeros_like(data_dict["std-of-means"])
),
)
return video_vae
@staticmethod
def from_config(config):
assert (
config["_class_name"] == "VideoAutoencoder"
), "config must have _class_name=VideoAutoencoder"
if isinstance(config["dims"], list):
config["dims"] = tuple(config["dims"])
assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
)
use_quant_conv = config.get("use_quant_conv", True)
if use_quant_conv and latent_log_var == "uniform":
raise ValueError("uniform latent_log_var requires use_quant_conv=False")
encoder = Encoder(
dims=config["dims"],
in_channels=config.get("in_channels", 3),
out_channels=config["latent_channels"],
block_out_channels=config["block_out_channels"],
patch_size=config.get("patch_size", 1),
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
add_channel_padding=config.get("add_channel_padding", False),
)
decoder = Decoder(
dims=config["dims"],
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
block_out_channels=config["block_out_channels"],
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
add_channel_padding=config.get("add_channel_padding", False),
)
dims = config["dims"]
return VideoAutoencoder(
encoder=encoder,
decoder=decoder,
latent_channels=config["latent_channels"],
dims=dims,
use_quant_conv=use_quant_conv,
)
@property
def config(self):
return SimpleNamespace(
_class_name="VideoAutoencoder",
dims=self.dims,
in_channels=self.encoder.conv_in.in_channels
// (self.encoder.patch_size_t * self.encoder.patch_size**2),
out_channels=self.decoder.conv_out.out_channels
// (self.decoder.patch_size_t * self.decoder.patch_size**2),
latent_channels=self.decoder.conv_in.in_channels,
block_out_channels=[
self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels
for i in range(len(self.encoder.down_blocks))
],
scaling_factor=1.0,
norm_layer=self.encoder.norm_layer,
patch_size=self.encoder.patch_size,
latent_log_var=self.encoder.latent_log_var,
use_quant_conv=self.use_quant_conv,
patch_size_t=self.encoder.patch_size_t,
add_channel_padding=self.encoder.add_channel_padding,
)
@property
def is_video_supported(self):
"""
Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images.
"""
return self.dims != 2
@property
def downscale_factor(self):
return self.encoder.downsample_factor
def to_json_string(self) -> str:
import json
return json.dumps(self.config.__dict__)
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
model_keys = set(name for name, _ in self.named_parameters())
key_mapping = {
".resnets.": ".res_blocks.",
"downsamplers.0": "downsample",
"upsamplers.0": "upsample",
}
converted_state_dict = {}
for key, value in state_dict.items():
for k, v in key_mapping.items():
key = key.replace(k, v)
if "norm" in key and key not in model_keys:
logger.info(
f"Removing key {key} from state_dict as it is not present in the model"
)
continue
converted_state_dict[key] = value
super().load_state_dict(converted_state_dict, strict=strict)
def last_layer(self):
if hasattr(self.decoder, "conv_out"):
if isinstance(self.decoder.conv_out, nn.Sequential):
last_layer = self.decoder.conv_out[-1]
else:
last_layer = self.decoder.conv_out
else:
last_layer = self.decoder.layers[-1]
return last_layer
class Encoder(nn.Module):
r"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
patch_size (`int`, *optional*, defaults to 1):
The patch size to use. Should be a power of 2.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
latent_log_var (`str`, *optional*, defaults to `per_channel`):
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]] = 3,
in_channels: int = 3,
out_channels: int = 3,
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
patch_size: Union[int, Tuple[int]] = 1,
norm_layer: str = "group_norm", # group_norm, pixel_norm
latent_log_var: str = "per_channel",
patch_size_t: Optional[int] = None,
add_channel_padding: Optional[bool] = False,
):
super().__init__()
self.patch_size = patch_size
self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
self.add_channel_padding = add_channel_padding
self.layers_per_block = layers_per_block
self.norm_layer = norm_layer
self.latent_channels = out_channels
self.latent_log_var = latent_log_var
if add_channel_padding:
in_channels = in_channels * self.patch_size**3
else:
in_channels = in_channels * self.patch_size_t * self.patch_size**2
self.in_channels = in_channels
output_channel = block_out_channels[0]
self.conv_in = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=output_channel,
kernel_size=3,
stride=1,
padding=1,
)
self.down_blocks = nn.ModuleList([])
for i in range(len(block_out_channels)):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = DownEncoderBlock3D(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
num_layers=self.layers_per_block,
add_downsample=not is_final_block and 2**i >= patch_size,
resnet_eps=1e-6,
downsample_padding=0,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
)
self.down_blocks.append(down_block)
self.mid_block = UNetMidBlock3D(
dims=dims,
in_channels=block_out_channels[-1],
num_layers=self.layers_per_block,
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
)
# out
if norm_layer == "group_norm":
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[-1],
num_groups=norm_num_groups,
eps=1e-6,
)
elif norm_layer == "pixel_norm":
self.conv_norm_out = PixelNorm()
self.conv_act = nn.SiLU()
conv_out_channels = out_channels
if latent_log_var == "per_channel":
conv_out_channels *= 2
elif latent_log_var == "uniform":
conv_out_channels += 1
elif latent_log_var != "none":
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
self.conv_out = make_conv_nd(
dims, block_out_channels[-1], conv_out_channels, 3, padding=1
)
self.gradient_checkpointing = False
@property
def downscale_factor(self):
return (
2
** len(
[
block
for block in self.down_blocks
if isinstance(block.downsample, Downsample3D)
]
)
* self.patch_size
)
def forward(
self, sample: torch.FloatTensor, return_features=False
) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
downsample_in_time = sample.shape[2] != 1
# patchify
patch_size_t = self.patch_size_t if downsample_in_time else 1
sample = patchify(
sample,
patch_size_hw=self.patch_size,
patch_size_t=patch_size_t,
add_channel_padding=self.add_channel_padding,
)
sample = self.conv_in(sample)
checkpoint_fn = (
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
if self.gradient_checkpointing and self.training
else lambda x: x
)
if return_features:
features = []
for down_block in self.down_blocks:
sample = checkpoint_fn(down_block)(
sample, downsample_in_time=downsample_in_time
)
if return_features:
features.append(sample)
sample = checkpoint_fn(self.mid_block)(sample)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if self.latent_log_var == "uniform":
last_channel = sample[:, -1:, ...]
num_dims = sample.dim()
if num_dims == 4:
# For shape (B, C, H, W)
repeated_last_channel = last_channel.repeat(
1, sample.shape[1] - 2, 1, 1
)
sample = torch.cat([sample, repeated_last_channel], dim=1)
elif num_dims == 5:
# For shape (B, C, F, H, W)
repeated_last_channel = last_channel.repeat(
1, sample.shape[1] - 2, 1, 1, 1
)
sample = torch.cat([sample, repeated_last_channel], dim=1)
else:
raise ValueError(f"Invalid input shape: {sample.shape}")
if return_features:
features.append(sample[:, : self.latent_channels, ...])
return sample, features
return sample
class Decoder(nn.Module):
r"""
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
patch_size (`int`, *optional*, defaults to 1):
The patch size to use. Should be a power of 2.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
"""
def __init__(
self,
dims,
in_channels: int = 3,
out_channels: int = 3,
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
patch_size: int = 1,
norm_layer: str = "group_norm",
patch_size_t: Optional[int] = None,
add_channel_padding: Optional[bool] = False,
):
super().__init__()
self.patch_size = patch_size
self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
self.add_channel_padding = add_channel_padding
self.layers_per_block = layers_per_block
if add_channel_padding:
out_channels = out_channels * self.patch_size**3
else:
out_channels = out_channels * self.patch_size_t * self.patch_size**2
self.out_channels = out_channels
self.conv_in = make_conv_nd(
dims,
in_channels,
block_out_channels[-1],
kernel_size=3,
stride=1,
padding=1,
)
self.mid_block = None
self.up_blocks = nn.ModuleList([])
self.mid_block = UNetMidBlock3D(
dims=dims,
in_channels=block_out_channels[-1],
num_layers=self.layers_per_block,
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
)
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i in range(len(reversed_block_out_channels)):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = UpDecoderBlock3D(
dims=dims,
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
add_upsample=not is_final_block
and 2 ** (len(block_out_channels) - i - 1) > patch_size,
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
)
self.up_blocks.append(up_block)
if norm_layer == "group_norm":
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
)
elif norm_layer == "pixel_norm":
self.conv_norm_out = PixelNorm()
self.conv_act = nn.SiLU()
self.conv_out = make_conv_nd(
dims, block_out_channels[0], out_channels, 3, padding=1
)
self.gradient_checkpointing = False
def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
assert target_shape is not None, "target_shape must be provided"
upsample_in_time = sample.shape[2] < target_shape[2]
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
checkpoint_fn = (
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
if self.gradient_checkpointing and self.training
else lambda x: x
)
sample = checkpoint_fn(self.mid_block)(sample)
sample = sample.to(upscale_dtype)
for up_block in self.up_blocks:
sample = checkpoint_fn(up_block)(sample, upsample_in_time=upsample_in_time)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
# un-patchify
patch_size_t = self.patch_size_t if upsample_in_time else 1
sample = unpatchify(
sample,
patch_size_hw=self.patch_size,
patch_size_t=patch_size_t,
add_channel_padding=self.add_channel_padding,
)
return sample
class DownEncoderBlock3D(nn.Module):
def __init__(
self,
dims: Union[int, Tuple[int, int]],
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_groups: int = 32,
add_downsample: bool = True,
downsample_padding: int = 1,
norm_layer: str = "group_norm",
):
super().__init__()
res_blocks = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
res_blocks.append(
ResnetBlock3D(
dims=dims,
in_channels=in_channels,
out_channels=out_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
norm_layer=norm_layer,
)
)
self.res_blocks = nn.ModuleList(res_blocks)
if add_downsample:
self.downsample = Downsample3D(
dims,
out_channels,
out_channels=out_channels,
padding=downsample_padding,
)
else:
self.downsample = Identity()
def forward(
self, hidden_states: torch.FloatTensor, downsample_in_time
) -> torch.FloatTensor:
for resnet in self.res_blocks:
hidden_states = resnet(hidden_states)
hidden_states = self.downsamp
gitextract_jt2joi_e/
├── .gitattributes
├── .github/
│ └── workflows/
│ └── pylint.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── configs/
│ ├── ltxv-13b-0.9.8-dev-fp8.yaml
│ ├── ltxv-13b-0.9.8-dev.yaml
│ ├── ltxv-13b-0.9.8-distilled-fp8.yaml
│ ├── ltxv-13b-0.9.8-distilled.yaml
│ ├── ltxv-2b-0.9.1.yaml
│ ├── ltxv-2b-0.9.5.yaml
│ ├── ltxv-2b-0.9.6-dev.yaml
│ ├── ltxv-2b-0.9.6-distilled.yaml
│ ├── ltxv-2b-0.9.8-distilled-fp8.yaml
│ ├── ltxv-2b-0.9.8-distilled.yaml
│ └── ltxv-2b-0.9.yaml
├── inference.py
├── ltx_video/
│ ├── __init__.py
│ ├── inference.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── autoencoders/
│ │ │ ├── __init__.py
│ │ │ ├── causal_conv3d.py
│ │ │ ├── causal_video_autoencoder.py
│ │ │ ├── conv_nd_factory.py
│ │ │ ├── dual_conv3d.py
│ │ │ ├── latent_upsampler.py
│ │ │ ├── pixel_norm.py
│ │ │ ├── pixel_shuffle.py
│ │ │ ├── vae.py
│ │ │ ├── vae_encode.py
│ │ │ └── video_autoencoder.py
│ │ └── transformers/
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── embeddings.py
│ │ ├── symmetric_patchifier.py
│ │ └── transformer3d.py
│ ├── pipelines/
│ │ ├── __init__.py
│ │ ├── crf_compressor.py
│ │ └── pipeline_ltx_video.py
│ ├── schedulers/
│ │ ├── __init__.py
│ │ └── rf.py
│ └── utils/
│ ├── __init__.py
│ ├── diffusers_config_mapping.py
│ ├── prompt_enhance_utils.py
│ ├── skip_layer_strategy.py
│ └── torch_utils.py
├── pyproject.toml
└── tests/
├── conftest.py
├── test_configs.py
├── test_inference.py
├── test_scheduler.py
├── test_vae.py
└── utils/
└── .gitattributes
SYMBOL INDEX (287 symbols across 28 files)
FILE: inference.py
function main (line 6) | def main():
FILE: ltx_video/inference.py
function get_total_gpu_memory (line 44) | def get_total_gpu_memory():
function get_device (line 51) | def get_device():
function load_image_to_tensor_with_resize_and_crop (line 59) | def load_image_to_tensor_with_resize_and_crop(
function calculate_padding (line 108) | def calculate_padding(
function convert_prompt_to_filename (line 128) | def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
function get_unique_filename (line 155) | def get_unique_filename(
function seed_everething (line 175) | def seed_everething(seed: int):
function create_transformer (line 185) | def create_transformer(ckpt_path: str, precision: str) -> Transformer3DM...
function create_ltx_video_pipeline (line 207) | def create_ltx_video_pipeline(
function create_latent_upsampler (line 294) | def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
function load_pipeline_config (line 301) | def load_pipeline_config(pipeline_config: str):
class InferenceConfig (line 317) | class InferenceConfig:
function infer (line 389) | def infer(config: InferenceConfig):
function prepare_conditioning (line 637) | def prepare_conditioning(
function get_media_num_frames (line 690) | def get_media_num_frames(media_path: str) -> int:
function load_media_file (line 702) | def load_media_file(
FILE: ltx_video/models/autoencoders/causal_conv3d.py
class CausalConv3d (line 7) | class CausalConv3d(nn.Module):
method __init__ (line 8) | def __init__(
method forward (line 44) | def forward(self, x, causal: bool = True):
method weight (line 62) | def weight(self):
FILE: ltx_video/models/autoencoders/causal_video_autoencoder.py
class CausalVideoAutoencoder (line 33) | class CausalVideoAutoencoder(AutoencoderKLWrapper):
method from_pretrained (line 35) | def from_pretrained(
method from_config (line 123) | def from_config(config):
method config (line 180) | def config(self):
method is_video_supported (line 201) | def is_video_supported(self):
method spatial_downscale_factor (line 208) | def spatial_downscale_factor(self):
method temporal_downscale_factor (line 228) | def temporal_downscale_factor(self):
method to_json_string (line 243) | def to_json_string(self) -> str:
method load_state_dict (line 248) | def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool ...
method last_layer (line 298) | def last_layer(self):
method set_use_tpu_flash_attention (line 308) | def set_use_tpu_flash_attention(self):
class Encoder (line 315) | class Encoder(nn.Module):
method __init__ (line 340) | def __init__(
method forward (line 508) | def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
class Decoder (line 558) | class Decoder(nn.Module):
method __init__ (line 583) | def __init__(
method forward (line 733) | def forward(
class UNetMidBlock3D (line 803) | class UNetMidBlock3D(nn.Module):
method __init__ (line 829) | def __init__(
method forward (line 895) | def forward(
class SpaceToDepthDownsample (line 974) | class SpaceToDepthDownsample(nn.Module):
method __init__ (line 975) | def __init__(self, dims, in_channels, out_channels, stride, spatial_pa...
method forward (line 989) | def forward(self, x, causal: bool = True):
class DepthToSpaceUpsample (line 1021) | class DepthToSpaceUpsample(nn.Module):
method __init__ (line 1022) | def __init__(
method forward (line 1049) | def forward(self, x, causal: bool = True):
class LayerNorm (line 1066) | class LayerNorm(nn.Module):
method __init__ (line 1067) | def __init__(self, dim, eps, elementwise_affine=True) -> None:
method forward (line 1071) | def forward(self, x):
class ResnetBlock3D (line 1078) | class ResnetBlock3D(nn.Module):
method __init__ (line 1091) | def __init__(
method _feed_spatial_noise (line 1181) | def _feed_spatial_noise(
method forward (line 1195) | def forward(
function patchify (line 1259) | def patchify(x, patch_size_hw, patch_size_t=1):
function unpatchify (line 1280) | def unpatchify(x, patch_size_hw, patch_size_t=1):
function create_video_autoencoder_demo_config (line 1300) | def create_video_autoencoder_demo_config(
function test_vae_patchify_unpatchify (line 1334) | def test_vae_patchify_unpatchify():
function demo_video_autoencoder_forward_backward (line 1343) | def demo_video_autoencoder_forward_backward():
FILE: ltx_video/models/autoencoders/conv_nd_factory.py
function make_conv_nd (line 9) | def make_conv_nd(
function make_linear_nd (line 75) | def make_linear_nd(
FILE: ltx_video/models/autoencoders/dual_conv3d.py
class DualConv3d (line 10) | class DualConv3d(nn.Module):
method __init__ (line 11) | def __init__(
method reset_parameters (line 86) | def reset_parameters(self):
method forward (line 97) | def forward(self, x, use_conv3d=False, skip_time_conv=False):
method forward_with_3d (line 103) | def forward_with_3d(self, x, skip_time_conv):
method forward_with_2d (line 133) | def forward_with_2d(self, x, skip_time_conv):
method weight (line 185) | def weight(self):
function test_dual_conv3d_consistency (line 189) | def test_dual_conv3d_consistency():
FILE: ltx_video/models/autoencoders/latent_upsampler.py
class ResBlock (line 15) | class ResBlock(nn.Module):
method __init__ (line 16) | def __init__(
method forward (line 31) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class LatentUpsampler (line 42) | class LatentUpsampler(ModelMixin, ConfigMixin):
method __init__ (line 55) | def __init__(
method forward (line 109) | def forward(self, latent: torch.Tensor) -> torch.Tensor:
method from_config (line 152) | def from_config(cls, config):
method config (line 162) | def config(self):
method from_pretrained (line 174) | def from_pretrained(
FILE: ltx_video/models/autoencoders/pixel_norm.py
class PixelNorm (line 5) | class PixelNorm(nn.Module):
method __init__ (line 6) | def __init__(self, dim=1, eps=1e-8):
method forward (line 11) | def forward(self, x):
FILE: ltx_video/models/autoencoders/pixel_shuffle.py
class PixelShuffleND (line 5) | class PixelShuffleND(nn.Module):
method __init__ (line 6) | def __init__(self, dims, upscale_factors=(2, 2, 2)):
method forward (line 12) | def forward(self, x):
FILE: ltx_video/models/autoencoders/vae.py
class AutoencoderKLWrapper (line 16) | class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
method __init__ (line 31) | def __init__(
method set_tiling_params (line 79) | def set_tiling_params(self, sample_size: int = 512, overlap_factor: fl...
method enable_z_tiling (line 85) | def enable_z_tiling(self, z_sample_size: int = 8):
method disable_z_tiling (line 98) | def disable_z_tiling(self):
method enable_hw_tiling (line 105) | def enable_hw_tiling(self):
method disable_hw_tiling (line 111) | def disable_hw_tiling(self):
method _hw_tiled_encode (line 117) | def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = T...
method blend_z (line 154) | def blend_z(
method blend_v (line 164) | def blend_v(
method blend_h (line 174) | def blend_h(
method _hw_tiled_decode (line 184) | def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape):
method encode (line 226) | def encode(
method _normalize_latent_channels (line 261) | def _normalize_latent_channels(self, z: torch.FloatTensor) -> torch.Fl...
method _unnormalize_latent_channels (line 275) | def _unnormalize_latent_channels(self, z: torch.FloatTensor) -> torch....
method _encode (line 286) | def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput:
method _decode (line 292) | def _decode(
method decode (line 306) | def decode(
method forward (line 352) | def forward(
FILE: ltx_video/models/autoencoders/vae_encode.py
function vae_encode (line 22) | def vae_encode(
function vae_decode (line 94) | def vae_decode(
function _run_decoder (line 134) | def _run_decoder(
function get_vae_size_scale_factor (line 168) | def get_vae_size_scale_factor(vae: AutoencoderKL) -> float:
function latent_to_pixel_coords (line 190) | def latent_to_pixel_coords(
function latent_to_pixel_coords_from_factors (line 215) | def latent_to_pixel_coords_from_factors(
function normalize_latents (line 228) | def normalize_latents(
function un_normalize_latents (line 239) | def un_normalize_latents(
FILE: ltx_video/models/autoencoders/video_autoencoder.py
class VideoAutoencoder (line 22) | class VideoAutoencoder(AutoencoderKLWrapper):
method from_pretrained (line 24) | def from_pretrained(
method from_config (line 61) | def from_config(config):
method config (line 112) | def config(self):
method is_video_supported (line 135) | def is_video_supported(self):
method downscale_factor (line 142) | def downscale_factor(self):
method to_json_string (line 145) | def to_json_string(self) -> str:
method load_state_dict (line 150) | def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool ...
method last_layer (line 174) | def last_layer(self):
class Encoder (line 185) | class Encoder(nn.Module):
method __init__ (line 208) | def __init__(
method downscale_factor (line 300) | def downscale_factor(self):
method forward (line 313) | def forward(
class Decoder (line 378) | class Decoder(nn.Module):
method __init__ (line 399) | def __init__(
method forward (line 479) | def forward(self, sample: torch.FloatTensor, target_shape) -> torch.Fl...
class DownEncoderBlock3D (line 517) | class DownEncoderBlock3D(nn.Module):
method __init__ (line 518) | def __init__(
method forward (line 560) | def forward(
class UNetMidBlock3D (line 573) | class UNetMidBlock3D(nn.Module):
method __init__ (line 591) | def __init__(
method forward (line 621) | def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
class UpDecoderBlock3D (line 628) | class UpDecoderBlock3D(nn.Module):
method __init__ (line 629) | def __init__(
method forward (line 671) | def forward(
class ResnetBlock3D (line 682) | class ResnetBlock3D(nn.Module):
method __init__ (line 695) | def __init__(
method forward (line 746) | def forward(
class Downsample3D (line 773) | class Downsample3D(nn.Module):
method __init__ (line 774) | def __init__(
method forward (line 796) | def forward(self, x, downsample_in_time=True):
class Upsample3D (line 812) | class Upsample3D(nn.Module):
method __init__ (line 819) | def __init__(self, dims, channels, out_channels=None):
method forward (line 828) | def forward(self, x, upsample_in_time):
function patchify (line 868) | def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
function unpatchify (line 906) | def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=Fal...
function create_video_autoencoder_config (line 934) | def create_video_autoencoder_config(
function create_video_autoencoder_pathify4x4x4_config (line 958) | def create_video_autoencoder_pathify4x4x4_config(
function create_video_autoencoder_pathify4x4_config (line 979) | def create_video_autoencoder_pathify4x4_config(
function test_vae_patchify_unpatchify (line 997) | def test_vae_patchify_unpatchify():
function demo_video_autoencoder_forward_backward (line 1006) | def demo_video_autoencoder_forward_backward():
FILE: ltx_video/models/transformers/attention.py
class BasicTransformerBlock (line 38) | class BasicTransformerBlock(nn.Module):
method __init__ (line 77) | def __init__(
method set_use_tpu_flash_attention (line 184) | def set_use_tpu_flash_attention(self):
method set_chunk_feed_forward (line 193) | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int =...
method forward (line 198) | def forward(
class Attention (line 325) | class Attention(nn.Module):
method __init__ (line 378) | def __init__(
method set_use_tpu_flash_attention (line 526) | def set_use_tpu_flash_attention(self):
method set_processor (line 532) | def set_processor(self, processor: "AttnProcessor") -> None:
method get_processor (line 554) | def get_processor(
method forward (line 660) | def forward(
method batch_to_head_dim (line 720) | def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
method head_to_batch_dim (line 739) | def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) ->...
method get_attention_scores (line 771) | def get_attention_scores(
method prepare_attention_mask (line 825) | def prepare_attention_mask(
method norm_encoder_hidden_states (line 884) | def norm_encoder_hidden_states(
method apply_rotary_emb (line 918) | def apply_rotary_emb(
class AttnProcessor2_0 (line 935) | class AttnProcessor2_0:
method __init__ (line 940) | def __init__(self):
method __call__ (line 943) | def __call__(
class AttnProcessor (line 1117) | class AttnProcessor:
method __call__ (line 1122) | def __call__(
class FeedForward (line 1204) | class FeedForward(nn.Module):
method __init__ (line 1218) | def __init__(
method forward (line 1257) | def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> ...
FILE: ltx_video/models/transformers/embeddings.py
function get_timestep_embedding (line 10) | def get_timestep_embedding(
function get_3d_sincos_pos_embed (line 53) | def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f):
function get_3d_sincos_pos_embed_from_grid (line 66) | def get_3d_sincos_pos_embed_from_grid(embed_dim, grid):
function get_1d_sincos_pos_embed_from_grid (line 79) | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
class SinusoidalPositionalEmbedding (line 103) | class SinusoidalPositionalEmbedding(nn.Module):
method __init__ (line 115) | def __init__(self, embed_dim: int, max_seq_length: int = 32):
method forward (line 126) | def forward(self, x):
FILE: ltx_video/models/transformers/symmetric_patchifier.py
class Patchifier (line 10) | class Patchifier(ConfigMixin, ABC):
method __init__ (line 11) | def __init__(self, patch_size: int):
method patchify (line 16) | def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
method unpatchify (line 20) | def unpatchify(
method patch_size (line 30) | def patch_size(self):
method get_latent_coords (line 33) | def get_latent_coords(
class SymmetricPatchifier (line 54) | class SymmetricPatchifier(Patchifier):
method patchify (line 55) | def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
method unpatchify (line 67) | def unpatchify(
FILE: ltx_video/models/transformers/transformer3d.py
class Transformer3DModelOutput (line 35) | class Transformer3DModelOutput(BaseOutput):
class Transformer3DModel (line 48) | class Transformer3DModel(ModelMixin, ConfigMixin):
method __init__ (line 52) | def __init__(
method set_use_tpu_flash_attention (line 162) | def set_use_tpu_flash_attention(self):
method create_skip_layer_mask (line 173) | def create_skip_layer_mask(
method _set_gradient_checkpointing (line 190) | def _set_gradient_checkpointing(self, module, value=False):
method get_fractional_positions (line 194) | def get_fractional_positions(self, indices_grid):
method precompute_freqs_cis (line 204) | def precompute_freqs_cis(self, indices_grid, spacing="exp"):
method load_state_dict (line 259) | def load_state_dict(
method from_pretrained (line 274) | def from_pretrained(
method forward (line 330) | def forward(
FILE: ltx_video/pipelines/crf_compressor.py
function _encode_single_frame (line 7) | def _encode_single_frame(output_file, image_array: np.ndarray, crf):
function _decode_single_frame (line 24) | def _decode_single_frame(video_file):
function compress (line 34) | def compress(image: torch.Tensor, crf=29):
FILE: ltx_video/pipelines/pipeline_ltx_video.py
function retrieve_timesteps (line 125) | def retrieve_timesteps(
class ConditioningItem (line 194) | class ConditioningItem:
class LTXVideoPipeline (line 213) | class LTXVideoPipeline(DiffusionPipeline):
method __init__ (line 262) | def __init__(
method mask_text_embeddings (line 298) | def mask_text_embeddings(self, emb, mask):
method encode_prompt (line 307) | def encode_prompt(
method prepare_extra_step_kwargs (line 479) | def prepare_extra_step_kwargs(self, generator, eta):
method check_inputs (line 500) | def check_inputs(
method _text_preprocessing (line 586) | def _text_preprocessing(self, text):
method add_noise_to_image_conditioning_latents (line 597) | def add_noise_to_image_conditioning_latents(
method prepare_latents (line 623) | def prepare_latents(
method classify_height_width_bin (line 704) | def classify_height_width_bin(
method resize_and_crop_tensor (line 714) | def resize_and_crop_tensor(
method resize_tensor (line 740) | def resize_tensor(media_items, height, width):
method __call__ (line 754) | def __call__(
method denoising_step (line 1348) | def denoising_step(
method prepare_conditioning (line 1383) | def prepare_conditioning(
method _resize_conditioning_item (line 1590) | def _resize_conditioning_item(
method _get_latent_spatial_position (line 1605) | def _get_latent_spatial_position(
method _handle_non_first_conditioning_sequence (line 1653) | def _handle_non_first_conditioning_sequence(
method trim_conditioning_sequence (line 1728) | def trim_conditioning_sequence(
method tone_map_latents (line 1749) | def tone_map_latents(
function adain_filter_latent (line 1790) | def adain_filter_latent(
class LTXMultiScalePipeline (line 1821) | class LTXMultiScalePipeline:
method _upsample_latents (line 1822) | def _upsample_latents(
method __init__ (line 1836) | def __init__(
method __call__ (line 1843) | def __call__(
FILE: ltx_video/schedulers/rf.py
function linear_quadratic_schedule (line 25) | def linear_quadratic_schedule(num_steps, threshold_noise=0.025, linear_s...
function simple_diffusion_resolution_dependent_timestep_shift (line 49) | def simple_diffusion_resolution_dependent_timestep_shift(
function time_shift (line 69) | def time_shift(mu: float, sigma: float, t: Tensor):
function get_normal_shift (line 73) | def get_normal_shift(
function strech_shifts_to_terminal (line 85) | def strech_shifts_to_terminal(shifts: Tensor, terminal=0.1):
function sd3_resolution_dependent_timestep_shift (line 112) | def sd3_resolution_dependent_timestep_shift(
class TimestepShifter (line 152) | class TimestepShifter(ABC):
method shift_timesteps (line 154) | def shift_timesteps(self, samples_shape: torch.Size, timesteps: Tensor...
class RectifiedFlowSchedulerOutput (line 159) | class RectifiedFlowSchedulerOutput(BaseOutput):
class RectifiedFlowScheduler (line 176) | class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
method __init__ (line 180) | def __init__(
method get_initial_timesteps (line 201) | def get_initial_timesteps(
method shift_timesteps (line 216) | def shift_timesteps(self, samples_shape: torch.Size, timesteps: Tensor...
method set_timesteps (line 227) | def set_timesteps(
method from_pretrained (line 264) | def from_pretrained(pretrained_model_path: Union[str, os.PathLike]):
method scale_model_input (line 288) | def scale_model_input(
method step (line 305) | def step(
method add_noise (line 376) | def add_noise(
FILE: ltx_video/utils/diffusers_config_mapping.py
function make_hashable_key (line 1) | def make_hashable_key(dict_key):
FILE: ltx_video/utils/prompt_enhance_utils.py
function tensor_to_pil (line 47) | def tensor_to_pil(tensor):
function generate_cinematic_prompt (line 64) | def generate_cinematic_prompt(
function _get_first_frames_from_conditioning_item (line 113) | def _get_first_frames_from_conditioning_item(conditioning_item) -> List[...
function _generate_t2v_prompt (line 121) | def _generate_t2v_prompt(
function _generate_i2v_prompt (line 151) | def _generate_i2v_prompt(
function _generate_image_captions (line 188) | def _generate_image_captions(
function _generate_and_decode_prompts (line 211) | def _generate_and_decode_prompts(
FILE: ltx_video/utils/skip_layer_strategy.py
class SkipLayerStrategy (line 4) | class SkipLayerStrategy(Enum):
FILE: ltx_video/utils/torch_utils.py
function append_dims (line 5) | def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
class Identity (line 17) | class Identity(nn.Module):
method __init__ (line 20) | def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused...
method forward (line 24) | def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
FILE: tests/conftest.py
function pytest_make_parametrize_id (line 14) | def pytest_make_parametrize_id(config, val, argname):
function num_latent_channels (line 21) | def num_latent_channels():
function video_autoencoder (line 26) | def video_autoencoder(num_latent_channels):
function transformer_config (line 34) | def transformer_config(num_latent_channels):
function synthetic_ckpt_path (line 67) | def synthetic_ckpt_path(
FILE: tests/test_configs.py
function prompt (line 10) | def prompt():
function test_run_config (line 20) | def test_run_config(tmp_path, prompt, pipeline_config):
FILE: tests/test_inference.py
function input_image_path (line 16) | def input_image_path():
function input_video_path (line 21) | def input_video_path():
function base_inference_config (line 25) | def base_inference_config(tmp_path, pipeline_config):
function base_pipeline_config (line 44) | def base_pipeline_config(synthetic_ckpt_path):
function test_condition_modes (line 66) | def test_condition_modes(
function test_vid2vid (line 93) | def test_vid2vid(tmp_path, input_video_path, base_pipeline_config):
function test_pipeline_on_batch (line 106) | def test_pipeline_on_batch(tmp_path, base_pipeline_config):
function test_prompt_enhancement (line 161) | def test_prompt_enhancement(tmp_path, base_pipeline_config):
FILE: tests/test_scheduler.py
function init_latents_and_scheduler (line 6) | def init_latents_and_scheduler(sampler):
function test_scheduler_default_behavior (line 18) | def test_scheduler_default_behavior(sampler):
function test_scheduler_per_token (line 41) | def test_scheduler_per_token(sampler):
function test_scheduler_t_not_in_list (line 71) | def test_scheduler_t_not_in_list(sampler):
FILE: tests/test_vae.py
function test_encode_decode_shape (line 8) | def test_encode_decode_shape(video_autoencoder, num_latent_channels):
function test_temporal_causality (line 32) | def test_temporal_causality(video_autoencoder):
function test_downscale_factors (line 59) | def test_downscale_factors(
Condensed preview — 54 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (446K chars).
[
{
"path": ".gitattributes",
"chars": 225,
"preview": "*.jpg filter=lfs diff=lfs merge=lfs -text\n*.jpeg filter=lfs diff=lfs merge=lfs -text\n*.png filter=lfs diff=lfs merge=lfs"
},
{
"path": ".github/workflows/pylint.yml",
"chars": 711,
"preview": "name: Ruff\n\non: [push]\n\njobs:\n build:\n runs-on: ubuntu-latest\n strategy:\n matrix:\n python-version: [\""
},
{
"path": ".gitignore",
"chars": 3200,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": ".pre-commit-config.yaml",
"chars": 524,
"preview": "repos:\n - repo: https://github.com/astral-sh/ruff-pre-commit\n # Ruff version.\n rev: v0.2.2\n hooks:\n # Run"
},
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 26536,
"preview": "<div align=\"center\">\n\n# LTX-Video\n\n[](htt"
},
{
"path": "configs/ltxv-13b-0.9.8-dev-fp8.yaml",
"chars": 1398,
"preview": "pipeline_type: multi-scale\ncheckpoint_path: \"ltxv-13b-0.9.8-dev-fp8.safetensors\"\ndownscale_factor: 0.6666666\nspatial_ups"
},
{
"path": "configs/ltxv-13b-0.9.8-dev.yaml",
"chars": 1330,
"preview": "pipeline_type: multi-scale\ncheckpoint_path: \"ltxv-13b-0.9.8-dev.safetensors\"\ndownscale_factor: 0.6666666\nspatial_upscale"
},
{
"path": "configs/ltxv-13b-0.9.8-distilled-fp8.yaml",
"chars": 1147,
"preview": "pipeline_type: multi-scale\ncheckpoint_path: \"ltxv-13b-0.9.8-distilled-fp8.safetensors\"\ndownscale_factor: 0.6666666\nspati"
},
{
"path": "configs/ltxv-13b-0.9.8-distilled.yaml",
"chars": 1080,
"preview": "pipeline_type: multi-scale\ncheckpoint_path: \"ltxv-13b-0.9.8-distilled.safetensors\"\ndownscale_factor: 0.6666666\nspatial_u"
},
{
"path": "configs/ltxv-2b-0.9.1.yaml",
"chars": 737,
"preview": "pipeline_type: base\ncheckpoint_path: \"ltx-video-2b-v0.9.1.safetensors\"\nguidance_scale: 3\nstg_scale: 1\nrescaling_scale: 0"
},
{
"path": "configs/ltxv-2b-0.9.5.yaml",
"chars": 737,
"preview": "pipeline_type: base\ncheckpoint_path: \"ltx-video-2b-v0.9.5.safetensors\"\nguidance_scale: 3\nstg_scale: 1\nrescaling_scale: 0"
},
{
"path": "configs/ltxv-2b-0.9.6-dev.yaml",
"chars": 741,
"preview": "pipeline_type: base\ncheckpoint_path: \"ltxv-2b-0.9.6-dev-04-25.safetensors\"\nguidance_scale: 3\nstg_scale: 1\nrescaling_scal"
},
{
"path": "configs/ltxv-2b-0.9.6-distilled.yaml",
"chars": 721,
"preview": "pipeline_type: base\ncheckpoint_path: \"ltxv-2b-0.9.6-distilled-04-25.safetensors\"\nguidance_scale: 1\nstg_scale: 0\nrescalin"
},
{
"path": "configs/ltxv-2b-0.9.8-distilled-fp8.yaml",
"chars": 1112,
"preview": "pipeline_type: multi-scale\ncheckpoint_path: \"ltxv-2b-0.9.8-distilled-fp8.safetensors\"\ndownscale_factor: 0.6666666\nspatia"
},
{
"path": "configs/ltxv-2b-0.9.8-distilled.yaml",
"chars": 1045,
"preview": "pipeline_type: multi-scale\ncheckpoint_path: \"ltxv-2b-0.9.8-distilled.safetensors\"\ndownscale_factor: 0.6666666\nspatial_up"
},
{
"path": "configs/ltxv-2b-0.9.yaml",
"chars": 735,
"preview": "pipeline_type: base\ncheckpoint_path: \"ltx-video-2b-v0.9.safetensors\"\nguidance_scale: 3\nstg_scale: 1\nrescaling_scale: 0.7"
},
{
"path": "inference.py",
"chars": 277,
"preview": "from transformers import HfArgumentParser\n\nfrom ltx_video.inference import infer, InferenceConfig\n\n\ndef main():\n pars"
},
{
"path": "ltx_video/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ltx_video/inference.py",
"chars": 26715,
"preview": "import os\nimport random\nfrom datetime import datetime\nfrom pathlib import Path\nfrom diffusers.utils import logging\nfrom "
},
{
"path": "ltx_video/models/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ltx_video/models/autoencoders/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ltx_video/models/autoencoders/causal_conv3d.py",
"chars": 1759,
"preview": "from typing import Tuple, Union\n\nimport torch\nimport torch.nn as nn\n\n\nclass CausalConv3d(nn.Module):\n def __init__(\n "
},
{
"path": "ltx_video/models/autoencoders/causal_video_autoencoder.py",
"chars": 52253,
"preview": "import json\nimport os\nfrom functools import partial\nfrom types import SimpleNamespace\nfrom typing import Any, Mapping, O"
},
{
"path": "ltx_video/models/autoencoders/conv_nd_factory.py",
"chars": 2609,
"preview": "from typing import Tuple, Union\n\nimport torch\n\nfrom ltx_video.models.autoencoders.dual_conv3d import DualConv3d\nfrom ltx"
},
{
"path": "ltx_video/models/autoencoders/dual_conv3d.py",
"chars": 6885,
"preview": "import math\nfrom typing import Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom ein"
},
{
"path": "ltx_video/models/autoencoders/latent_upsampler.py",
"chars": 7025,
"preview": "from typing import Optional, Union\nfrom pathlib import Path\nimport os\nimport json\n\nimport torch\nimport torch.nn as nn\nfr"
},
{
"path": "ltx_video/models/autoencoders/pixel_norm.py",
"chars": 307,
"preview": "import torch\nfrom torch import nn\n\n\nclass PixelNorm(nn.Module):\n def __init__(self, dim=1, eps=1e-8):\n super(P"
},
{
"path": "ltx_video/models/autoencoders/pixel_shuffle.py",
"chars": 1043,
"preview": "import torch.nn as nn\nfrom einops import rearrange\n\n\nclass PixelShuffleND(nn.Module):\n def __init__(self, dims, upsca"
},
{
"path": "ltx_video/models/autoencoders/vae.py",
"chars": 14404,
"preview": "from typing import Optional, Union\n\nimport torch\nimport inspect\nimport math\nimport torch.nn as nn\nfrom diffusers import "
},
{
"path": "ltx_video/models/autoencoders/vae_encode.py",
"chars": 8779,
"preview": "from typing import Tuple\nimport torch\nfrom diffusers import AutoencoderKL\nfrom einops import rearrange\nfrom torch import"
},
{
"path": "ltx_video/models/autoencoders/video_autoencoder.py",
"chars": 35462,
"preview": "import json\nimport os\nfrom functools import partial\nfrom types import SimpleNamespace\nfrom typing import Any, Mapping, O"
},
{
"path": "ltx_video/models/transformers/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ltx_video/models/transformers/attention.py",
"chars": 52315,
"preview": "import inspect\nfrom importlib import import_module\nfrom typing import Any, Dict, Optional, Tuple\n\nimport torch\nimport to"
},
{
"path": "ltx_video/models/transformers/embeddings.py",
"chars": 4471,
"preview": "# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py\nimport math\n\nimpor"
},
{
"path": "ltx_video/models/transformers/symmetric_patchifier.py",
"chars": 2774,
"preview": "from abc import ABC, abstractmethod\nfrom typing import Tuple\n\nimport torch\nfrom diffusers.configuration_utils import Con"
},
{
"path": "ltx_video/models/transformers/transformer3d.py",
"chars": 22744,
"preview": "# Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.p"
},
{
"path": "ltx_video/pipelines/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ltx_video/pipelines/crf_compressor.py",
"chars": 1533,
"preview": "import av\nimport torch\nimport io\nimport numpy as np\n\n\ndef _encode_single_frame(output_file, image_array: np.ndarray, crf"
},
{
"path": "ltx_video/pipelines/pipeline_ltx_video.py",
"chars": 83428,
"preview": "# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_"
},
{
"path": "ltx_video/schedulers/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ltx_video/schedulers/rf.py",
"chars": 15283,
"preview": "import math\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom typing import Callable, Optional,"
},
{
"path": "ltx_video/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ltx_video/utils/diffusers_config_mapping.py",
"chars": 5578,
"preview": "def make_hashable_key(dict_key):\n def convert_value(value):\n if isinstance(value, list):\n return tu"
},
{
"path": "ltx_video/utils/prompt_enhance_utils.py",
"chars": 7651,
"preview": "import logging\nfrom typing import Union, List, Optional\n\nimport torch\nfrom PIL import Image\n\nlogger = logging.getLogger("
},
{
"path": "ltx_video/utils/skip_layer_strategy.py",
"chars": 169,
"preview": "from enum import Enum, auto\n\n\nclass SkipLayerStrategy(Enum):\n AttentionSkip = auto()\n AttentionValues = auto()\n "
},
{
"path": "ltx_video/utils/torch_utils.py",
"chars": 822,
"preview": "import torch\nfrom torch import nn\n\n\ndef append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:\n \"\"\"Appends d"
},
{
"path": "pyproject.toml",
"chars": 849,
"preview": "[build-system]\nrequires = [\"setuptools>=42\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"ltx-vid"
},
{
"path": "tests/conftest.py",
"chars": 3050,
"preview": "import json\nimport pytest\nimport safetensors.torch\nimport torch\n\nfrom ltx_video.models.autoencoders.causal_video_autoenc"
},
{
"path": "tests/test_configs.py",
"chars": 958,
"preview": "import pytest\nfrom pathlib import Path\n\nfrom ltx_video.inference import infer, InferenceConfig\n\nCONFIGS_DIR = Path(__fil"
},
{
"path": "tests/test_inference.py",
"chars": 8780,
"preview": "from dataclasses import asdict\nimport pytest\nimport torch\nimport yaml\n\nfrom ltx_video.inference import (\n create_ltx_"
},
{
"path": "tests/test_scheduler.py",
"chars": 3688,
"preview": "import pytest\nimport torch\nfrom ltx_video.schedulers.rf import RectifiedFlowScheduler\n\n\ndef init_latents_and_scheduler(s"
},
{
"path": "tests/test_vae.py",
"chars": 3173,
"preview": "import pytest\nimport torch\nfrom ltx_video.models.autoencoders.causal_video_autoencoder import (\n CausalVideoAutoencod"
},
{
"path": "tests/utils/.gitattributes",
"chars": 42,
"preview": "*.mp4 filter=lfs diff=lfs merge=lfs -text\n"
}
]
About this extraction
This page contains the full source code of the Lightricks/LTX-Video GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 54 files (418.1 KB), approximately 100.7k tokens, and a symbol index with 287 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.