Repository: Stability-AI/generative-models Branch: main Commit: e8cd657656fa Files: 120 Total size: 778.4 KB Directory structure: gitextract_3mk8n7n3/ ├── .github/ │ └── workflows/ │ ├── black.yml │ ├── test-build.yaml │ └── test-inference.yml ├── .gitignore ├── CODEOWNERS ├── LICENSE-CODE ├── README.md ├── configs/ │ ├── example_training/ │ │ ├── autoencoder/ │ │ │ └── kl-f4/ │ │ │ ├── imagenet-attnfree-logvar.yaml │ │ │ └── imagenet-kl_f8_8chn.yaml │ │ ├── imagenet-f8_cond.yaml │ │ ├── toy/ │ │ │ ├── cifar10_cond.yaml │ │ │ ├── mnist.yaml │ │ │ ├── mnist_cond.yaml │ │ │ ├── mnist_cond_discrete_eps.yaml │ │ │ ├── mnist_cond_l1_loss.yaml │ │ │ └── mnist_cond_with_ema.yaml │ │ ├── txt2img-clipl-legacy-ucg-training.yaml │ │ └── txt2img-clipl.yaml │ └── inference/ │ ├── sd_xl_base.yaml │ ├── sd_xl_refiner.yaml │ ├── sv3d_p.yaml │ ├── sv3d_u.yaml │ ├── svd.yaml │ └── svd_image_decoder.yaml ├── main.py ├── model_licenses/ │ ├── LICENSE-SDXL-Turbo │ ├── LICENSE-SDXL0.9 │ ├── LICENSE-SDXL1.0 │ ├── LICENSE-SV3D │ └── LICENSE-SVD ├── pyproject.toml ├── pytest.ini ├── requirements/ │ └── pt2.txt ├── scripts/ │ ├── __init__.py │ ├── demo/ │ │ ├── __init__.py │ │ ├── detect.py │ │ ├── discretization.py │ │ ├── gradio_app.py │ │ ├── gradio_app_sv4d.py │ │ ├── sampling.py │ │ ├── streamlit_helpers.py │ │ ├── sv3d_helpers.py │ │ ├── sv4d_helpers.py │ │ ├── turbo.py │ │ └── video_sampling.py │ ├── sampling/ │ │ ├── configs/ │ │ │ ├── sv3d_p.yaml │ │ │ ├── sv3d_u.yaml │ │ │ ├── sv4d.yaml │ │ │ ├── sv4d2.yaml │ │ │ ├── sv4d2_8views.yaml │ │ │ ├── svd.yaml │ │ │ ├── svd_image_decoder.yaml │ │ │ ├── svd_xt.yaml │ │ │ ├── svd_xt_1_1.yaml │ │ │ └── svd_xt_image_decoder.yaml │ │ ├── simple_video_sample.py │ │ ├── simple_video_sample_4d.py │ │ └── simple_video_sample_4d2.py │ ├── tests/ │ │ └── attention.py │ └── util/ │ ├── __init__.py │ └── detection/ │ ├── __init__.py │ ├── nsfw_and_watermark_dectection.py │ ├── p_head_v1.npz │ └── w_head_v1.npz ├── sgm/ │ ├── __init__.py │ ├── data/ │ │ ├── __init__.py │ │ ├── cifar10.py │ │ ├── dataset.py │ │ └── mnist.py │ ├── inference/ │ │ ├── api.py │ │ └── helpers.py │ ├── lr_scheduler.py │ ├── models/ │ │ ├── __init__.py │ │ ├── autoencoder.py │ │ └── diffusion.py │ ├── modules/ │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── autoencoding/ │ │ │ ├── __init__.py │ │ │ ├── losses/ │ │ │ │ ├── __init__.py │ │ │ │ ├── discriminator_loss.py │ │ │ │ └── lpips.py │ │ │ ├── lpips/ │ │ │ │ ├── __init__.py │ │ │ │ ├── loss/ │ │ │ │ │ ├── .gitignore │ │ │ │ │ ├── LICENSE │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── lpips.py │ │ │ │ ├── model/ │ │ │ │ │ ├── LICENSE │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── model.py │ │ │ │ ├── util.py │ │ │ │ └── vqperceptual.py │ │ │ ├── regularizers/ │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ └── quantize.py │ │ │ └── temporal_ae.py │ │ ├── diffusionmodules/ │ │ │ ├── __init__.py │ │ │ ├── denoiser.py │ │ │ ├── denoiser_scaling.py │ │ │ ├── denoiser_weighting.py │ │ │ ├── discretizer.py │ │ │ ├── guiders.py │ │ │ ├── loss.py │ │ │ ├── loss_weighting.py │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ ├── sampling.py │ │ │ ├── sampling_utils.py │ │ │ ├── sigma_sampling.py │ │ │ ├── util.py │ │ │ ├── video_model.py │ │ │ └── wrappers.py │ │ ├── distributions/ │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── encoders/ │ │ │ ├── __init__.py │ │ │ └── modules.py │ │ ├── spacetime_attention.py │ │ └── video_attention.py │ └── util.py └── tests/ └── inference/ └── test_inference.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/black.yml ================================================ name: Run black on: [pull_request] jobs: lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - name: Install venv run: | sudo apt-get -y install python3.10-venv - uses: psf/black@stable with: options: "--check --verbose -l88" src: "./sgm ./scripts ./main.py" ================================================ FILE: .github/workflows/test-build.yaml ================================================ name: Build package on: push: branches: [ main ] pull_request: jobs: build: name: Build runs-on: ubuntu-latest strategy: fail-fast: false matrix: python-version: ["3.8", "3.10"] requirements-file: ["pt2", "pt13"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip pip install -r requirements/${{ matrix.requirements-file }}.txt pip install . ================================================ FILE: .github/workflows/test-inference.yml ================================================ name: Test inference on: pull_request: push: branches: - main jobs: test: name: "Test inference" # This action is designed only to run on the Stability research cluster at this time, so many assumptions are made about the environment if: github.repository == 'stability-ai/generative-models' runs-on: [self-hosted, slurm, g40] steps: - uses: actions/checkout@v3 - name: "Symlink checkpoints" run: ln -s ${{vars.SGM_CHECKPOINTS_PATH}} checkpoints - name: "Setup python" uses: actions/setup-python@v4 with: python-version: "3.10" - name: "Install Hatch" run: pip install hatch - name: "Run inference tests" run: hatch run ci:test-inference --junit-xml test-results.xml - name: Surface failing tests if: always() uses: pmeier/pytest-results-action@main with: path: test-results.xml summary: true display-options: fEX fail-on-empty: true ================================================ FILE: .gitignore ================================================ # extensions *.egg-info *.py[cod] # envs .pt13 .pt2 # directories /checkpoints /dist /outputs /build /src /.vscode **/__pycache__/ ================================================ FILE: CODEOWNERS ================================================ .github @Stability-AI/infrastructure ================================================ FILE: LICENSE-CODE ================================================ MIT License Copyright (c) 2023 Stability AI Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # Generative Models by Stability AI ![sample1](assets/000.jpg) ## News **May 20, 2025** - We are releasing **[Stable Video 4D 2.0 (SV4D 2.0)](https://huggingface.co/stabilityai/sv4d2.0)**, an enhanced video-to-4D diffusion model for high-fidelity novel-view video synthesis and 4D asset generation. For research purposes: - **SV4D 2.0** was trained to generate 48 frames (12 video frames x 4 camera views) at 576x576 resolution, given a 12-frame input video of the same size, ideally consisting of white-background images of a moving object. - Compared to our previous 4D model [SV4D](https://huggingface.co/stabilityai/sv4d), **SV4D 2.0** can generate videos with higher fidelity, sharper details during motion, and better spatio-temporal consistency. It also generalizes much better to real-world videos. Moreover, it does not rely on refernce multi-view of the first frame generated by SV3D, making it more robust to self-occlusions. - To generate longer novel-view videos, we autoregressively generate 12 frames at a time and use the previous generation as conditioning views for the remaining frames. - Please check our [project page](https://sv4d20.github.io), [arxiv paper](https://arxiv.org/pdf/2503.16396) and [video summary](https://www.youtube.com/watch?v=dtqj-s50ynU) for more details. **QUICKSTART** : - `python scripts/sampling/simple_video_sample_4d2.py --input_path assets/sv4d_videos/camel.gif --output_folder outputs` (after downloading [sv4d2.safetensors](https://huggingface.co/stabilityai/sv4d2.0) from HuggingFace into `checkpoints/`) To run **SV4D 2.0** on a single input video of 21 frames: - Download SV4D 2.0 model (`sv4d2.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d2.0) to `checkpoints/`: `huggingface-cli download stabilityai/sv4d2.0 sv4d2.safetensors --local-dir checkpoints` - Run inference: `python scripts/sampling/simple_video_sample_4d2.py --input_path ` - `input_path` : The input video `` can be - a single video file in `gif` or `mp4` format, such as `assets/sv4d_videos/camel.gif`, or - a folder containing images of video frames in `.jpg`, `.jpeg`, or `.png` format, or - a file name pattern matching images of video frames. - `num_steps` : default is 50, can decrease to it to shorten sampling time. - `elevations_deg` : specified elevations (reletive to input view), default is 0.0 (same as input view). - **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Clipdrop](https://clipdrop.co/) or [SAM2](https://github.com/facebookresearch/segment-anything-2) before running SV4D. - **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--encoding_t=1` (of frames encoded at a time) and `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`. Notes: - We also train a 8-view model that generates 5 frames x 8 views at a time (same as SV4D). - Download the model from huggingface: `huggingface-cli download stabilityai/sv4d2.0 sv4d2_8views.safetensors --local-dir checkpoints` - Run inference: `python scripts/sampling/simple_video_sample_4d2.py --model_path checkpoints/sv4d2_8views.safetensors --input_path assets/sv4d_videos/chest.gif --output_folder outputs` - The 5x8 model takes 5 frames of input at a time. But the inference scripts for both model take 21-frame video as input by default (same as SV3D and SV4D), we run the model autoregressively until we generate 21 frames. - Install dependencies before running: ``` python3.10 -m venv .generativemodels source .generativemodels/bin/activate pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # check CUDA version pip3 install -r requirements/pt2.txt pip3 install . pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata ``` ![tile](assets/sv4d2.gif) **July 24, 2024** - We are releasing **[Stable Video 4D (SV4D)](https://huggingface.co/stabilityai/sv4d)**, a video-to-4D diffusion model for novel-view video synthesis. For research purposes: - **SV4D** was trained to generate 40 frames (5 video frames x 8 camera views) at 576x576 resolution, given 5 context frames (the input video), and 8 reference views (synthesised from the first frame of the input video, using a multi-view diffusion model like SV3D) of the same size, ideally white-background images with one object. - To generate longer novel-view videos (21 frames), we propose a novel sampling method using SV4D, by first sampling 5 anchor frames and then densely sampling the remaining frames while maintaining temporal consistency. - To run the community-build gradio demo locally, run `python -m scripts.demo.gradio_app_sv4d`. - Please check our [project page](https://sv4d.github.io), [tech report](https://sv4d.github.io/static/sv4d_technical_report.pdf) and [video summary](https://www.youtube.com/watch?v=RBP8vdAWTgk) for more details. **QUICKSTART** : `python scripts/sampling/simple_video_sample_4d.py --input_path assets/sv4d_videos/test_video1.mp4 --output_folder outputs/sv4d` (after downloading [sv4d.safetensors](https://huggingface.co/stabilityai/sv4d) and [sv3d_u.safetensors](https://huggingface.co/stabilityai/sv3d) from HuggingFace into `checkpoints/`) To run **SV4D** on a single input video of 21 frames: - Download SV3D models (`sv3d_u.safetensors` and `sv3d_p.safetensors`) from [here](https://huggingface.co/stabilityai/sv3d) and SV4D model (`sv4d.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d) to `checkpoints/` - Run `python scripts/sampling/simple_video_sample_4d.py --input_path ` - `input_path` : The input video `` can be - a single video file in `gif` or `mp4` format, such as `assets/sv4d_videos/test_video1.mp4`, or - a folder containing images of video frames in `.jpg`, `.jpeg`, or `.png` format, or - a file name pattern matching images of video frames. - `num_steps` : default is 20, can increase to 50 for better quality but longer sampling time. - `sv3d_version` : To specify the SV3D model to generate reference multi-views, set `--sv3d_version=sv3d_u` for SV3D_u or `--sv3d_version=sv3d_p` for SV3D_p. - `elevations_deg` : To generate novel-view videos at a specified elevation (default elevation is 10) using SV3D_p (default is SV3D_u), run `python scripts/sampling/simple_video_sample_4d.py --input_path assets/sv4d_videos/test_video1.mp4 --sv3d_version sv3d_p --elevations_deg 30.0` - **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Clipdrop](https://clipdrop.co/) or [SAM2](https://github.com/facebookresearch/segment-anything-2) before running SV4D. - **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--encoding_t=1` (of frames encoded at a time) and `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`. ![tile](assets/sv4d.gif) **March 18, 2024** - We are releasing **[SV3D](https://huggingface.co/stabilityai/sv3d)**, an image-to-video model for novel multi-view synthesis, for research purposes: - **SV3D** was trained to generate 21 frames at resolution 576x576, given 1 context frame of the same size, ideally a white-background image with one object. - **SV3D_u**: This variant generates orbital videos based on single image inputs without camera conditioning.. - **SV3D_p**: Extending the capability of **SVD3_u**, this variant accommodates both single images and orbital views allowing for the creation of 3D video along specified camera paths. - We extend the streamlit demo `scripts/demo/video_sampling.py` and the standalone python script `scripts/sampling/simple_video_sample.py` for inference of both models. - Please check our [project page](https://sv3d.github.io), [tech report](https://sv3d.github.io/static/paper.pdf) and [video summary](https://youtu.be/Zqw4-1LcfWg) for more details. To run **SV3D_u** on a single image: - Download `sv3d_u.safetensors` from https://huggingface.co/stabilityai/sv3d to `checkpoints/sv3d_u.safetensors` - Run `python scripts/sampling/simple_video_sample.py --input_path --version sv3d_u` To run **SV3D_p** on a single image: - Download `sv3d_p.safetensors` from https://huggingface.co/stabilityai/sv3d to `checkpoints/sv3d_p.safetensors` 1. Generate static orbit at a specified elevation eg. 10.0 : `python scripts/sampling/simple_video_sample.py --input_path --version sv3d_p --elevations_deg 10.0` 2. Generate dynamic orbit at a specified elevations and azimuths: specify sequences of 21 elevations (in degrees) to `elevations_deg` ([-90, 90]), and 21 azimuths (in degrees) to `azimuths_deg` [0, 360] in sorted order from 0 to 360. For example: `python scripts/sampling/simple_video_sample.py --input_path --version sv3d_p --elevations_deg [] --azimuths_deg []` To run SVD or SV3D on a streamlit server: `streamlit run scripts/demo/video_sampling.py` ![tile](assets/sv3d.gif) **November 28, 2023** - We are releasing SDXL-Turbo, a lightning fast text-to image model. Alongside the model, we release a [technical report](https://stability.ai/research/adversarial-diffusion-distillation) - Usage: - Follow the installation instructions or update the existing environment with `pip install streamlit-keyup`. - Download the [weights](https://huggingface.co/stabilityai/sdxl-turbo) and place them in the `checkpoints/` directory. - Run `streamlit run scripts/demo/turbo.py`. ![tile](assets/turbo_tile.png) **November 21, 2023** - We are releasing Stable Video Diffusion, an image-to-video model, for research purposes: - [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid): This model was trained to generate 14 frames at resolution 576x1024 given a context frame of the same size. We use the standard image encoder from SD 2.1, but replace the decoder with a temporally-aware `deflickering decoder`. - [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt): Same architecture as `SVD` but finetuned for 25 frame generation. - You can run the community-build gradio demo locally by running `python -m scripts.demo.gradio_app`. - We provide a streamlit demo `scripts/demo/video_sampling.py` and a standalone python script `scripts/sampling/simple_video_sample.py` for inference of both models. - Alongside the model, we release a [technical report](https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets). ![tile](assets/tile.gif) **July 26, 2023** - We are releasing two new open models with a permissive [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0) (see [Inference](#inference) for file hashes): - [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0): An improved version over `SDXL-base-0.9`. - [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0): An improved version over `SDXL-refiner-0.9`. ![sample2](assets/001_with_eval.png) **July 4, 2023** - A technical report on SDXL is now available [here](https://arxiv.org/abs/2307.01952). **June 22, 2023** - We are releasing two new diffusion models for research purposes: - `SDXL-base-0.9`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip) and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses the OpenCLIP model. - `SDXL-refiner-0.9`: The refiner has been trained to denoise small noise levels of high quality data and as such is not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model. If you would like to access these models for your research, please apply using one of the following links: [SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9). This means that you can apply for any of the two links - and if you are granted - you can access both. Please log in to your Hugging Face Account with your organization email to request access. **We plan to do a full release soon (July).** ## The codebase ### General Philosophy Modularity is king. This repo implements a config-driven approach where we build and combine submodules by calling `instantiate_from_config()` on objects defined in yaml configs. See `configs/` for many examples. ### Changelog from the old `ldm` codebase For training, we use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), but it should be easy to use other training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`, now `DiffusionEngine`) has been cleaned up: - No more extensive subclassing! We now handle all types of conditioning inputs (vectors, sequences and spatial conditionings, and all combinations thereof) in a single class: `GeneralConditioner`, see `sgm/modules/encoders/modules.py`. - We separate guiders (such as classifier-free guidance, see `sgm/modules/diffusionmodules/guiders.py`) from the samplers (`sgm/modules/diffusionmodules/sampling.py`), and the samplers are independent of the model. - We adopt the ["denoiser framework"](https://arxiv.org/abs/2206.00364) for both training and inference (most notable change is probably now the option to train continuous time models): * Discrete times models (denoisers) are simply a special case of continuous time models (denoisers); see `sgm/modules/diffusionmodules/denoiser.py`. * The following features are now independent: weighting of the diffusion loss function (`sgm/modules/diffusionmodules/denoiser_weighting.py`), preconditioning of the network (`sgm/modules/diffusionmodules/denoiser_scaling.py`), and sampling of noise levels during training (`sgm/modules/diffusionmodules/sigma_sampling.py`). - Autoencoding models have also been cleaned up. ## Installation: #### 1. Clone the repo ```shell git clone https://github.com/Stability-AI/generative-models.git cd generative-models ``` #### 2. Setting up the virtualenv This is assuming you have navigated to the `generative-models` root after cloning it. **NOTE:** This is tested under `python3.10`. For other python versions, you might encounter version conflicts. **PyTorch 2.0** ```shell # install required packages from pypi python3 -m venv .pt2 source .pt2/bin/activate pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 pip3 install -r requirements/pt2.txt ``` #### 3. Install `sgm` ```shell pip3 install . ``` #### 4. Install `sdata` for training ```shell pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata ``` ## Packaging This repository uses PEP 517 compliant packaging using [Hatch](https://hatch.pypa.io/latest/). To build a distributable wheel, install `hatch` and run `hatch build` (specifying `-t wheel` will skip building a sdist, which is not necessary). ``` pip install hatch hatch build -t wheel ``` You will find the built package in `dist/`. You can install the wheel with `pip install dist/*.whl`. Note that the package does **not** currently specify dependencies; you will need to install the required packages, depending on your use case and PyTorch version, manually. ## Inference We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling in `scripts/demo/sampling.py`. We provide file hashes for the complete file as well as for only the saved tensors in the file ( see [Model Spec](https://github.com/Stability-AI/ModelSpec) for a script to evaluate that). The following models are currently supported: - [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) ``` File Hash (sha256): 31e35c80fc4829d14f90153f4c74cd59c90b779f6afe05a74cd6120b893f7e5b Tensordata Hash (sha256): 0xd7a9105a900fd52748f20725fe52fe52b507fd36bee4fc107b1550a26e6ee1d7 ``` - [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0) ``` File Hash (sha256): 7440042bbdc8a24813002c09b6b69b64dc90fded4472613437b7f55f9b7d9c5f Tensordata Hash (sha256): 0x1a77d21bebc4b4de78c474a90cb74dc0d2217caf4061971dbfa75ad406b75d81 ``` - [SDXL-base-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9) - [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9) **Weights for SDXL**: **SDXL-1.0:** The weights of SDXL-1.0 are available (subject to a [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0)) here: - base model: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/ - refiner model: https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/ **SDXL-0.9:** The weights of SDXL-0.9 are available and subject to a [research license](model_licenses/LICENSE-SDXL0.9). If you would like to access these models for your research, please apply using one of the following links: [SDXL-base-0.9 model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9). This means that you can apply for any of the two links - and if you are granted - you can access both. Please log in to your Hugging Face Account with your organization email to request access. After obtaining the weights, place them into `checkpoints/`. Next, start the demo using ``` streamlit run scripts/demo/sampling.py --server.port ``` ### Invisible Watermark Detection Images generated with our code use the [invisible-watermark](https://github.com/ShieldMnt/invisible-watermark/) library to embed an invisible watermark into the model output. We also provide a script to easily detect that watermark. Please note that this watermark is not the same as in previous Stable Diffusion 1.x/2.x versions. To run the script you need to either have a working installation as above or try an _experimental_ import using only a minimal amount of packages: ```bash python -m venv .detect source .detect/bin/activate pip install "numpy>=1.17" "PyWavelets>=1.1.1" "opencv-python>=4.1.0.25" pip install --no-deps invisible-watermark ``` To run the script you need to have a working installation as above. The script is then useable in the following ways (don't forget to activate your virtual environment beforehand, e.g. `source .pt1/bin/activate`): ```bash # test a single file python scripts/demo/detect.py # test multiple files at once python scripts/demo/detect.py ... # test all files in a specific folder python scripts/demo/detect.py /* ``` ## Training: We are providing example training configs in `configs/example_training`. To launch a training, run ``` python main.py --base configs/ configs/ ``` where configs are merged from left to right (later configs overwrite the same values). This can be used to combine model, training and data configs. However, all of them can also be defined in a single config. For example, to run a class-conditional pixel-based diffusion model training on MNIST, run ```bash python main.py --base configs/example_training/toy/mnist_cond.yaml ``` **NOTE 1:** Using the non-toy-dataset configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml` and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depending on the used dataset (which is expected to stored in tar-file in the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search for comments containing `USER:` in the respective config. **NOTE 2:** This repository supports both `pytorch1.13` and `pytorch2`for training generative models. However for autoencoder training as e.g. in `configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml`, only `pytorch1.13` is supported. **NOTE 3:** Training latent generative models (as e.g. in `configs/example_training/imagenet-f8_cond.yaml`) requires retrieving the checkpoint from [Hugging Face](https://huggingface.co/stabilityai/sdxl-vae/tree/main) and replacing the `CKPT_PATH` placeholder in [this line](configs/example_training/imagenet-f8_cond.yaml#81). The same is to be done for the provided text-to-image configs. ### Building New Diffusion Models #### Conditioner The `GeneralConditioner` is configured through the `conditioner_config`. Its only attribute is `emb_models`, a list of different embedders (all inherited from `AbstractEmbModel`) that are used to condition the generative model. All embedders should define whether or not they are trainable (`is_trainable`, default `False`), a classifier-free guidance dropout rate is used (`ucg_rate`, default `0`), and an input key (`input_key`), for example, `txt` for text-conditioning or `cls` for class-conditioning. When computing conditionings, the embedder will get `batch[input_key]` as input. We currently support two to four dimensional conditionings and conditionings of different embedders are concatenated appropriately. Note that the order of the embedders in the `conditioner_config` is important. #### Network The neural network is set through the `network_config`. This used to be called `unet_config`, which is not general enough as we plan to experiment with transformer-based diffusion backbones. #### Loss The loss is configured through `loss_config`. For standard diffusion model training, you will have to set `sigma_sampler_config`. #### Sampler config As discussed above, the sampler is independent of the model. In the `sampler_config`, we set the type of numerical solver, number of steps, type of discretization, as well as, for example, guidance wrappers for classifier-free guidance. ### Dataset Handling For large scale training we recommend using the data pipelines from our [data pipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement and automatically included when following the steps from the [Installation section](#installation). Small map-style datasets should be defined here in the repository (e.g., MNIST, CIFAR-10, ...), and return a dict of data keys/values, e.g., ```python example = {"jpg": x, # this is a tensor -1...1 chw "txt": "a beautiful image"} ``` where we expect images in -1...1, channel-first format. ================================================ FILE: configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml ================================================ model: base_learning_rate: 4.5e-6 target: sgm.models.autoencoder.AutoencodingEngine params: input_key: jpg monitor: val/rec_loss loss_config: target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator params: perceptual_weight: 0.25 disc_start: 20001 disc_weight: 0.5 learn_logvar: True regularization_weights: kl_loss: 1.0 regularizer_config: target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer encoder_config: target: sgm.modules.diffusionmodules.model.Encoder params: attn_type: none double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4] num_res_blocks: 4 attn_resolutions: [] dropout: 0.0 decoder_config: target: sgm.modules.diffusionmodules.model.Decoder params: ${model.params.encoder_config.params} data: target: sgm.data.dataset.StableDataModuleFromConfig params: train: datapipeline: urls: - DATA-PATH pipeline_config: shardshuffle: 10000 sample_shuffle: 10000 decoders: - pil postprocessors: - target: sdata.mappers.TorchVisionImageTransforms params: key: jpg transforms: - target: torchvision.transforms.Resize params: size: 256 interpolation: 3 - target: torchvision.transforms.ToTensor - target: sdata.mappers.Rescaler - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare params: h_key: height w_key: width loader: batch_size: 8 num_workers: 4 lightning: strategy: target: pytorch_lightning.strategies.DDPStrategy params: find_unused_parameters: True modelcheckpoint: params: every_n_train_steps: 5000 callbacks: metrics_over_trainsteps_checkpoint: params: every_n_train_steps: 50000 image_logger: target: main.ImageLogger params: enable_autocast: False batch_frequency: 1000 max_images: 8 increase_log_steps: True trainer: devices: 0, limit_val_batches: 50 benchmark: True accumulate_grad_batches: 1 val_check_interval: 10000 ================================================ FILE: configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml ================================================ model: base_learning_rate: 4.5e-6 target: sgm.models.autoencoder.AutoencodingEngine params: input_key: jpg monitor: val/loss/rec disc_start_iter: 0 encoder_config: target: sgm.modules.diffusionmodules.model.Encoder params: attn_type: vanilla-xformers double_z: true z_channels: 8 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 decoder_config: target: sgm.modules.diffusionmodules.model.Decoder params: ${model.params.encoder_config.params} regularizer_config: target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer loss_config: target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator params: perceptual_weight: 0.25 disc_start: 20001 disc_weight: 0.5 learn_logvar: True regularization_weights: kl_loss: 1.0 data: target: sgm.data.dataset.StableDataModuleFromConfig params: train: datapipeline: urls: - DATA-PATH pipeline_config: shardshuffle: 10000 sample_shuffle: 10000 decoders: - pil postprocessors: - target: sdata.mappers.TorchVisionImageTransforms params: key: jpg transforms: - target: torchvision.transforms.Resize params: size: 256 interpolation: 3 - target: torchvision.transforms.ToTensor - target: sdata.mappers.Rescaler - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare params: h_key: height w_key: width loader: batch_size: 8 num_workers: 4 lightning: strategy: target: pytorch_lightning.strategies.DDPStrategy params: find_unused_parameters: True modelcheckpoint: params: every_n_train_steps: 5000 callbacks: metrics_over_trainsteps_checkpoint: params: every_n_train_steps: 50000 image_logger: target: main.ImageLogger params: enable_autocast: False batch_frequency: 1000 max_images: 8 increase_log_steps: True trainer: devices: 0, limit_val_batches: 50 benchmark: True accumulate_grad_batches: 1 val_check_interval: 10000 ================================================ FILE: configs/example_training/imagenet-f8_cond.yaml ================================================ model: base_learning_rate: 1.0e-4 target: sgm.models.diffusion.DiffusionEngine params: scale_factor: 0.13025 disable_first_stage_autocast: True log_keys: - cls scheduler_config: target: sgm.lr_scheduler.LambdaLinearScheduler params: warm_up_steps: [10000] cycle_lengths: [10000000000000] f_start: [1.e-6] f_max: [1.] f_min: [1.] denoiser_config: target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser params: num_idx: 1000 scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling discretization_config: target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization network_config: target: sgm.modules.diffusionmodules.openaimodel.UNetModel params: use_checkpoint: True in_channels: 4 out_channels: 4 model_channels: 256 attention_resolutions: [1, 2, 4] num_res_blocks: 2 channel_mult: [1, 2, 4] num_head_channels: 64 num_classes: sequential adm_in_channels: 1024 transformer_depth: 1 context_dim: 1024 spatial_transformer_attn_type: softmax-xformers conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - is_trainable: True input_key: cls ucg_rate: 0.2 target: sgm.modules.encoders.modules.ClassEmbedder params: add_sequence_dim: True embed_dim: 1024 n_classes: 1000 - is_trainable: False ucg_rate: 0.2 input_key: original_size_as_tuple target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - is_trainable: False input_key: crop_coords_top_left ucg_rate: 0.2 target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 first_stage_config: target: sgm.models.autoencoder.AutoencoderKL params: ckpt_path: CKPT_PATH embed_dim: 4 monitor: val/rec_loss ddconfig: attn_type: vanilla-xformers double_z: true z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 lossconfig: target: torch.nn.Identity loss_fn_config: target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss params: loss_weighting_config: target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting sigma_sampler_config: target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling params: num_idx: 1000 discretization_config: target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization sampler_config: target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler params: num_steps: 50 discretization_config: target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization guider_config: target: sgm.modules.diffusionmodules.guiders.VanillaCFG params: scale: 5.0 data: target: sgm.data.dataset.StableDataModuleFromConfig params: train: datapipeline: urls: # USER: adapt this path the root of your custom dataset - DATA_PATH pipeline_config: shardshuffle: 10000 sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM decoders: - pil postprocessors: - target: sdata.mappers.TorchVisionImageTransforms params: key: jpg # USER: you might wanna adapt this for your custom dataset transforms: - target: torchvision.transforms.Resize params: size: 256 interpolation: 3 - target: torchvision.transforms.ToTensor - target: sdata.mappers.Rescaler - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare params: h_key: height # USER: you might wanna adapt this for your custom dataset w_key: width # USER: you might wanna adapt this for your custom dataset loader: batch_size: 64 num_workers: 6 lightning: modelcheckpoint: params: every_n_train_steps: 5000 callbacks: metrics_over_trainsteps_checkpoint: params: every_n_train_steps: 25000 image_logger: target: main.ImageLogger params: disabled: False enable_autocast: False batch_frequency: 1000 max_images: 8 increase_log_steps: True log_first_step: False log_images_kwargs: use_ema_scope: False N: 8 n_rows: 2 trainer: devices: 0, benchmark: True num_sanity_val_steps: 0 accumulate_grad_batches: 1 max_epochs: 1000 ================================================ FILE: configs/example_training/toy/cifar10_cond.yaml ================================================ model: base_learning_rate: 1.0e-4 target: sgm.models.diffusion.DiffusionEngine params: denoiser_config: target: sgm.modules.diffusionmodules.denoiser.Denoiser params: scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling params: sigma_data: 1.0 network_config: target: sgm.modules.diffusionmodules.openaimodel.UNetModel params: in_channels: 3 out_channels: 3 model_channels: 32 attention_resolutions: [] num_res_blocks: 4 channel_mult: [1, 2, 2] num_head_channels: 32 num_classes: sequential adm_in_channels: 128 conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - is_trainable: True input_key: cls ucg_rate: 0.2 target: sgm.modules.encoders.modules.ClassEmbedder params: embed_dim: 128 n_classes: 10 first_stage_config: target: sgm.models.autoencoder.IdentityFirstStage loss_fn_config: target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss params: loss_weighting_config: target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting params: sigma_data: 1.0 sigma_sampler_config: target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling sampler_config: target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler params: num_steps: 50 discretization_config: target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization guider_config: target: sgm.modules.diffusionmodules.guiders.VanillaCFG params: scale: 3.0 data: target: sgm.data.cifar10.CIFAR10Loader params: batch_size: 512 num_workers: 1 lightning: modelcheckpoint: params: every_n_train_steps: 5000 callbacks: metrics_over_trainsteps_checkpoint: params: every_n_train_steps: 25000 image_logger: target: main.ImageLogger params: disabled: False batch_frequency: 1000 max_images: 64 increase_log_steps: True log_first_step: False log_images_kwargs: use_ema_scope: False N: 64 n_rows: 8 trainer: devices: 0, benchmark: True num_sanity_val_steps: 0 accumulate_grad_batches: 1 max_epochs: 20 ================================================ FILE: configs/example_training/toy/mnist.yaml ================================================ model: base_learning_rate: 1.0e-4 target: sgm.models.diffusion.DiffusionEngine params: denoiser_config: target: sgm.modules.diffusionmodules.denoiser.Denoiser params: scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling params: sigma_data: 1.0 network_config: target: sgm.modules.diffusionmodules.openaimodel.UNetModel params: in_channels: 1 out_channels: 1 model_channels: 32 attention_resolutions: [] num_res_blocks: 4 channel_mult: [1, 2, 2] num_head_channels: 32 first_stage_config: target: sgm.models.autoencoder.IdentityFirstStage loss_fn_config: target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss params: loss_weighting_config: target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting params: sigma_data: 1.0 sigma_sampler_config: target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling sampler_config: target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler params: num_steps: 50 discretization_config: target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization data: target: sgm.data.mnist.MNISTLoader params: batch_size: 512 num_workers: 1 lightning: modelcheckpoint: params: every_n_train_steps: 5000 callbacks: metrics_over_trainsteps_checkpoint: params: every_n_train_steps: 25000 image_logger: target: main.ImageLogger params: disabled: False batch_frequency: 1000 max_images: 64 increase_log_steps: False log_first_step: False log_images_kwargs: use_ema_scope: False N: 64 n_rows: 8 trainer: devices: 0, benchmark: True num_sanity_val_steps: 0 accumulate_grad_batches: 1 max_epochs: 10 ================================================ FILE: configs/example_training/toy/mnist_cond.yaml ================================================ model: base_learning_rate: 1.0e-4 target: sgm.models.diffusion.DiffusionEngine params: denoiser_config: target: sgm.modules.diffusionmodules.denoiser.Denoiser params: scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling params: sigma_data: 1.0 network_config: target: sgm.modules.diffusionmodules.openaimodel.UNetModel params: in_channels: 1 out_channels: 1 model_channels: 32 attention_resolutions: [] num_res_blocks: 4 channel_mult: [1, 2, 2] num_head_channels: 32 num_classes: sequential adm_in_channels: 128 conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - is_trainable: True input_key: cls ucg_rate: 0.2 target: sgm.modules.encoders.modules.ClassEmbedder params: embed_dim: 128 n_classes: 10 first_stage_config: target: sgm.models.autoencoder.IdentityFirstStage loss_fn_config: target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss params: loss_weighting_config: target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting params: sigma_data: 1.0 sigma_sampler_config: target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling sampler_config: target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler params: num_steps: 50 discretization_config: target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization guider_config: target: sgm.modules.diffusionmodules.guiders.VanillaCFG params: scale: 3.0 data: target: sgm.data.mnist.MNISTLoader params: batch_size: 512 num_workers: 1 lightning: modelcheckpoint: params: every_n_train_steps: 5000 callbacks: metrics_over_trainsteps_checkpoint: params: every_n_train_steps: 25000 image_logger: target: main.ImageLogger params: disabled: False batch_frequency: 1000 max_images: 16 increase_log_steps: True log_first_step: False log_images_kwargs: use_ema_scope: False N: 16 n_rows: 4 trainer: devices: 0, benchmark: True num_sanity_val_steps: 0 accumulate_grad_batches: 1 max_epochs: 20 ================================================ FILE: configs/example_training/toy/mnist_cond_discrete_eps.yaml ================================================ model: base_learning_rate: 1.0e-4 target: sgm.models.diffusion.DiffusionEngine params: denoiser_config: target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser params: num_idx: 1000 scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling discretization_config: target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization network_config: target: sgm.modules.diffusionmodules.openaimodel.UNetModel params: in_channels: 1 out_channels: 1 model_channels: 32 attention_resolutions: [] num_res_blocks: 4 channel_mult: [1, 2, 2] num_head_channels: 32 num_classes: sequential adm_in_channels: 128 conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - is_trainable: True input_key: cls ucg_rate: 0.2 target: sgm.modules.encoders.modules.ClassEmbedder params: embed_dim: 128 n_classes: 10 first_stage_config: target: sgm.models.autoencoder.IdentityFirstStage loss_fn_config: target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss params: loss_weighting_config: target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting sigma_sampler_config: target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling params: num_idx: 1000 discretization_config: target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization sampler_config: target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler params: num_steps: 50 discretization_config: target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization guider_config: target: sgm.modules.diffusionmodules.guiders.VanillaCFG params: scale: 5.0 data: target: sgm.data.mnist.MNISTLoader params: batch_size: 512 num_workers: 1 lightning: modelcheckpoint: params: every_n_train_steps: 5000 callbacks: metrics_over_trainsteps_checkpoint: params: every_n_train_steps: 25000 image_logger: target: main.ImageLogger params: disabled: False batch_frequency: 1000 max_images: 16 increase_log_steps: True log_first_step: False log_images_kwargs: use_ema_scope: False N: 16 n_rows: 4 trainer: devices: 0, benchmark: True num_sanity_val_steps: 0 accumulate_grad_batches: 1 max_epochs: 20 ================================================ FILE: configs/example_training/toy/mnist_cond_l1_loss.yaml ================================================ model: base_learning_rate: 1.0e-4 target: sgm.models.diffusion.DiffusionEngine params: denoiser_config: target: sgm.modules.diffusionmodules.denoiser.Denoiser params: scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling params: sigma_data: 1.0 network_config: target: sgm.modules.diffusionmodules.openaimodel.UNetModel params: in_channels: 1 out_channels: 1 model_channels: 32 attention_resolutions: [] num_res_blocks: 4 channel_mult: [1, 2, 2] num_head_channels: 32 num_classes: sequential adm_in_channels: 128 conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - is_trainable: True input_key: cls ucg_rate: 0.2 target: sgm.modules.encoders.modules.ClassEmbedder params: embed_dim: 128 n_classes: 10 first_stage_config: target: sgm.models.autoencoder.IdentityFirstStage loss_fn_config: target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss params: loss_type: l1 loss_weighting_config: target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting params: sigma_data: 1.0 sigma_sampler_config: target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling sampler_config: target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler params: num_steps: 50 discretization_config: target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization guider_config: target: sgm.modules.diffusionmodules.guiders.VanillaCFG params: scale: 3.0 data: target: sgm.data.mnist.MNISTLoader params: batch_size: 512 num_workers: 1 lightning: modelcheckpoint: params: every_n_train_steps: 5000 callbacks: metrics_over_trainsteps_checkpoint: params: every_n_train_steps: 25000 image_logger: target: main.ImageLogger params: disabled: False batch_frequency: 1000 max_images: 64 increase_log_steps: True log_first_step: False log_images_kwargs: use_ema_scope: False N: 64 n_rows: 8 trainer: devices: 0, benchmark: True num_sanity_val_steps: 0 accumulate_grad_batches: 1 max_epochs: 20 ================================================ FILE: configs/example_training/toy/mnist_cond_with_ema.yaml ================================================ model: base_learning_rate: 1.0e-4 target: sgm.models.diffusion.DiffusionEngine params: use_ema: True denoiser_config: target: sgm.modules.diffusionmodules.denoiser.Denoiser params: scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling params: sigma_data: 1.0 network_config: target: sgm.modules.diffusionmodules.openaimodel.UNetModel params: in_channels: 1 out_channels: 1 model_channels: 32 attention_resolutions: [] num_res_blocks: 4 channel_mult: [1, 2, 2] num_head_channels: 32 num_classes: sequential adm_in_channels: 128 conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - is_trainable: True input_key: cls ucg_rate: 0.2 target: sgm.modules.encoders.modules.ClassEmbedder params: embed_dim: 128 n_classes: 10 first_stage_config: target: sgm.models.autoencoder.IdentityFirstStage loss_fn_config: target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss params: loss_weighting_config: target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting params: sigma_data: 1.0 sigma_sampler_config: target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling sampler_config: target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler params: num_steps: 50 discretization_config: target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization guider_config: target: sgm.modules.diffusionmodules.guiders.VanillaCFG params: scale: 3.0 data: target: sgm.data.mnist.MNISTLoader params: batch_size: 512 num_workers: 1 lightning: modelcheckpoint: params: every_n_train_steps: 5000 callbacks: metrics_over_trainsteps_checkpoint: params: every_n_train_steps: 25000 image_logger: target: main.ImageLogger params: disabled: False batch_frequency: 1000 max_images: 64 increase_log_steps: True log_first_step: False log_images_kwargs: use_ema_scope: False N: 64 n_rows: 8 trainer: devices: 0, benchmark: True num_sanity_val_steps: 0 accumulate_grad_batches: 1 max_epochs: 20 ================================================ FILE: configs/example_training/txt2img-clipl-legacy-ucg-training.yaml ================================================ model: base_learning_rate: 1.0e-4 target: sgm.models.diffusion.DiffusionEngine params: scale_factor: 0.13025 disable_first_stage_autocast: True log_keys: - txt scheduler_config: target: sgm.lr_scheduler.LambdaLinearScheduler params: warm_up_steps: [10000] cycle_lengths: [10000000000000] f_start: [1.e-6] f_max: [1.] f_min: [1.] denoiser_config: target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser params: num_idx: 1000 scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling discretization_config: target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization network_config: target: sgm.modules.diffusionmodules.openaimodel.UNetModel params: use_checkpoint: True in_channels: 4 out_channels: 4 model_channels: 320 attention_resolutions: [1, 2, 4] num_res_blocks: 2 channel_mult: [1, 2, 4, 4] num_head_channels: 64 num_classes: sequential adm_in_channels: 1792 num_heads: 1 transformer_depth: 1 context_dim: 768 spatial_transformer_attn_type: softmax-xformers conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - is_trainable: True input_key: txt ucg_rate: 0.1 legacy_ucg_value: "" target: sgm.modules.encoders.modules.FrozenCLIPEmbedder params: always_return_pooled: True - is_trainable: False ucg_rate: 0.1 input_key: original_size_as_tuple target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - is_trainable: False input_key: crop_coords_top_left ucg_rate: 0.1 target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 first_stage_config: target: sgm.models.autoencoder.AutoencoderKL params: ckpt_path: CKPT_PATH embed_dim: 4 monitor: val/rec_loss ddconfig: attn_type: vanilla-xformers double_z: true z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [ 1, 2, 4, 4 ] num_res_blocks: 2 attn_resolutions: [ ] dropout: 0.0 lossconfig: target: torch.nn.Identity loss_fn_config: target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss params: loss_weighting_config: target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting sigma_sampler_config: target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling params: num_idx: 1000 discretization_config: target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization sampler_config: target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler params: num_steps: 50 discretization_config: target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization guider_config: target: sgm.modules.diffusionmodules.guiders.VanillaCFG params: scale: 7.5 data: target: sgm.data.dataset.StableDataModuleFromConfig params: train: datapipeline: urls: # USER: adapt this path the root of your custom dataset - DATA_PATH pipeline_config: shardshuffle: 10000 sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM decoders: - pil postprocessors: - target: sdata.mappers.TorchVisionImageTransforms params: key: jpg # USER: you might wanna adapt this for your custom dataset transforms: - target: torchvision.transforms.Resize params: size: 256 interpolation: 3 - target: torchvision.transforms.ToTensor - target: sdata.mappers.Rescaler - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare # USER: you might wanna use non-default parameters due to your custom dataset loader: batch_size: 64 num_workers: 6 lightning: modelcheckpoint: params: every_n_train_steps: 5000 callbacks: metrics_over_trainsteps_checkpoint: params: every_n_train_steps: 25000 image_logger: target: main.ImageLogger params: disabled: False enable_autocast: False batch_frequency: 1000 max_images: 8 increase_log_steps: True log_first_step: False log_images_kwargs: use_ema_scope: False N: 8 n_rows: 2 trainer: devices: 0, benchmark: True num_sanity_val_steps: 0 accumulate_grad_batches: 1 max_epochs: 1000 ================================================ FILE: configs/example_training/txt2img-clipl.yaml ================================================ model: base_learning_rate: 1.0e-4 target: sgm.models.diffusion.DiffusionEngine params: scale_factor: 0.13025 disable_first_stage_autocast: True log_keys: - txt scheduler_config: target: sgm.lr_scheduler.LambdaLinearScheduler params: warm_up_steps: [10000] cycle_lengths: [10000000000000] f_start: [1.e-6] f_max: [1.] f_min: [1.] denoiser_config: target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser params: num_idx: 1000 scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling discretization_config: target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization network_config: target: sgm.modules.diffusionmodules.openaimodel.UNetModel params: use_checkpoint: True in_channels: 4 out_channels: 4 model_channels: 320 attention_resolutions: [1, 2, 4] num_res_blocks: 2 channel_mult: [1, 2, 4, 4] num_head_channels: 64 num_classes: sequential adm_in_channels: 1792 num_heads: 1 transformer_depth: 1 context_dim: 768 spatial_transformer_attn_type: softmax-xformers conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - is_trainable: True input_key: txt ucg_rate: 0.1 legacy_ucg_value: "" target: sgm.modules.encoders.modules.FrozenCLIPEmbedder params: always_return_pooled: True - is_trainable: False ucg_rate: 0.1 input_key: original_size_as_tuple target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - is_trainable: False input_key: crop_coords_top_left ucg_rate: 0.1 target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 first_stage_config: target: sgm.models.autoencoder.AutoencoderKL params: ckpt_path: CKPT_PATH embed_dim: 4 monitor: val/rec_loss ddconfig: attn_type: vanilla-xformers double_z: true z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 lossconfig: target: torch.nn.Identity loss_fn_config: target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss params: loss_weighting_config: target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting sigma_sampler_config: target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling params: num_idx: 1000 discretization_config: target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization sampler_config: target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler params: num_steps: 50 discretization_config: target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization guider_config: target: sgm.modules.diffusionmodules.guiders.VanillaCFG params: scale: 7.5 data: target: sgm.data.dataset.StableDataModuleFromConfig params: train: datapipeline: urls: # USER: adapt this path the root of your custom dataset - DATA_PATH pipeline_config: shardshuffle: 10000 sample_shuffle: 10000 decoders: - pil postprocessors: - target: sdata.mappers.TorchVisionImageTransforms params: key: jpg # USER: you might wanna adapt this for your custom dataset transforms: - target: torchvision.transforms.Resize params: size: 256 interpolation: 3 - target: torchvision.transforms.ToTensor - target: sdata.mappers.Rescaler # USER: you might wanna use non-default parameters due to your custom dataset - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare # USER: you might wanna use non-default parameters due to your custom dataset loader: batch_size: 64 num_workers: 6 lightning: modelcheckpoint: params: every_n_train_steps: 5000 callbacks: metrics_over_trainsteps_checkpoint: params: every_n_train_steps: 25000 image_logger: target: main.ImageLogger params: disabled: False enable_autocast: False batch_frequency: 1000 max_images: 8 increase_log_steps: True log_first_step: False log_images_kwargs: use_ema_scope: False N: 8 n_rows: 2 trainer: devices: 0, benchmark: True num_sanity_val_steps: 0 accumulate_grad_batches: 1 max_epochs: 1000 ================================================ FILE: configs/inference/sd_xl_base.yaml ================================================ model: target: sgm.models.diffusion.DiffusionEngine params: scale_factor: 0.13025 disable_first_stage_autocast: True denoiser_config: target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser params: num_idx: 1000 scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling discretization_config: target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization network_config: target: sgm.modules.diffusionmodules.openaimodel.UNetModel params: adm_in_channels: 2816 num_classes: sequential use_checkpoint: True in_channels: 4 out_channels: 4 model_channels: 320 attention_resolutions: [4, 2] num_res_blocks: 2 channel_mult: [1, 2, 4] num_head_channels: 64 use_linear_in_transformer: True transformer_depth: [1, 2, 10] context_dim: 2048 spatial_transformer_attn_type: softmax-xformers conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - is_trainable: False input_key: txt target: sgm.modules.encoders.modules.FrozenCLIPEmbedder params: layer: hidden layer_idx: 11 - is_trainable: False input_key: txt target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 params: arch: ViT-bigG-14 version: laion2b_s39b_b160k freeze: True layer: penultimate always_return_pooled: True legacy: False - is_trainable: False input_key: original_size_as_tuple target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - is_trainable: False input_key: crop_coords_top_left target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - is_trainable: False input_key: target_size_as_tuple target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 first_stage_config: target: sgm.models.autoencoder.AutoencoderKL params: embed_dim: 4 monitor: val/rec_loss ddconfig: attn_type: vanilla-xformers double_z: true z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 lossconfig: target: torch.nn.Identity ================================================ FILE: configs/inference/sd_xl_refiner.yaml ================================================ model: target: sgm.models.diffusion.DiffusionEngine params: scale_factor: 0.13025 disable_first_stage_autocast: True denoiser_config: target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser params: num_idx: 1000 scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling discretization_config: target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization network_config: target: sgm.modules.diffusionmodules.openaimodel.UNetModel params: adm_in_channels: 2560 num_classes: sequential use_checkpoint: True in_channels: 4 out_channels: 4 model_channels: 384 attention_resolutions: [4, 2] num_res_blocks: 2 channel_mult: [1, 2, 4, 4] num_head_channels: 64 use_linear_in_transformer: True transformer_depth: 4 context_dim: [1280, 1280, 1280, 1280] spatial_transformer_attn_type: softmax-xformers conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - is_trainable: False input_key: txt target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 params: arch: ViT-bigG-14 version: laion2b_s39b_b160k legacy: False freeze: True layer: penultimate always_return_pooled: True - is_trainable: False input_key: original_size_as_tuple target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - is_trainable: False input_key: crop_coords_top_left target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - is_trainable: False input_key: aesthetic_score target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 first_stage_config: target: sgm.models.autoencoder.AutoencoderKL params: embed_dim: 4 monitor: val/rec_loss ddconfig: attn_type: vanilla-xformers double_z: true z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 lossconfig: target: torch.nn.Identity ================================================ FILE: configs/inference/sv3d_p.yaml ================================================ model: target: sgm.models.diffusion.DiffusionEngine params: scale_factor: 0.18215 disable_first_stage_autocast: True denoiser_config: target: sgm.modules.diffusionmodules.denoiser.Denoiser params: scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise network_config: target: sgm.modules.diffusionmodules.video_model.VideoUNet params: adm_in_channels: 1280 num_classes: sequential use_checkpoint: True in_channels: 8 out_channels: 4 model_channels: 320 attention_resolutions: [4, 2, 1] num_res_blocks: 2 channel_mult: [1, 2, 4, 4] num_head_channels: 64 use_linear_in_transformer: True transformer_depth: 1 context_dim: 1024 spatial_transformer_attn_type: softmax-xformers extra_ff_mix_layer: True use_spatial_context: True merge_strategy: learned_with_images video_kernel_size: [3, 1, 1] conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - input_key: cond_frames_without_noise is_trainable: False target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder params: n_cond_frames: 1 n_copies: 1 open_clip_embedding_config: target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder params: freeze: True - input_key: cond_frames is_trainable: False target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder params: disable_encoder_autocast: True n_cond_frames: 1 n_copies: 1 is_ae: True encoder_config: target: sgm.models.autoencoder.AutoencoderKLModeOnly params: embed_dim: 4 monitor: val/rec_loss ddconfig: attn_type: vanilla-xformers double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 lossconfig: target: torch.nn.Identity - input_key: cond_aug is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - input_key: polars_rad is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 512 - input_key: azimuths_rad is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 512 first_stage_config: target: sgm.models.autoencoder.AutoencodingEngine params: loss_config: target: torch.nn.Identity regularizer_config: target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer encoder_config: target: torch.nn.Identity decoder_config: target: sgm.modules.diffusionmodules.model.Decoder params: attn_type: vanilla-xformers double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [ 1, 2, 4, 4 ] num_res_blocks: 2 attn_resolutions: [ ] dropout: 0.0 ================================================ FILE: configs/inference/sv3d_u.yaml ================================================ model: target: sgm.models.diffusion.DiffusionEngine params: scale_factor: 0.18215 disable_first_stage_autocast: True denoiser_config: target: sgm.modules.diffusionmodules.denoiser.Denoiser params: scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise network_config: target: sgm.modules.diffusionmodules.video_model.VideoUNet params: adm_in_channels: 256 num_classes: sequential use_checkpoint: True in_channels: 8 out_channels: 4 model_channels: 320 attention_resolutions: [4, 2, 1] num_res_blocks: 2 channel_mult: [1, 2, 4, 4] num_head_channels: 64 use_linear_in_transformer: True transformer_depth: 1 context_dim: 1024 spatial_transformer_attn_type: softmax-xformers extra_ff_mix_layer: True use_spatial_context: True merge_strategy: learned_with_images video_kernel_size: [3, 1, 1] conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - input_key: cond_frames_without_noise is_trainable: False target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder params: n_cond_frames: 1 n_copies: 1 open_clip_embedding_config: target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder params: freeze: True - input_key: cond_frames is_trainable: False target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder params: disable_encoder_autocast: True n_cond_frames: 1 n_copies: 1 is_ae: True encoder_config: target: sgm.models.autoencoder.AutoencoderKLModeOnly params: embed_dim: 4 monitor: val/rec_loss ddconfig: attn_type: vanilla-xformers double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 lossconfig: target: torch.nn.Identity - input_key: cond_aug is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 first_stage_config: target: sgm.models.autoencoder.AutoencodingEngine params: loss_config: target: torch.nn.Identity regularizer_config: target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer encoder_config: target: torch.nn.Identity decoder_config: target: sgm.modules.diffusionmodules.model.Decoder params: attn_type: vanilla-xformers double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [ 1, 2, 4, 4 ] num_res_blocks: 2 attn_resolutions: [ ] dropout: 0.0 ================================================ FILE: configs/inference/svd.yaml ================================================ model: target: sgm.models.diffusion.DiffusionEngine params: scale_factor: 0.18215 disable_first_stage_autocast: True denoiser_config: target: sgm.modules.diffusionmodules.denoiser.Denoiser params: scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise network_config: target: sgm.modules.diffusionmodules.video_model.VideoUNet params: adm_in_channels: 768 num_classes: sequential use_checkpoint: True in_channels: 8 out_channels: 4 model_channels: 320 attention_resolutions: [4, 2, 1] num_res_blocks: 2 channel_mult: [1, 2, 4, 4] num_head_channels: 64 use_linear_in_transformer: True transformer_depth: 1 context_dim: 1024 spatial_transformer_attn_type: softmax-xformers extra_ff_mix_layer: True use_spatial_context: True merge_strategy: learned_with_images video_kernel_size: [3, 1, 1] conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - is_trainable: False input_key: cond_frames_without_noise target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder params: n_cond_frames: 1 n_copies: 1 open_clip_embedding_config: target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder params: freeze: True - input_key: fps_id is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - input_key: motion_bucket_id is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - input_key: cond_frames is_trainable: False target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder params: disable_encoder_autocast: True n_cond_frames: 1 n_copies: 1 is_ae: True encoder_config: target: sgm.models.autoencoder.AutoencoderKLModeOnly params: embed_dim: 4 monitor: val/rec_loss ddconfig: attn_type: vanilla-xformers double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 lossconfig: target: torch.nn.Identity - input_key: cond_aug is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 first_stage_config: target: sgm.models.autoencoder.AutoencodingEngine params: loss_config: target: torch.nn.Identity regularizer_config: target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer encoder_config: target: sgm.modules.diffusionmodules.model.Encoder params: attn_type: vanilla double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 decoder_config: target: sgm.modules.autoencoding.temporal_ae.VideoDecoder params: attn_type: vanilla double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 video_kernel_size: [3, 1, 1] ================================================ FILE: configs/inference/svd_image_decoder.yaml ================================================ model: target: sgm.models.diffusion.DiffusionEngine params: scale_factor: 0.18215 disable_first_stage_autocast: True denoiser_config: target: sgm.modules.diffusionmodules.denoiser.Denoiser params: scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise network_config: target: sgm.modules.diffusionmodules.video_model.VideoUNet params: adm_in_channels: 768 num_classes: sequential use_checkpoint: True in_channels: 8 out_channels: 4 model_channels: 320 attention_resolutions: [4, 2, 1] num_res_blocks: 2 channel_mult: [1, 2, 4, 4] num_head_channels: 64 use_linear_in_transformer: True transformer_depth: 1 context_dim: 1024 spatial_transformer_attn_type: softmax-xformers extra_ff_mix_layer: True use_spatial_context: True merge_strategy: learned_with_images video_kernel_size: [3, 1, 1] conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - is_trainable: False input_key: cond_frames_without_noise target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder params: n_cond_frames: 1 n_copies: 1 open_clip_embedding_config: target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder params: freeze: True - input_key: fps_id is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - input_key: motion_bucket_id is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - input_key: cond_frames is_trainable: False target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder params: disable_encoder_autocast: True n_cond_frames: 1 n_copies: 1 is_ae: True encoder_config: target: sgm.models.autoencoder.AutoencoderKLModeOnly params: embed_dim: 4 monitor: val/rec_loss ddconfig: attn_type: vanilla-xformers double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 lossconfig: target: torch.nn.Identity - input_key: cond_aug is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 first_stage_config: target: sgm.models.autoencoder.AutoencoderKL params: embed_dim: 4 monitor: val/rec_loss ddconfig: attn_type: vanilla-xformers double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 lossconfig: target: torch.nn.Identity ================================================ FILE: main.py ================================================ import argparse import datetime import glob import inspect import os import sys from inspect import Parameter from typing import Union import numpy as np import pytorch_lightning as pl import torch import torchvision import wandb from matplotlib import pyplot as plt from natsort import natsorted from omegaconf import OmegaConf from packaging import version from PIL import Image from pytorch_lightning import seed_everything from pytorch_lightning.callbacks import Callback from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.trainer import Trainer from pytorch_lightning.utilities import rank_zero_only from sgm.util import exists, instantiate_from_config, isheatmap MULTINODE_HACKS = True def default_trainer_args(): argspec = dict(inspect.signature(Trainer.__init__).parameters) argspec.pop("self") default_args = { param: argspec[param].default for param in argspec if argspec[param] != Parameter.empty } return default_args def get_parser(**parser_kwargs): def str2bool(v): if isinstance(v, bool): return v if v.lower() in ("yes", "true", "t", "y", "1"): return True elif v.lower() in ("no", "false", "f", "n", "0"): return False else: raise argparse.ArgumentTypeError("Boolean value expected.") parser = argparse.ArgumentParser(**parser_kwargs) parser.add_argument( "-n", "--name", type=str, const=True, default="", nargs="?", help="postfix for logdir", ) parser.add_argument( "--no_date", type=str2bool, nargs="?", const=True, default=False, help="if True, skip date generation for logdir and only use naming via opt.base or opt.name (+ opt.postfix, optionally)", ) parser.add_argument( "-r", "--resume", type=str, const=True, default="", nargs="?", help="resume from logdir or checkpoint in logdir", ) parser.add_argument( "-b", "--base", nargs="*", metavar="base_config.yaml", help="paths to base configs. Loaded from left-to-right. " "Parameters can be overwritten or added with command-line options of the form `--key value`.", default=list(), ) parser.add_argument( "-t", "--train", type=str2bool, const=True, default=True, nargs="?", help="train", ) parser.add_argument( "--no-test", type=str2bool, const=True, default=False, nargs="?", help="disable test", ) parser.add_argument( "-p", "--project", help="name of new or path to existing project" ) parser.add_argument( "-d", "--debug", type=str2bool, nargs="?", const=True, default=False, help="enable post-mortem debugging", ) parser.add_argument( "-s", "--seed", type=int, default=23, help="seed for seed_everything", ) parser.add_argument( "-f", "--postfix", type=str, default="", help="post-postfix for default name", ) parser.add_argument( "--projectname", type=str, default="stablediffusion", ) parser.add_argument( "-l", "--logdir", type=str, default="logs", help="directory for logging dat shit", ) parser.add_argument( "--scale_lr", type=str2bool, nargs="?", const=True, default=False, help="scale base-lr by ngpu * batch_size * n_accumulate", ) parser.add_argument( "--legacy_naming", type=str2bool, nargs="?", const=True, default=False, help="name run based on config file name if true, else by whole path", ) parser.add_argument( "--enable_tf32", type=str2bool, nargs="?", const=True, default=False, help="enables the TensorFloat32 format both for matmuls and cuDNN for pytorch 1.12", ) parser.add_argument( "--startup", type=str, default=None, help="Startuptime from distributed script", ) parser.add_argument( "--wandb", type=str2bool, nargs="?", const=True, default=False, # TODO: later default to True help="log to wandb", ) parser.add_argument( "--no_base_name", type=str2bool, nargs="?", const=True, default=False, # TODO: later default to True help="log to wandb", ) if version.parse(torch.__version__) >= version.parse("2.0.0"): parser.add_argument( "--resume_from_checkpoint", type=str, default=None, help="single checkpoint file to resume from", ) default_args = default_trainer_args() for key in default_args: parser.add_argument("--" + key, default=default_args[key]) return parser def get_checkpoint_name(logdir): ckpt = os.path.join(logdir, "checkpoints", "last**.ckpt") ckpt = natsorted(glob.glob(ckpt)) print('available "last" checkpoints:') print(ckpt) if len(ckpt) > 1: print("got most recent checkpoint") ckpt = sorted(ckpt, key=lambda x: os.path.getmtime(x))[-1] print(f"Most recent ckpt is {ckpt}") with open(os.path.join(logdir, "most_recent_ckpt.txt"), "w") as f: f.write(ckpt + "\n") try: version = int(ckpt.split("/")[-1].split("-v")[-1].split(".")[0]) except Exception as e: print("version confusion but not bad") print(e) version = 1 # version = last_version + 1 else: # in this case, we only have one "last.ckpt" ckpt = ckpt[0] version = 1 melk_ckpt_name = f"last-v{version}.ckpt" print(f"Current melk ckpt name: {melk_ckpt_name}") return ckpt, melk_ckpt_name class SetupCallback(Callback): def __init__( self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config, debug, ckpt_name=None, ): super().__init__() self.resume = resume self.now = now self.logdir = logdir self.ckptdir = ckptdir self.cfgdir = cfgdir self.config = config self.lightning_config = lightning_config self.debug = debug self.ckpt_name = ckpt_name def on_exception(self, trainer: pl.Trainer, pl_module, exception): if not self.debug and trainer.global_rank == 0: print("Summoning checkpoint.") if self.ckpt_name is None: ckpt_path = os.path.join(self.ckptdir, "last.ckpt") else: ckpt_path = os.path.join(self.ckptdir, self.ckpt_name) trainer.save_checkpoint(ckpt_path) def on_fit_start(self, trainer, pl_module): if trainer.global_rank == 0: # Create logdirs and save configs os.makedirs(self.logdir, exist_ok=True) os.makedirs(self.ckptdir, exist_ok=True) os.makedirs(self.cfgdir, exist_ok=True) if "callbacks" in self.lightning_config: if ( "metrics_over_trainsteps_checkpoint" in self.lightning_config["callbacks"] ): os.makedirs( os.path.join(self.ckptdir, "trainstep_checkpoints"), exist_ok=True, ) print("Project config") print(OmegaConf.to_yaml(self.config)) if MULTINODE_HACKS: import time time.sleep(5) OmegaConf.save( self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)), ) print("Lightning config") print(OmegaConf.to_yaml(self.lightning_config)) OmegaConf.save( OmegaConf.create({"lightning": self.lightning_config}), os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)), ) else: # ModelCheckpoint callback created log directory --- remove it if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir): dst, name = os.path.split(self.logdir) dst = os.path.join(dst, "child_runs", name) os.makedirs(os.path.split(dst)[0], exist_ok=True) try: os.rename(self.logdir, dst) except FileNotFoundError: pass class ImageLogger(Callback): def __init__( self, batch_frequency, max_images, clamp=True, increase_log_steps=True, rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, log_images_kwargs=None, log_before_first_step=False, enable_autocast=True, ): super().__init__() self.enable_autocast = enable_autocast self.rescale = rescale self.batch_freq = batch_frequency self.max_images = max_images self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)] if not increase_log_steps: self.log_steps = [self.batch_freq] self.clamp = clamp self.disabled = disabled self.log_on_batch_idx = log_on_batch_idx self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} self.log_first_step = log_first_step self.log_before_first_step = log_before_first_step @rank_zero_only def log_local( self, save_dir, split, images, global_step, current_epoch, batch_idx, pl_module: Union[None, pl.LightningModule] = None, ): root = os.path.join(save_dir, "images", split) for k in images: if isheatmap(images[k]): fig, ax = plt.subplots() ax = ax.matshow( images[k].cpu().numpy(), cmap="hot", interpolation="lanczos" ) plt.colorbar(ax) plt.axis("off") filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format( k, global_step, current_epoch, batch_idx ) os.makedirs(root, exist_ok=True) path = os.path.join(root, filename) plt.savefig(path) plt.close() # TODO: support wandb else: grid = torchvision.utils.make_grid(images[k], nrow=4) if self.rescale: grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) grid = grid.numpy() grid = (grid * 255).astype(np.uint8) filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format( k, global_step, current_epoch, batch_idx ) path = os.path.join(root, filename) os.makedirs(os.path.split(path)[0], exist_ok=True) img = Image.fromarray(grid) img.save(path) if exists(pl_module): assert isinstance( pl_module.logger, WandbLogger ), "logger_log_image only supports WandbLogger currently" pl_module.logger.log_image( key=f"{split}/{k}", images=[ img, ], step=pl_module.global_step, ) @rank_zero_only def log_img(self, pl_module, batch, batch_idx, split="train"): check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step if ( self.check_frequency(check_idx) and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0 and callable(pl_module.log_images) and # batch_idx > 5 and self.max_images > 0 ): logger = type(pl_module.logger) is_train = pl_module.training if is_train: pl_module.eval() gpu_autocast_kwargs = { "enabled": self.enable_autocast, # torch.is_autocast_enabled(), "dtype": torch.get_autocast_gpu_dtype(), "cache_enabled": torch.is_autocast_cache_enabled(), } with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs): images = pl_module.log_images( batch, split=split, **self.log_images_kwargs ) for k in images: N = min(images[k].shape[0], self.max_images) if not isheatmap(images[k]): images[k] = images[k][:N] if isinstance(images[k], torch.Tensor): images[k] = images[k].detach().float().cpu() if self.clamp and not isheatmap(images[k]): images[k] = torch.clamp(images[k], -1.0, 1.0) self.log_local( pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch, batch_idx, pl_module=pl_module if isinstance(pl_module.logger, WandbLogger) else None, ) if is_train: pl_module.train() def check_frequency(self, check_idx): if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and ( check_idx > 0 or self.log_first_step ): try: self.log_steps.pop(0) except IndexError as e: print(e) pass return True return False @rank_zero_only def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): self.log_img(pl_module, batch, batch_idx, split="train") @rank_zero_only def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): if self.log_before_first_step and pl_module.global_step == 0: print(f"{self.__class__.__name__}: logging before training") self.log_img(pl_module, batch, batch_idx, split="train") @rank_zero_only def on_validation_batch_end( self, trainer, pl_module, outputs, batch, batch_idx, *args, **kwargs ): if not self.disabled and pl_module.global_step > 0: self.log_img(pl_module, batch, batch_idx, split="val") if hasattr(pl_module, "calibrate_grad_norm"): if ( pl_module.calibrate_grad_norm and batch_idx % 25 == 0 ) and batch_idx > 0: self.log_gradients(trainer, pl_module, batch_idx=batch_idx) @rank_zero_only def init_wandb(save_dir, opt, config, group_name, name_str): print(f"setting WANDB_DIR to {save_dir}") os.makedirs(save_dir, exist_ok=True) os.environ["WANDB_DIR"] = save_dir if opt.debug: wandb.init(project=opt.projectname, mode="offline", group=group_name) else: wandb.init( project=opt.projectname, config=config, settings=wandb.Settings(code_dir="./sgm"), group=group_name, name=name_str, ) if __name__ == "__main__": # custom parser to specify config files, train, test and debug mode, # postfix, resume. # `--key value` arguments are interpreted as arguments to the trainer. # `nested.key=value` arguments are interpreted as config parameters. # configs are merged from left-to-right followed by command line parameters. # model: # base_learning_rate: float # target: path to lightning module # params: # key: value # data: # target: main.DataModuleFromConfig # params: # batch_size: int # wrap: bool # train: # target: path to train dataset # params: # key: value # validation: # target: path to validation dataset # params: # key: value # test: # target: path to test dataset # params: # key: value # lightning: (optional, has sane defaults and can be specified on cmdline) # trainer: # additional arguments to trainer # logger: # logger to instantiate # modelcheckpoint: # modelcheckpoint to instantiate # callbacks: # callback1: # target: importpath # params: # key: value now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") # add cwd for convenience and to make classes in this file available when # running as `python main.py` # (in particular `main.DataModuleFromConfig`) sys.path.append(os.getcwd()) parser = get_parser() opt, unknown = parser.parse_known_args() if opt.name and opt.resume: raise ValueError( "-n/--name and -r/--resume cannot be specified both." "If you want to resume training in a new log folder, " "use -n/--name in combination with --resume_from_checkpoint" ) melk_ckpt_name = None name = None if opt.resume: if not os.path.exists(opt.resume): raise ValueError("Cannot find {}".format(opt.resume)) if os.path.isfile(opt.resume): paths = opt.resume.split("/") # idx = len(paths)-paths[::-1].index("logs")+1 # logdir = "/".join(paths[:idx]) logdir = "/".join(paths[:-2]) ckpt = opt.resume _, melk_ckpt_name = get_checkpoint_name(logdir) else: assert os.path.isdir(opt.resume), opt.resume logdir = opt.resume.rstrip("/") ckpt, melk_ckpt_name = get_checkpoint_name(logdir) print("#" * 100) print(f'Resuming from checkpoint "{ckpt}"') print("#" * 100) opt.resume_from_checkpoint = ckpt base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) opt.base = base_configs + opt.base _tmp = logdir.split("/") nowname = _tmp[-1] else: if opt.name: name = "_" + opt.name elif opt.base: if opt.no_base_name: name = "" else: if opt.legacy_naming: cfg_fname = os.path.split(opt.base[0])[-1] cfg_name = os.path.splitext(cfg_fname)[0] else: assert "configs" in os.path.split(opt.base[0])[0], os.path.split( opt.base[0] )[0] cfg_path = os.path.split(opt.base[0])[0].split(os.sep)[ os.path.split(opt.base[0])[0].split(os.sep).index("configs") + 1 : ] # cut away the first one (we assert all configs are in "configs") cfg_name = os.path.splitext(os.path.split(opt.base[0])[-1])[0] cfg_name = "-".join(cfg_path) + f"-{cfg_name}" name = "_" + cfg_name else: name = "" if not opt.no_date: nowname = now + name + opt.postfix else: nowname = name + opt.postfix if nowname.startswith("_"): nowname = nowname[1:] logdir = os.path.join(opt.logdir, nowname) print(f"LOGDIR: {logdir}") ckptdir = os.path.join(logdir, "checkpoints") cfgdir = os.path.join(logdir, "configs") seed_everything(opt.seed, workers=True) # move before model init, in case a torch.compile(...) is called somewhere if opt.enable_tf32: # pt_version = version.parse(torch.__version__) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True print(f"Enabling TF32 for PyTorch {torch.__version__}") else: print(f"Using default TF32 settings for PyTorch {torch.__version__}:") print( f"torch.backends.cuda.matmul.allow_tf32={torch.backends.cuda.matmul.allow_tf32}" ) print(f"torch.backends.cudnn.allow_tf32={torch.backends.cudnn.allow_tf32}") try: # init and save configs configs = [OmegaConf.load(cfg) for cfg in opt.base] cli = OmegaConf.from_dotlist(unknown) config = OmegaConf.merge(*configs, cli) lightning_config = config.pop("lightning", OmegaConf.create()) # merge trainer cli with config trainer_config = lightning_config.get("trainer", OmegaConf.create()) # default to gpu trainer_config["accelerator"] = "gpu" # standard_args = default_trainer_args() for k in standard_args: if getattr(opt, k) != standard_args[k]: trainer_config[k] = getattr(opt, k) ckpt_resume_path = opt.resume_from_checkpoint if not "devices" in trainer_config and trainer_config["accelerator"] != "gpu": del trainer_config["accelerator"] cpu = True else: gpuinfo = trainer_config["devices"] print(f"Running on GPUs {gpuinfo}") cpu = False trainer_opt = argparse.Namespace(**trainer_config) lightning_config.trainer = trainer_config # model model = instantiate_from_config(config.model) # trainer and callbacks trainer_kwargs = dict() # default logger configs default_logger_cfgs = { "wandb": { "target": "pytorch_lightning.loggers.WandbLogger", "params": { "name": nowname, # "save_dir": logdir, "offline": opt.debug, "id": nowname, "project": opt.projectname, "log_model": False, # "dir": logdir, }, }, "csv": { "target": "pytorch_lightning.loggers.CSVLogger", "params": { "name": "testtube", # hack for sbord fanatics "save_dir": logdir, }, }, } default_logger_cfg = default_logger_cfgs["wandb" if opt.wandb else "csv"] if opt.wandb: # TODO change once leaving "swiffer" config directory try: group_name = nowname.split(now)[-1].split("-")[1] except: group_name = nowname default_logger_cfg["params"]["group"] = group_name init_wandb( os.path.join(os.getcwd(), logdir), opt=opt, group_name=group_name, config=config, name_str=nowname, ) if "logger" in lightning_config: logger_cfg = lightning_config.logger else: logger_cfg = OmegaConf.create() logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to # specify which metric is used to determine best models default_modelckpt_cfg = { "target": "pytorch_lightning.callbacks.ModelCheckpoint", "params": { "dirpath": ckptdir, "filename": "{epoch:06}", "verbose": True, "save_last": True, }, } if hasattr(model, "monitor"): print(f"Monitoring {model.monitor} as checkpoint metric.") default_modelckpt_cfg["params"]["monitor"] = model.monitor default_modelckpt_cfg["params"]["save_top_k"] = 3 if "modelcheckpoint" in lightning_config: modelckpt_cfg = lightning_config.modelcheckpoint else: modelckpt_cfg = OmegaConf.create() modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}") # https://pytorch-lightning.readthedocs.io/en/stable/extensions/strategy.html # default to ddp if not further specified default_strategy_config = {"target": "pytorch_lightning.strategies.DDPStrategy"} if "strategy" in lightning_config: strategy_cfg = lightning_config.strategy else: strategy_cfg = OmegaConf.create() default_strategy_config["params"] = { "find_unused_parameters": False, # "static_graph": True, # "ddp_comm_hook": default.fp16_compress_hook # TODO: experiment with this, also for DDPSharded } strategy_cfg = OmegaConf.merge(default_strategy_config, strategy_cfg) print( f"strategy config: \n ++++++++++++++ \n {strategy_cfg} \n ++++++++++++++ " ) trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg) # add callback which sets up log directory default_callbacks_cfg = { "setup_callback": { "target": "main.SetupCallback", "params": { "resume": opt.resume, "now": now, "logdir": logdir, "ckptdir": ckptdir, "cfgdir": cfgdir, "config": config, "lightning_config": lightning_config, "debug": opt.debug, "ckpt_name": melk_ckpt_name, }, }, "image_logger": { "target": "main.ImageLogger", "params": {"batch_frequency": 1000, "max_images": 4, "clamp": True}, }, "learning_rate_logger": { "target": "pytorch_lightning.callbacks.LearningRateMonitor", "params": { "logging_interval": "step", # "log_momentum": True }, }, } if version.parse(pl.__version__) >= version.parse("1.4.0"): default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg}) if "callbacks" in lightning_config: callbacks_cfg = lightning_config.callbacks else: callbacks_cfg = OmegaConf.create() if "metrics_over_trainsteps_checkpoint" in callbacks_cfg: print( "Caution: Saving checkpoints every n train steps without deleting. This might require some free space." ) default_metrics_over_trainsteps_ckpt_dict = { "metrics_over_trainsteps_checkpoint": { "target": "pytorch_lightning.callbacks.ModelCheckpoint", "params": { "dirpath": os.path.join(ckptdir, "trainstep_checkpoints"), "filename": "{epoch:06}-{step:09}", "verbose": True, "save_top_k": -1, "every_n_train_steps": 10000, "save_weights_only": True, }, } } default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) if "ignore_keys_callback" in callbacks_cfg and ckpt_resume_path is not None: callbacks_cfg.ignore_keys_callback.params["ckpt_path"] = ckpt_resume_path elif "ignore_keys_callback" in callbacks_cfg: del callbacks_cfg["ignore_keys_callback"] trainer_kwargs["callbacks"] = [ instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg ] if not "plugins" in trainer_kwargs: trainer_kwargs["plugins"] = list() # cmd line trainer args (which are in trainer_opt) have always priority over config-trainer-args (which are in trainer_kwargs) trainer_opt = vars(trainer_opt) trainer_kwargs = { key: val for key, val in trainer_kwargs.items() if key not in trainer_opt } trainer = Trainer(**trainer_opt, **trainer_kwargs) trainer.logdir = logdir ### # data data = instantiate_from_config(config.data) # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html # calling these ourselves should not be necessary but it is. # lightning still takes care of proper multiprocessing though data.prepare_data() # data.setup() print("#### Data #####") try: for k in data.datasets: print( f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}" ) except: print("datasets not yet initialized.") # configure learning rate if "batch_size" in config.data.params: bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate else: bs, base_lr = ( config.data.params.train.loader.batch_size, config.model.base_learning_rate, ) if not cpu: ngpu = len(lightning_config.trainer.devices.strip(",").split(",")) else: ngpu = 1 if "accumulate_grad_batches" in lightning_config.trainer: accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches else: accumulate_grad_batches = 1 print(f"accumulate_grad_batches = {accumulate_grad_batches}") lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches if opt.scale_lr: model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr print( "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format( model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr ) ) else: model.learning_rate = base_lr print("++++ NOT USING LR SCALING ++++") print(f"Setting learning rate to {model.learning_rate:.2e}") # allow checkpointing via USR1 def melk(*args, **kwargs): # run all checkpoint hooks if trainer.global_rank == 0: print("Summoning checkpoint.") if melk_ckpt_name is None: ckpt_path = os.path.join(ckptdir, "last.ckpt") else: ckpt_path = os.path.join(ckptdir, melk_ckpt_name) trainer.save_checkpoint(ckpt_path) def divein(*args, **kwargs): if trainer.global_rank == 0: import pudb pudb.set_trace() import signal signal.signal(signal.SIGUSR1, melk) signal.signal(signal.SIGUSR2, divein) # run if opt.train: try: trainer.fit(model, data, ckpt_path=ckpt_resume_path) except Exception: if not opt.debug: melk() raise if not opt.no_test and not trainer.interrupted: trainer.test(model, data) except RuntimeError as err: if MULTINODE_HACKS: import datetime import os import socket import requests device = os.environ.get("CUDA_VISIBLE_DEVICES", "?") hostname = socket.gethostname() ts = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") resp = requests.get("http://169.254.169.254/latest/meta-data/instance-id") print( f"ERROR at {ts} on {hostname}/{resp.text} (CUDA_VISIBLE_DEVICES={device}): {type(err).__name__}: {err}", flush=True, ) raise err except Exception: if opt.debug and trainer.global_rank == 0: try: import pudb as debugger except ImportError: import pdb as debugger debugger.post_mortem() raise finally: # move newly created debug project to debug_runs if opt.debug and not opt.resume and trainer.global_rank == 0: dst, name = os.path.split(logdir) dst = os.path.join(dst, "debug_runs", name) os.makedirs(os.path.split(dst)[0], exist_ok=True) os.rename(logdir, dst) if opt.wandb: wandb.finish() # if trainer.global_rank == 0: # print(trainer.profiler.summary()) ================================================ FILE: model_licenses/LICENSE-SDXL-Turbo ================================================ STABILITY AI NON-COMMERCIAL RESEARCH COMMUNITY LICENSE AGREEMENT Dated: November 28, 2023 By using or distributing any portion or element of the Models, Software, Software Products or Derivative Works, you agree to be bound by this Agreement. "Agreement" means this Stable Non-Commercial Research Community License Agreement. “AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time. "Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model. “Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software. "Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf. “Model(s)" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement. “Non-Commercial Uses” means exercising any of the rights granted herein for the purpose of research or non-commercial purposes. Non-Commercial Uses does not include any production use of the Software Products or any Derivative Works. "Stability AI" or "we" means Stability AI Ltd. and its affiliates. "Software" means Stability AI’s proprietary software made available under this Agreement. “Software Products” means the Models, Software and Documentation, individually or in any combination. 1. License Rights and Redistribution. a. Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned or controlled by Stability AI embodied in the Software Products to reproduce the Software Products and produce, reproduce, distribute, and create Derivative Works of the Software Products for Non-Commercial Uses only, respectively. b. You may not use the Software Products or Derivative Works to enable third parties to use the Software Products or Derivative Works as part of your hosted service or via your APIs, whether you are adding substantial additional functionality thereto or not. Merely distributing the Software Products or Derivative Works for download online without offering any related service (ex. by distributing the Models on HuggingFace) is not a violation of this subsection. If you wish to use the Software Products or any Derivative Works for commercial or production use or you wish to make the Software Products or any Derivative Works available to third parties via your hosted service or your APIs, contact Stability AI at https://stability.ai/contact. c. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Non-Commercial Research Community License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified. 2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS. 3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. 4. Intellectual Property. a. No trademark licenses are granted under this Agreement, and in connection with the Software Products or Derivative Works, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products or Derivative Works. b. Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works c. If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products or Derivative Works in violation of this Agreement. 5. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of any Software Products or Derivative Works. Sections 2-4 shall survive the termination of this Agreement. ================================================ FILE: model_licenses/LICENSE-SDXL0.9 ================================================ SDXL 0.9 RESEARCH LICENSE AGREEMENT Copyright (c) Stability AI Ltd. This License Agreement (as may be amended in accordance with this License Agreement, “License”), between you, or your employer or other entity (if you are entering into this agreement on behalf of your employer or other entity) (“Licensee” or “you”) and Stability AI Ltd. (“Stability AI” or “we”) applies to your use of any computer program, algorithm, source code, object code, or software that is made available by Stability AI under this License (“Software”) and any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software (“Documentation”). By clicking “I Accept” below or by using the Software, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to use the Software or Documentation (collectively, the “Software Products”), and you must immediately cease using the Software Products. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to Stability AI that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the Software Products on behalf of your employer or other entity. 1. LICENSE GRANT a. Subject to your compliance with the Documentation and Sections 2, 3, and 5, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s copyright interests to reproduce, distribute, and create derivative works of the Software solely for your non-commercial research purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Stability AI’s prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License. b. You may make a reasonable number of copies of the Documentation solely for use in connection with the license to the Software granted above. c. The grant of rights expressly set forth in this Section 1 (License Grant) are the complete grant of rights to you in the Software Products, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Stability AI and its licensors reserve all rights not expressly granted by this License. 2. RESTRICTIONS You will not, and will not permit, assist or cause any third party to: a. use, modify, copy, reproduce, create derivative works of, or distribute the Software Products (or any derivative works thereof, works incorporating the Software Products, or any data produced by the Software), in whole or in part, for (i) any commercial or production purposes, (ii) military purposes or in the service of nuclear technology, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates any third-party rights, or (vi) in any manner that violates any applicable law and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and all laws governing the processing of biometric information), as well as all amendments and successor laws to any of the foregoing; b. alter or remove copyright and other proprietary notices which appear on or in the Software Products; c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Stability AI in connection with the Software, or to circumvent or remove any usage restrictions, or to enable functionality disabled by Stability AI; or d. offer or impose any terms on the Software Products that alter, restrict, or are inconsistent with the terms of this License. e. 1) violate any applicable U.S. and non-U.S. export control and trade sanctions laws (“Export Laws”); 2) directly or indirectly export, re-export, provide, or otherwise transfer Software Products: (a) to any individual, entity, or country prohibited by Export Laws; (b) to anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications; 3) use or download Software Products if you or they are: (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any purpose prohibited by Export Laws; and (4) will not disguise your location through IP proxying or other methods. 3. ATTRIBUTION Together with any copies of the Software Products (as well as derivative works thereof or works incorporating the Software Products) that you distribute, you must provide (i) a copy of this License, and (ii) the following attribution notice: “SDXL 0.9 is licensed under the SDXL Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.” 4. DISCLAIMERS THE SOFTWARE PRODUCTS ARE PROVIDED “AS IS” AND “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. STABILITY AIEXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE SOFTWARE PRODUCTS, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. STABILITY AI MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE SOFTWARE PRODUCTS WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS. 5. LIMITATION OF LIABILITY TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL STABILITY AI BE LIABLE TO YOU (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF STABILITY AI HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE SOFTWARE PRODUCTS, THEIR CONSTITUENT COMPONENTS, AND ANY OUTPUT (COLLECTIVELY, “SOFTWARE MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE SOFTWARE MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE SOFTWARE MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE SOFTWARE MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE. 6. INDEMNIFICATION You will indemnify, defend and hold harmless Stability AI and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “Stability AI Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any Stability AI Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to: (a) your access to or use of the Software Products (as well as any results or data generated from such access or use), including any High-Risk Use (defined below); (b) your violation of this License; or (c) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Stability AI Parties of any such Claims, and cooperate with Stability AI Parties in defending such Claims. You will also grant the Stability AI Parties sole control of the defense or settlement, at Stability AI’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Stability AI or the other Stability AI Parties. 7. TERMINATION; SURVIVAL a. This License will automatically terminate upon any breach by you of the terms of this License. b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you. c. The following sections survive termination of this License: 2 (Restrictions), 3 (Attribution), 4 (Disclaimers), 5 (Limitation on Liability), 6 (Indemnification) 7 (Termination; Survival), 8 (Third Party Materials), 9 (Trademarks), 10 (Applicable Law; Dispute Resolution), and 11 (Miscellaneous). 8. THIRD PARTY MATERIALS The Software Products may contain third-party software or other components (including free and open source software) (all of the foregoing, “Third Party Materials”), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Stability AI does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk. 9. TRADEMARKS Licensee has not been granted any trademark license as part of this License and may not use any name or mark associated with Stability AI without the prior written permission of Stability AI, except to the extent necessary to make the reference required by the “ATTRIBUTION” section of this Agreement. 10. APPLICABLE LAW; DISPUTE RESOLUTION This License will be governed and construed under the laws of the State of California without regard to conflicts of law provisions. Any suit or proceeding arising out of or relating to this License will be brought in the federal or state courts, as applicable, in San Mateo County, California, and each party irrevocably submits to the jurisdiction and venue of such courts. 11. MISCELLANEOUS If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Stability AI to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the Documentation, contains the entire understanding between you and Stability AI regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Stability AI regarding such subject matter. No change or addition to any provision of this License will be binding unless it is in writing and signed by an authorized representative of both you and Stability AI. ================================================ FILE: model_licenses/LICENSE-SDXL1.0 ================================================ Copyright (c) 2023 Stability AI CreativeML Open RAIL++-M License dated July 26, 2023 Section I: PREAMBLE Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation. Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations. In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation. Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this agreement aims to strike a balance between both in order to enable responsible open-science in the field of AI. This CreativeML Open RAIL++-M License governs the use of the model (and its derivatives) and is informed by the model card associated with the model. NOW THEREFORE, You and Licensor agree as follows: Definitions "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document. "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License. "Output" means the results of operating a Model as embodied in informational content resulting therefrom. "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material. "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model. "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any. "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access. "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model. "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator. "Third Parties" means individuals or legal entities that are not under common control with Licensor or You. "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model. Section II: INTELLECTUAL PROPERTY RIGHTS Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed. Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions: Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material. You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License; You must cause any modified files to carry prominent notices stating that You changed the files; You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5). The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License. Section IV: OTHER PROVISIONS Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. END OF TERMS AND CONDITIONS Attachment A Use Restrictions You agree not to use the Model or Derivatives of the Model: In any way that violates any applicable national, federal, state, local or international law or regulation; For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; To generate or disseminate verifiably false information and/or content with the purpose of harming others; To generate or disseminate personal identifiable information that can be used to harm an individual; To defame, disparage or otherwise harass others; For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation; For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics; To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm; For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories; To provide medical advice and medical results interpretation; To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use). ================================================ FILE: model_licenses/LICENSE-SV3D ================================================ STABILITY AI NON-COMMERCIAL COMMUNITY LICENSE AGREEMENT Dated: March 18, 2024 "Agreement" means this Stable Non-Commercial Research Community License Agreement. “AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time. "Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws, (b) any modifications to a Model, and (c) any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model. “Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software. "Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf. “Model(s)" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement. “Non-Commercial Uses” means exercising any of the rights granted herein for the purpose of research or non-commercial purposes. Non-Commercial Uses does not include any production use of the Software Products or any Derivative Works. "Stability AI" or "we" means Stability AI Ltd and its affiliates. "Software" means Stability AI’s proprietary software made available under this Agreement. “Software Products” means the Models, Software and Documentation, individually or in any combination. 1. License Rights and Redistribution. a. Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned or controlled by Stability AI embodied in the Software Products to use, reproduce, distribute, and create Derivative Works of, the Software Products, in each case for Non-Commercial Uses only. b. You may not use the Software Products or Derivative Works to enable third parties to use the Software Products or Derivative Works as part of your hosted service or via your APIs, whether you are adding substantial additional functionality thereto or not. Merely distributing the Software Products or Derivative Works for download online without offering any related service (ex. by distributing the Models on HuggingFace) is not a violation of this subsection. If you wish to use the Software Products or any Derivative Works for commercial or production use or you wish to make the Software Products or any Derivative Works available to third parties via your hosted service or your APIs, contact Stability AI at https://stability.ai/contact. c. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Non-Commercial Research Community License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified. 2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS. 3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. 4. Intellectual Property. a. No trademark licenses are granted under this Agreement, and in connection with the Software Products or Derivative Works, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products or Derivative Works. b. Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works c. If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products or Derivative Works in violation of this Agreement. 5. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of any Software Products or Derivative Works. Sections 2-4 shall survive the termination of this Agreement. 6. Governing Law. This Agreement will be governed by and construed in accordance with the laws of the United States and the State of California without regard to choice of law principles. ================================================ FILE: model_licenses/LICENSE-SVD ================================================ STABLE VIDEO DIFFUSION NON-COMMERCIAL COMMUNITY LICENSE AGREEMENT Dated: November 21, 2023 “AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time. "Agreement" means the terms and conditions for use, reproduction, distribution and modification of the Software Products set forth herein. "Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model. “Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software. "Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf. "Stability AI" or "we" means Stability AI Ltd. "Software" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement. “Software Products” means Software and Documentation. By using or distributing any portion or element of the Software Products, you agree to be bound by this Agreement. License Rights and Redistribution. Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Software Products to reproduce, distribute, and create Derivative Works of the Software Products for purposes other than commercial or production use. b. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "Stable Video Diffusion is licensed under the Stable Video Diffusion Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified. 2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS. 3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. 3. Intellectual Property. a. No trademark licenses are granted under this Agreement, and in connection with the Software Products, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products. Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works. If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products in violation of this Agreement. 4. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Software Products. Sections 2-4 shall survive the termination of this Agreement. ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["hatchling"] build-backend = "hatchling.build" [project] name = "sgm" dynamic = ["version"] description = "Stability Generative Models" readme = "README.md" license-files = { paths = ["LICENSE-CODE"] } requires-python = ">=3.8" [project.urls] Homepage = "https://github.com/Stability-AI/generative-models" [tool.hatch.version] path = "sgm/__init__.py" [tool.hatch.build] # This needs to be explicitly set so the configuration files # grafted into the `sgm` directory get included in the wheel's # RECORD file. include = [ "sgm", ] # The force-include configurations below make Hatch copy # the configs/ directory (containing the various YAML files required # to generatively model) into the source distribution and the wheel. [tool.hatch.build.targets.sdist.force-include] "./configs" = "sgm/configs" [tool.hatch.build.targets.wheel.force-include] "./configs" = "sgm/configs" [tool.hatch.envs.ci] skip-install = false dependencies = [ "pytest" ] [tool.hatch.envs.ci.scripts] test-inference = [ "pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118", "pip install -r requirements/pt2.txt", "pytest -v tests/inference/test_inference.py {args}", ] ================================================ FILE: pytest.ini ================================================ [pytest] markers = inference: mark as inference test (deselect with '-m "not inference"') ================================================ FILE: requirements/pt2.txt ================================================ black==23.7.0 chardet==5.1.0 clip @ git+https://github.com/openai/CLIP.git einops>=0.6.1 fairscale>=0.4.13 fire>=0.5.0 fsspec>=2023.6.0 imageio[ffmpeg] imageio[pyav] invisible-watermark>=0.2.0 kornia==0.6.9 matplotlib>=3.7.2 natsort>=8.4.0 ninja>=1.11.1 numpy==2.1 omegaconf>=2.3.0 onnxruntime open-clip-torch>=2.20.0 opencv-python==4.6.0.66 pandas>=2.0.3 pillow>=9.5.0 pudb>=2022.1.3 pytorch-lightning==2.0.1 pyyaml>=6.0.1 rembg scipy>=1.10.1 streamlit>=0.73.1 tensorboardx==2.6 timm>=0.9.2 tokenizers==0.12.1 torch>=2.0.1 torchaudio>=2.0.2 torchdata==0.6.1 torchmetrics>=1.0.1 torchvision>=0.15.2 tqdm>=4.65.0 transformers==4.19.1 triton==2.0.0 urllib3<1.27,>=1.25.4 wandb>=0.15.6 webdataset>=0.2.33 wheel>=0.41.0 xformers>=0.0.20 gradio streamlit-keyup==0.2.0 ================================================ FILE: scripts/__init__.py ================================================ ================================================ FILE: scripts/demo/__init__.py ================================================ ================================================ FILE: scripts/demo/detect.py ================================================ import argparse import cv2 import numpy as np try: from imwatermark import WatermarkDecoder except ImportError as e: try: # Assume some of the other dependencies such as torch are not fulfilled # import file without loading unnecessary libraries. import importlib.util import sys spec = importlib.util.find_spec("imwatermark.maxDct") assert spec is not None maxDct = importlib.util.module_from_spec(spec) sys.modules["maxDct"] = maxDct spec.loader.exec_module(maxDct) class WatermarkDecoder(object): """A minimal version of https://github.com/ShieldMnt/invisible-watermark/blob/main/imwatermark/watermark.py to only reconstruct bits using dwtDct""" def __init__(self, wm_type="bytes", length=0): assert wm_type == "bits", "Only bits defined in minimal import" self._wmType = wm_type self._wmLen = length def reconstruct(self, bits): if len(bits) != self._wmLen: raise RuntimeError("bits are not matched with watermark length") return bits def decode(self, cv2Image, method="dwtDct", **configs): (r, c, channels) = cv2Image.shape if r * c < 256 * 256: raise RuntimeError("image too small, should be larger than 256x256") bits = [] assert method == "dwtDct" embed = maxDct.EmbedMaxDct(watermarks=[], wmLen=self._wmLen, **configs) bits = embed.decode(cv2Image) return self.reconstruct(bits) except: raise e # A fixed 48-bit message that was choosen at random # WATERMARK_MESSAGE = 0xB3EC907BB19E WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] MATCH_VALUES = [ [27, "No watermark detected"], [33, "Partial watermark match. Cannot determine with certainty."], [ 35, ( "Likely watermarked. In our test 0.02% of real images were " 'falsely detected as "Likely watermarked"' ), ], [ 49, ( "Very likely watermarked. In our test no real images were " 'falsely detected as "Very likely watermarked"' ), ], ] class GetWatermarkMatch: def __init__(self, watermark): self.watermark = watermark self.num_bits = len(self.watermark) self.decoder = WatermarkDecoder("bits", self.num_bits) def __call__(self, x: np.ndarray) -> np.ndarray: """ Detects the number of matching bits the predefined watermark with one or multiple images. Images should be in cv2 format, e.g. h x w x c BGR. Args: x: ([B], h w, c) in range [0, 255] Returns: number of matched bits ([B],) """ squeeze = len(x.shape) == 3 if squeeze: x = x[None, ...] bs = x.shape[0] detected = np.empty((bs, self.num_bits), dtype=bool) for k in range(bs): detected[k] = self.decoder.decode(x[k], "dwtDct") result = np.sum(detected == self.watermark, axis=-1) if squeeze: return result[0] else: return result get_watermark_match = GetWatermarkMatch(WATERMARK_BITS) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "filename", nargs="+", type=str, help="Image files to check for watermarks", ) opts = parser.parse_args() print( """ This script tries to detect watermarked images. Please be aware of the following: - As the watermark is supposed to be invisible, there is the risk that watermarked images may not be detected. - To maximize the chance of detection make sure that the image has the same dimensions as when the watermark was applied (most likely 1024x1024 or 512x512). - Specific image manipulation may drastically decrease the chance that watermarks can be detected. - There is also the chance that an image has the characteristics of the watermark by chance. - The watermark script is public, anybody may watermark any images, and could therefore claim it to be generated. - All numbers below are based on a test using 10,000 images without any modifications after applying the watermark. """ ) for fn in opts.filename: image = cv2.imread(fn) if image is None: print(f"Couldn't read {fn}. Skipping") continue num_bits = get_watermark_match(image) k = 0 while num_bits > MATCH_VALUES[k][0]: k += 1 print( f"{fn}: {MATCH_VALUES[k][1]}", f"Bits that matched the watermark {num_bits} from {len(WATERMARK_BITS)}\n", sep="\n\t", ) ================================================ FILE: scripts/demo/discretization.py ================================================ import torch from sgm.modules.diffusionmodules.discretizer import Discretization class Img2ImgDiscretizationWrapper: """ wraps a discretizer, and prunes the sigmas params: strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) """ def __init__(self, discretization: Discretization, strength: float = 1.0): self.discretization = discretization self.strength = strength assert 0.0 <= self.strength <= 1.0 def __call__(self, *args, **kwargs): # sigmas start large first, and decrease then sigmas = self.discretization(*args, **kwargs) print(f"sigmas after discretization, before pruning img2img: ", sigmas) sigmas = torch.flip(sigmas, (0,)) sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] print("prune index:", max(int(self.strength * len(sigmas)), 1)) sigmas = torch.flip(sigmas, (0,)) print(f"sigmas after pruning: ", sigmas) return sigmas class Txt2NoisyDiscretizationWrapper: """ wraps a discretizer, and prunes the sigmas params: strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned) """ def __init__( self, discretization: Discretization, strength: float = 0.0, original_steps=None ): self.discretization = discretization self.strength = strength self.original_steps = original_steps assert 0.0 <= self.strength <= 1.0 def __call__(self, *args, **kwargs): # sigmas start large first, and decrease then sigmas = self.discretization(*args, **kwargs) print(f"sigmas after discretization, before pruning img2img: ", sigmas) sigmas = torch.flip(sigmas, (0,)) if self.original_steps is None: steps = len(sigmas) else: steps = self.original_steps + 1 prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0) sigmas = sigmas[prune_index:] print("prune index:", prune_index) sigmas = torch.flip(sigmas, (0,)) print(f"sigmas after pruning: ", sigmas) return sigmas ================================================ FILE: scripts/demo/gradio_app.py ================================================ # Adding this at the very top of app.py to make 'generative-models' directory discoverable import os import sys sys.path.append(os.path.join(os.path.dirname(__file__), "generative-models")) import math import random import uuid from glob import glob from pathlib import Path from typing import Optional import cv2 import gradio as gr import numpy as np import torch from einops import rearrange, repeat from fire import Fire from huggingface_hub import hf_hub_download from omegaconf import OmegaConf from PIL import Image from torchvision.transforms import ToTensor from scripts.sampling.simple_video_sample import ( get_batch, get_unique_embedder_keys_from_conditioner, load_model, ) from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering from sgm.inference.helpers import embed_watermark from sgm.util import default, instantiate_from_config # To download all svd models # hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid-xt", filename="svd_xt.safetensors", local_dir="checkpoints") # hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid", filename="svd.safetensors", local_dir="checkpoints") # hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid-xt-1-1", filename="svd_xt_1_1.safetensors", local_dir="checkpoints") # Define the repo, local directory and filename repo_id = "stabilityai/stable-video-diffusion-img2vid-xt-1-1" # replace with "stabilityai/stable-video-diffusion-img2vid-xt" or "stabilityai/stable-video-diffusion-img2vid" for other models filename = "svd_xt_1_1.safetensors" # replace with "svd_xt.safetensors" or "svd.safetensors" for other models local_dir = "checkpoints" local_file_path = os.path.join(local_dir, filename) # Check if the file already exists if not os.path.exists(local_file_path): # If the file doesn't exist, download it hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir) print("File downloaded.") else: print("File already exists. No need to download.") version = "svd_xt_1_1" # replace with 'svd_xt' or 'svd' for other models device = "cuda" max_64_bit_int = 2**63 - 1 if version == "svd_xt_1_1": num_frames = 25 num_steps = 30 model_config = "scripts/sampling/configs/svd_xt_1_1.yaml" else: raise ValueError(f"Version {version} does not exist.") model, filter = load_model( model_config, device, num_frames, num_steps, ) def sample( input_path: str = "assets/test_image.png", # Can either be image file or folder with image files seed: Optional[int] = None, randomize_seed: bool = True, motion_bucket_id: int = 127, fps_id: int = 6, version: str = "svd_xt_1_1", cond_aug: float = 0.02, decoding_t: int = 7, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. device: str = "cuda", output_folder: str = "outputs", progress=gr.Progress(track_tqdm=True), ): """ Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`. """ fps_id = int(fps_id) # casting float slider values to int) if randomize_seed: seed = random.randint(0, max_64_bit_int) torch.manual_seed(seed) path = Path(input_path) all_img_paths = [] if path.is_file(): if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]): all_img_paths = [input_path] else: raise ValueError("Path is not valid image file.") elif path.is_dir(): all_img_paths = sorted( [ f for f in path.iterdir() if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"] ] ) if len(all_img_paths) == 0: raise ValueError("Folder does not contain any images.") else: raise ValueError for input_img_path in all_img_paths: with Image.open(input_img_path) as image: if image.mode == "RGBA": image = image.convert("RGB") w, h = image.size if h % 64 != 0 or w % 64 != 0: width, height = map(lambda x: x - x % 64, (w, h)) image = image.resize((width, height)) print( f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!" ) image = ToTensor()(image) image = image * 2.0 - 1.0 image = image.unsqueeze(0).to(device) H, W = image.shape[2:] assert image.shape[1] == 3 F = 8 C = 4 shape = (num_frames, C, H // F, W // F) if (H, W) != (576, 1024): print( "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`." ) if motion_bucket_id > 255: print( "WARNING: High motion bucket! This may lead to suboptimal performance." ) if fps_id < 5: print("WARNING: Small fps value! This may lead to suboptimal performance.") if fps_id > 30: print("WARNING: Large fps value! This may lead to suboptimal performance.") value_dict = {} value_dict["motion_bucket_id"] = motion_bucket_id value_dict["fps_id"] = fps_id value_dict["cond_aug"] = cond_aug value_dict["cond_frames_without_noise"] = image value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image) value_dict["cond_aug"] = cond_aug with torch.no_grad(): with torch.autocast(device): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, [1, num_frames], T=num_frames, device=device, ) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=[ "cond_frames", "cond_frames_without_noise", ], ) for k in ["crossattn", "concat"]: uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) randn = torch.randn(shape, device=device) additional_model_inputs = {} additional_model_inputs["image_only_indicator"] = torch.zeros( 2, num_frames ).to(device) additional_model_inputs["num_video_frames"] = batch["num_video_frames"] def denoiser(input, sigma, c): return model.denoiser( model.model, input, sigma, c, **additional_model_inputs ) samples_z = model.sampler(denoiser, randn, cond=c, uc=uc) model.en_and_decode_n_samples_a_time = decoding_t samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) os.makedirs(output_folder, exist_ok=True) base_count = len(glob(os.path.join(output_folder, "*.mp4"))) video_path = os.path.join(output_folder, f"{base_count:06d}.mp4") writer = cv2.VideoWriter( video_path, cv2.VideoWriter_fourcc(*"mp4v"), fps_id + 1, (samples.shape[-1], samples.shape[-2]), ) samples = embed_watermark(samples) samples = filter(samples) vid = ( (rearrange(samples, "t c h w -> t h w c") * 255) .cpu() .numpy() .astype(np.uint8) ) for frame in vid: frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) writer.write(frame) writer.release() return video_path, seed def resize_image(image_path, output_size=(1024, 576)): image = Image.open(image_path) # Calculate aspect ratios target_aspect = output_size[0] / output_size[1] # Aspect ratio of the desired size image_aspect = image.width / image.height # Aspect ratio of the original image # Resize then crop if the original image is larger if image_aspect > target_aspect: # Resize the image to match the target height, maintaining aspect ratio new_height = output_size[1] new_width = int(new_height * image_aspect) resized_image = image.resize((new_width, new_height), Image.LANCZOS) # Calculate coordinates for cropping left = (new_width - output_size[0]) / 2 top = 0 right = (new_width + output_size[0]) / 2 bottom = output_size[1] else: # Resize the image to match the target width, maintaining aspect ratio new_width = output_size[0] new_height = int(new_width / image_aspect) resized_image = image.resize((new_width, new_height), Image.LANCZOS) # Calculate coordinates for cropping left = 0 top = (new_height - output_size[1]) / 2 right = output_size[0] bottom = (new_height + output_size[1]) / 2 # Crop the image cropped_image = resized_image.crop((left, top, right, bottom)) return cropped_image with gr.Blocks() as demo: gr.Markdown( """# Community demo for Stable Video Diffusion - Img2Vid - XT ([model](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt), [paper](https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets)) #### Research release ([_non-commercial_](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/blob/main/LICENSE)): generate `4s` vid from a single image at (`25 frames` at `6 fps`). Generation takes ~60s in an A100. [Join the waitlist for Stability's upcoming web experience](https://stability.ai/contact). """ ) with gr.Row(): with gr.Column(): image = gr.Image(label="Upload your image", type="filepath") generate_btn = gr.Button("Generate") video = gr.Video() with gr.Accordion("Advanced options", open=False): seed = gr.Slider( label="Seed", value=42, randomize=True, minimum=0, maximum=max_64_bit_int, step=1, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) motion_bucket_id = gr.Slider( label="Motion bucket id", info="Controls how much motion to add/remove from the image", value=127, minimum=1, maximum=255, ) fps_id = gr.Slider( label="Frames per second", info="The length of your video in seconds will be 25/fps", value=6, minimum=5, maximum=30, ) image.upload(fn=resize_image, inputs=image, outputs=image, queue=False) generate_btn.click( fn=sample, inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id], outputs=[video, seed], api_name="video", ) if __name__ == "__main__": demo.queue(max_size=20) demo.launch(share=True) ================================================ FILE: scripts/demo/gradio_app_sv4d.py ================================================ # Adding this at the very top of app.py to make 'generative-models' directory discoverable import os import sys sys.path.append(os.path.join(os.path.dirname(__file__), "generative-models")) from glob import glob from typing import Optional import gradio as gr import numpy as np import torch from huggingface_hub import hf_hub_download from typing import List, Optional, Union import torchvision from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder from scripts.demo.sv4d_helpers import ( decode_latents, load_model, initial_model_load, read_video, run_img2vid, prepare_inputs, do_sample_per_step, sample_sv3d, save_video, preprocess_video, ) # the tmp path, if /tmp/gradio is not writable, change it to a writable path # os.environ["GRADIO_TEMP_DIR"] = "gradio_tmp" version = "sv4d" # replace with 'sv3d_p' or 'sv3d_u' for other models # Define the repo, local directory and filename repo_id = "stabilityai/sv4d" filename = f"{version}.safetensors" # replace with "sv3d_u.safetensors" or "sv3d_p.safetensors" local_dir = "checkpoints" local_ckpt_path = os.path.join(local_dir, filename) # Check if the file already exists if not os.path.exists(local_ckpt_path): # If the file doesn't exist, download it hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir) print("File downloaded. (sv4d)") else: print("File already exists. No need to download. (sv4d)") device = "cuda" max_64_bit_int = 2**63 - 1 num_frames = 21 num_steps = 20 model_config = f"scripts/sampling/configs/{version}.yaml" # Set model config T = 5 # number of frames per sample V = 8 # number of views per sample F = 8 # vae factor to downsize image->latent C = 4 H, W = 576, 576 n_frames = 21 # number of input and output video frames n_views = V + 1 # number of output video views (1 input view + 8 novel views) n_views_sv3d = 21 subsampled_views = np.array( [0, 2, 5, 7, 9, 12, 14, 16, 19] ) # subsample (V+1=)9 (uniform) views from 21 SV3D views version_dict = { "T": T * V, "H": H, "W": W, "C": C, "f": F, "options": { "discretization": 1, "cfg": 3, "sigma_min": 0.002, "sigma_max": 700.0, "rho": 7.0, "guider": 5, "num_steps": num_steps, "force_uc_zero_embeddings": [ "cond_frames", "cond_frames_without_noise", "cond_view", "cond_motion", ], "additional_guider_kwargs": { "additional_cond_keys": ["cond_view", "cond_motion"] }, }, } # Load SV4D model model, filter = load_model( model_config, device, version_dict["T"], num_steps, ) model = initial_model_load(model) # -----------sv3d config and model loading---------------- # if version == "sv3d_u": sv3d_model_config = "scripts/sampling/configs/sv3d_u.yaml" # elif version == "sv3d_p": # sv3d_model_config = "scripts/sampling/configs/sv3d_p.yaml" # else: # raise ValueError(f"Version {version} does not exist.") # Define the repo, local directory and filename repo_id = "stabilityai/sv3d" filename = f"sv3d_u.safetensors" # replace with "sv3d_u.safetensors" or "sv3d_p.safetensors" local_dir = "checkpoints" local_ckpt_path = os.path.join(local_dir, filename) # Check if the file already exists if not os.path.exists(local_ckpt_path): # If the file doesn't exist, download it hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir) print("File downloaded. (sv3d)") else: print("File already exists. No need to download. (sv3d)") # load sv3d model sv3d_model, filter = load_model( sv3d_model_config, device, 21, num_steps, verbose=False, ) sv3d_model = initial_model_load(sv3d_model) # ------------------ def sample_anchor( input_path: str = "assets/test_image.png", # Can either be image file or folder with image files seed: Optional[int] = None, encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary. decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. num_steps: int = 20, sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p fps_id: int = 6, motion_bucket_id: int = 127, cond_aug: float = 1e-5, device: str = "cuda", elevations_deg: Optional[Union[float, List[float]]] = 10.0, azimuths_deg: Optional[List[float]] = None, verbose: Optional[bool] = False, ): """ Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`. """ output_folder = os.path.dirname(input_path) torch.manual_seed(seed) os.makedirs(output_folder, exist_ok=True) # Read input video frames i.e. images at view 0 print(f"Reading {input_path}") images_v0 = read_video( input_path, n_frames=n_frames, device=device, ) # Get camera viewpoints if isinstance(elevations_deg, float) or isinstance(elevations_deg, int): elevations_deg = [elevations_deg] * n_views_sv3d assert ( len(elevations_deg) == n_views_sv3d ), f"Please provide 1 value, or a list of {n_views_sv3d} values for elevations_deg! Given {len(elevations_deg)}" if azimuths_deg is None: azimuths_deg = np.linspace(0, 360, n_views_sv3d + 1)[1:] % 360 assert ( len(azimuths_deg) == n_views_sv3d ), f"Please provide a list of {n_views_sv3d} values for azimuths_deg! Given {len(azimuths_deg)}" polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg]) azimuths_rad = np.array( [np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg] ) # Sample multi-view images of the first frame using SV3D i.e. images at time 0 sv3d_model.sampler.num_steps = num_steps print("sv3d_model.sampler.num_steps", sv3d_model.sampler.num_steps) images_t0 = sample_sv3d( images_v0[0], n_views_sv3d, num_steps, sv3d_version, fps_id, motion_bucket_id, cond_aug, decoding_t, device, polars_rad, azimuths_rad, verbose, sv3d_model, ) images_t0 = torch.roll(images_t0, 1, 0) # move conditioning image to first frame sv3d_file = os.path.join(output_folder, "t000.mp4") save_video(sv3d_file, images_t0.unsqueeze(1)) for emb in model.conditioner.embedders: if isinstance(emb, VideoPredictionEmbedderWithEncoder): emb.en_and_decode_n_samples_a_time = encoding_t model.en_and_decode_n_samples_a_time = decoding_t # Initialize image matrix img_matrix = [[None] * n_views for _ in range(n_frames)] for i, v in enumerate(subsampled_views): img_matrix[0][i] = images_t0[v].unsqueeze(0) for t in range(n_frames): img_matrix[t][0] = images_v0[t] # Interleaved sampling for anchor frames t0, v0 = 0, 0 frame_indices = np.arange(T - 1, n_frames, T - 1) # [4, 8, 12, 16, 20] view_indices = np.arange(V) + 1 print(f"Sampling anchor frames {frame_indices}") image = img_matrix[t0][v0] cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0) cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0) polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() azims = (azims - azimuths_rad[v0]) % (torch.pi * 2) model.sampler.num_steps = num_steps version_dict["options"]["num_steps"] = num_steps samples = run_img2vid( version_dict, model, image, seed, polars, azims, cond_motion, cond_view, decoding_t ) samples = samples.view(T, V, 3, H, W) for i, t in enumerate(frame_indices): for j, v in enumerate(view_indices): if img_matrix[t][v] is None: img_matrix[t][v] = samples[i, j][None] * 2 - 1 # concat video grid_list = [] for t in frame_indices: imgs_view = torch.cat(img_matrix[t]) grid_list.append(torchvision.utils.make_grid(imgs_view, nrow=3).unsqueeze(0)) # save output videos anchor_vis_file = os.path.join(output_folder, "anchor_vis.mp4") save_video(anchor_vis_file, grid_list, fps=3) anchor_file = os.path.join(output_folder, "anchor.mp4") image_list = samples.view(T*V, 3, H, W).unsqueeze(1) * 2 - 1 save_video(anchor_file, image_list) return sv3d_file, anchor_vis_file, anchor_file def sample_all( input_path: str = "inputs/test_video1.mp4", # Can either be video file or folder with image files sv3d_path: str = "outputs/sv4d/000000_t000.mp4", anchor_path: str = "outputs/sv4d/000000_anchor.mp4", seed: Optional[int] = None, num_steps: int = 20, device: str = "cuda", elevations_deg: Optional[Union[float, List[float]]] = 10.0, azimuths_deg: Optional[List[float]] = None, ): """ Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`. """ output_folder = os.path.dirname(input_path) torch.manual_seed(seed) os.makedirs(output_folder, exist_ok=True) # Read input video frames i.e. images at view 0 print(f"Reading {input_path}") images_v0 = read_video( input_path, n_frames=n_frames, device=device, ) images_t0 = read_video( sv3d_path, n_frames=n_views_sv3d, device=device, ) # Get camera viewpoints if isinstance(elevations_deg, float) or isinstance(elevations_deg, int): elevations_deg = [elevations_deg] * n_views_sv3d assert ( len(elevations_deg) == n_views_sv3d ), f"Please provide 1 value, or a list of {n_views_sv3d} values for elevations_deg! Given {len(elevations_deg)}" if azimuths_deg is None: azimuths_deg = np.linspace(0, 360, n_views_sv3d + 1)[1:] % 360 assert ( len(azimuths_deg) == n_views_sv3d ), f"Please provide a list of {n_views_sv3d} values for azimuths_deg! Given {len(azimuths_deg)}" polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg]) azimuths_rad = np.array( [np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg] ) # Initialize image matrix img_matrix = [[None] * n_views for _ in range(n_frames)] for i, v in enumerate(subsampled_views): img_matrix[0][i] = images_t0[v] for t in range(n_frames): img_matrix[t][0] = images_v0[t] # load interleaved sampling for anchor frames t0, v0 = 0, 0 frame_indices = np.arange(T - 1, n_frames, T - 1) # [4, 8, 12, 16, 20] view_indices = np.arange(V) + 1 anchor_frames = read_video( anchor_path, n_frames=T * V, device=device, ) anchor_frames = torch.cat(anchor_frames).view(T, V, 3, H, W) for i, t in enumerate(frame_indices): for j, v in enumerate(view_indices): if img_matrix[t][v] is None: img_matrix[t][v] = anchor_frames[i, j][None] # Dense sampling for the rest print(f"Sampling dense frames:") for t0 in np.arange(0, n_frames - 1, T - 1): # [0, 4, 8, 12, 16] frame_indices = t0 + np.arange(T) print(f"Sampling dense frames {frame_indices}") latent_matrix = torch.randn(n_frames, n_views, C, H // F, W // F).to("cuda") polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() azims = (azims - azimuths_rad[v0]) % (torch.pi * 2) # alternate between forward and backward conditioning forward_inputs, forward_frame_indices, backward_inputs, backward_frame_indices = prepare_inputs( frame_indices, img_matrix, v0, view_indices, model, version_dict, seed, polars, azims ) for step in range(num_steps): if step % 2 == 1: c, uc, additional_model_inputs, sampler = forward_inputs frame_indices = forward_frame_indices else: c, uc, additional_model_inputs, sampler = backward_inputs frame_indices = backward_frame_indices noisy_latents = latent_matrix[frame_indices][:, view_indices].flatten(0, 1) samples = do_sample_per_step( model, sampler, noisy_latents, c, uc, step, additional_model_inputs, ) samples = samples.view(T, V, C, H // F, W // F) for i, t in enumerate(frame_indices): for j, v in enumerate(view_indices): latent_matrix[t, v] = samples[i, j] img_matrix = decode_latents(model, latent_matrix, img_matrix, frame_indices, view_indices, T) # concat video grid_list = [] for t in range(n_frames): imgs_view = torch.cat(img_matrix[t]) grid_list.append(torchvision.utils.make_grid(imgs_view, nrow=3).unsqueeze(0)) # save output videos vid_file = os.path.join(output_folder, "sv4d_final.mp4") save_video(vid_file, grid_list) return vid_file, seed with gr.Blocks() as demo: gr.Markdown( """# Demo for SV4D from Stability AI ([model](https://huggingface.co/stabilityai/sv4d), [news](https://stability.ai/news/stable-video-4d)) #### Research release ([_non-commercial_](https://huggingface.co/stabilityai/sv4d/blob/main/LICENSE.md)): generate 8 novel view videos from a single-view video (with white background). #### It takes ~45s to generate anchor frames and another ~160s to generate full results (21 frames). #### Hints for improving performance: - Use a white background; - Make the object in the center of the image; - The SV4D process the first 21 frames of the uploaded video. Gradio provides a nice option of trimming the uploaded video if needed. """ ) with gr.Row(): with gr.Column(): input_video = gr.Video(label="Upload your video") generate_btn = gr.Button("Step 1: generate 8 novel view videos (5 anchor frames each)") interpolate_btn = gr.Button("Step 2: Extend novel view videos to 21 frames") with gr.Column(): anchor_video = gr.Video(label="SV4D outputs (anchor frames)") sv3d_video = gr.Video(label="SV3D outputs", interactive=False) with gr.Column(): sv4d_interpolated_video = gr.Video(label="SV4D outputs (21 frames)") with gr.Accordion("Advanced options", open=False): seed = gr.Slider( label="Seed", value=23, # randomize=True, minimum=0, maximum=100, step=1, ) encoding_t = gr.Slider( label="Encode n frames at a time", info="Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.", value=8, minimum=1, maximum=40, ) decoding_t = gr.Slider( label="Decode n frames at a time", info="Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.", value=4, minimum=1, maximum=14, ) denoising_steps = gr.Slider( label="Number of denoising steps", info="Increase will improve the performance but needs more time.", value=20, minimum=10, maximum=50, step=1, ) remove_bg = gr.Checkbox( label="Remove background", info="We use rembg. Users can check the alternative way: SAM2 (https://github.com/facebookresearch/segment-anything-2)", ) input_video.upload(fn=preprocess_video, inputs=[input_video, remove_bg], outputs=input_video, queue=False) with gr.Row(visible=False): anchor_frames = gr.Video() generate_btn.click( fn=sample_anchor, inputs=[input_video, seed, encoding_t, decoding_t, denoising_steps], outputs=[sv3d_video, anchor_video, anchor_frames], api_name="SV4D output (5 frames)", ) interpolate_btn.click( fn=sample_all, inputs=[input_video, sv3d_video, anchor_frames, seed, denoising_steps], outputs=[sv4d_interpolated_video, seed], api_name="SV4D interpolation (21 frames)", ) examples = gr.Examples( fn=preprocess_video, examples=[ "./assets/sv4d_videos/test_video1.mp4", "./assets/sv4d_videos/test_video2.mp4", "./assets/sv4d_videos/green_robot.mp4", "./assets/sv4d_videos/dolphin.mp4", "./assets/sv4d_videos/lucia_v000.mp4", "./assets/sv4d_videos/snowboard_v000.mp4", "./assets/sv4d_videos/stroller_v000.mp4", "./assets/sv4d_videos/human5.mp4", "./assets/sv4d_videos/bunnyman.mp4", "./assets/sv4d_videos/hiphop_parrot.mp4", "./assets/sv4d_videos/guppie_v0.mp4", "./assets/sv4d_videos/wave_hello.mp4", "./assets/sv4d_videos/pistol_v0.mp4", "./assets/sv4d_videos/human7.mp4", "./assets/sv4d_videos/monkey.mp4", "./assets/sv4d_videos/train_v0.mp4", ], inputs=[input_video], run_on_click=True, outputs=[input_video], ) if __name__ == "__main__": demo.queue(max_size=20) demo.launch(share=True) ================================================ FILE: scripts/demo/sampling.py ================================================ from pytorch_lightning import seed_everything from scripts.demo.streamlit_helpers import * SAVE_PATH = "outputs/demo/txt2img/" SD_XL_BASE_RATIOS = { "0.5": (704, 1408), "0.52": (704, 1344), "0.57": (768, 1344), "0.6": (768, 1280), "0.68": (832, 1216), "0.72": (832, 1152), "0.78": (896, 1152), "0.82": (896, 1088), "0.88": (960, 1088), "0.94": (960, 1024), "1.0": (1024, 1024), "1.07": (1024, 960), "1.13": (1088, 960), "1.21": (1088, 896), "1.29": (1152, 896), "1.38": (1152, 832), "1.46": (1216, 832), "1.67": (1280, 768), "1.75": (1344, 768), "1.91": (1344, 704), "2.0": (1408, 704), "2.09": (1472, 704), "2.4": (1536, 640), "2.5": (1600, 640), "2.89": (1664, 576), "3.0": (1728, 576), } VERSION2SPECS = { "SDXL-base-1.0": { "H": 1024, "W": 1024, "C": 4, "f": 8, "is_legacy": False, "config": "configs/inference/sd_xl_base.yaml", "ckpt": "checkpoints/sd_xl_base_1.0.safetensors", }, "SDXL-base-0.9": { "H": 1024, "W": 1024, "C": 4, "f": 8, "is_legacy": False, "config": "configs/inference/sd_xl_base.yaml", "ckpt": "checkpoints/sd_xl_base_0.9.safetensors", }, "SDXL-refiner-0.9": { "H": 1024, "W": 1024, "C": 4, "f": 8, "is_legacy": True, "config": "configs/inference/sd_xl_refiner.yaml", "ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors", }, "SDXL-refiner-1.0": { "H": 1024, "W": 1024, "C": 4, "f": 8, "is_legacy": True, "config": "configs/inference/sd_xl_refiner.yaml", "ckpt": "checkpoints/sd_xl_refiner_1.0.safetensors", }, } def load_img(display=True, key=None, device="cuda"): image = get_interactive_image(key=key) if image is None: return None if display: st.image(image) w, h = image.size print(f"loaded input image of size ({w}, {h})") width, height = map( lambda x: x - x % 64, (w, h) ) # resize to integer multiple of 64 image = image.resize((width, height)) image = np.array(image.convert("RGB")) image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 return image.to(device) def run_txt2img( state, version, version_dict, is_legacy=False, return_latents=False, filter=None, stage2strength=None, ): if version.startswith("SDXL-base"): W, H = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10) else: H = st.number_input("H", value=version_dict["H"], min_value=64, max_value=2048) W = st.number_input("W", value=version_dict["W"], min_value=64, max_value=2048) C = version_dict["C"] F = version_dict["f"] init_dict = { "orig_width": W, "orig_height": H, "target_width": W, "target_height": H, } value_dict = init_embedder_options( get_unique_embedder_keys_from_conditioner(state["model"].conditioner), init_dict, prompt=prompt, negative_prompt=negative_prompt, ) sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength) num_samples = num_rows * num_cols if st.button("Sample"): st.write(f"**Model I:** {version}") out = do_sample( state["model"], sampler, value_dict, num_samples, H, W, C, F, force_uc_zero_embeddings=["txt"] if not is_legacy else [], return_latents=return_latents, filter=filter, ) return out def run_img2img( state, version_dict, is_legacy=False, return_latents=False, filter=None, stage2strength=None, ): img = load_img() if img is None: return None H, W = img.shape[2], img.shape[3] init_dict = { "orig_width": W, "orig_height": H, "target_width": W, "target_height": H, } value_dict = init_embedder_options( get_unique_embedder_keys_from_conditioner(state["model"].conditioner), init_dict, prompt=prompt, negative_prompt=negative_prompt, ) strength = st.number_input( "**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0 ) sampler, num_rows, num_cols = init_sampling( img2img_strength=strength, stage2strength=stage2strength, ) num_samples = num_rows * num_cols if st.button("Sample"): out = do_img2img( repeat(img, "1 ... -> n ...", n=num_samples), state["model"], sampler, value_dict, num_samples, force_uc_zero_embeddings=["txt"] if not is_legacy else [], return_latents=return_latents, filter=filter, ) return out def apply_refiner( input, state, sampler, num_samples, prompt, negative_prompt, filter=None, finish_denoising=False, ): init_dict = { "orig_width": input.shape[3] * 8, "orig_height": input.shape[2] * 8, "target_width": input.shape[3] * 8, "target_height": input.shape[2] * 8, } value_dict = init_dict value_dict["prompt"] = prompt value_dict["negative_prompt"] = negative_prompt value_dict["crop_coords_top"] = 0 value_dict["crop_coords_left"] = 0 value_dict["aesthetic_score"] = 6.0 value_dict["negative_aesthetic_score"] = 2.5 st.warning(f"refiner input shape: {input.shape}") samples = do_img2img( input, state["model"], sampler, value_dict, num_samples, skip_encode=True, filter=filter, add_noise=not finish_denoising, ) return samples if __name__ == "__main__": st.title("Stable Diffusion") version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0) version_dict = VERSION2SPECS[version] if st.checkbox("Load Model"): mode = st.radio("Mode", ("txt2img", "img2img"), 0) else: mode = "skip" st.write("__________________________") set_lowvram_mode(st.checkbox("Low vram mode", True)) if version.startswith("SDXL-base"): add_pipeline = st.checkbox("Load SDXL-refiner?", False) st.write("__________________________") else: add_pipeline = False seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)) seed_everything(seed) save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version)) if mode != "skip": state = init_st(version_dict, load_filter=True) if state["msg"]: st.info(state["msg"]) model = state["model"] is_legacy = version_dict["is_legacy"] prompt = st.text_input( "prompt", "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", ) if is_legacy: negative_prompt = st.text_input("negative prompt", "") else: negative_prompt = "" # which is unused stage2strength = None finish_denoising = False if add_pipeline: st.write("__________________________") version2 = st.selectbox("Refiner:", ["SDXL-refiner-1.0", "SDXL-refiner-0.9"]) st.warning( f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) " ) st.write("**Refiner Options:**") version_dict2 = VERSION2SPECS[version2] state2 = init_st(version_dict2, load_filter=False) st.info(state2["msg"]) stage2strength = st.number_input( "**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0 ) sampler2, *_ = init_sampling( key=2, img2img_strength=stage2strength, specify_num_samples=False, ) st.write("__________________________") finish_denoising = st.checkbox("Finish denoising with refiner.", True) if not finish_denoising: stage2strength = None if mode == "txt2img": out = run_txt2img( state, version, version_dict, is_legacy=is_legacy, return_latents=add_pipeline, filter=state.get("filter"), stage2strength=stage2strength, ) elif mode == "img2img": out = run_img2img( state, version_dict, is_legacy=is_legacy, return_latents=add_pipeline, filter=state.get("filter"), stage2strength=stage2strength, ) elif mode == "skip": out = None else: raise ValueError(f"unknown mode {mode}") if isinstance(out, (tuple, list)): samples, samples_z = out else: samples = out samples_z = None if add_pipeline and samples_z is not None: st.write("**Running Refinement Stage**") samples = apply_refiner( samples_z, state2, sampler2, samples_z.shape[0], prompt=prompt, negative_prompt=negative_prompt if is_legacy else "", filter=state.get("filter"), finish_denoising=finish_denoising, ) if save_locally and samples is not None: perform_save_locally(save_path, samples) ================================================ FILE: scripts/demo/streamlit_helpers.py ================================================ import copy import math import os from glob import glob from typing import Dict, List, Optional, Tuple, Union import cv2 import imageio import numpy as np import streamlit as st import torch import torch.nn as nn import torchvision.transforms as TT from einops import rearrange, repeat from imwatermark import WatermarkEncoder from omegaconf import ListConfig, OmegaConf from PIL import Image from safetensors.torch import load_file as load_safetensors from scripts.demo.discretization import ( Img2ImgDiscretizationWrapper, Txt2NoisyDiscretizationWrapper, ) from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering from sgm.inference.helpers import embed_watermark from sgm.modules.diffusionmodules.guiders import ( LinearPredictionGuider, TrianglePredictionGuider, VanillaCFG, ) from sgm.modules.diffusionmodules.sampling import ( DPMPP2MSampler, DPMPP2SAncestralSampler, EulerAncestralSampler, EulerEDMSampler, HeunEDMSampler, LinearMultistepSampler, ) from sgm.util import append_dims, default, instantiate_from_config from torch import autocast from torchvision import transforms from torchvision.utils import make_grid, save_image @st.cache_resource() def init_st(version_dict, load_ckpt=True, load_filter=True): state = dict() if not "model" in state: config = version_dict["config"] ckpt = version_dict["ckpt"] config = OmegaConf.load(config) model, msg = load_model_from_config(config, ckpt if load_ckpt else None) state["msg"] = msg state["model"] = model state["ckpt"] = ckpt if load_ckpt else None state["config"] = config if load_filter: state["filter"] = DeepFloydDataFiltering(verbose=False) return state def load_model(model): model.cuda() lowvram_mode = False def set_lowvram_mode(mode): global lowvram_mode lowvram_mode = mode def initial_model_load(model): global lowvram_mode if lowvram_mode: model.model.half() else: model.cuda() return model def unload_model(model): global lowvram_mode if lowvram_mode: model.cpu() torch.cuda.empty_cache() def load_model_from_config(config, ckpt=None, verbose=True): model = instantiate_from_config(config.model) if ckpt is not None: print(f"Loading model from {ckpt}") if ckpt.endswith("ckpt"): pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: global_step = pl_sd["global_step"] st.info(f"loaded ckpt from global step {global_step}") print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] elif ckpt.endswith("safetensors"): sd = load_safetensors(ckpt) else: raise NotImplementedError msg = None m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") print(m) if len(u) > 0 and verbose: print("unexpected keys:") print(u) else: msg = None model = initial_model_load(model) model.eval() return model, msg def get_unique_embedder_keys_from_conditioner(conditioner): return list(set([x.input_key for x in conditioner.embedders])) def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): # Hardcoded demo settings; might undergo some changes in the future value_dict = {} for key in keys: if key == "txt": if prompt is None: prompt = "A professional photograph of an astronaut riding a pig" if negative_prompt is None: negative_prompt = "" prompt = st.text_input("Prompt", prompt) negative_prompt = st.text_input("Negative prompt", negative_prompt) value_dict["prompt"] = prompt value_dict["negative_prompt"] = negative_prompt if key == "original_size_as_tuple": orig_width = st.number_input( "orig_width", value=init_dict["orig_width"], min_value=16, ) orig_height = st.number_input( "orig_height", value=init_dict["orig_height"], min_value=16, ) value_dict["orig_width"] = orig_width value_dict["orig_height"] = orig_height if key == "crop_coords_top_left": crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0) crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0) value_dict["crop_coords_top"] = crop_coord_top value_dict["crop_coords_left"] = crop_coord_left if key == "aesthetic_score": value_dict["aesthetic_score"] = 6.0 value_dict["negative_aesthetic_score"] = 2.5 if key == "target_size_as_tuple": value_dict["target_width"] = init_dict["target_width"] value_dict["target_height"] = init_dict["target_height"] if key in ["fps_id", "fps"]: fps = st.number_input("fps", value=6, min_value=1) value_dict["fps"] = fps value_dict["fps_id"] = fps - 1 if key == "motion_bucket_id": mb_id = st.number_input("motion bucket id", 0, 511, value=127) value_dict["motion_bucket_id"] = mb_id if key == "pool_image": st.text("Image for pool conditioning") image = load_img( key="pool_image_input", size=224, center_crop=True, ) if image is None: st.info("Need an image here") image = torch.zeros(1, 3, 224, 224) value_dict["pool_image"] = image return value_dict def perform_save_locally(save_path, samples): os.makedirs(os.path.join(save_path), exist_ok=True) base_count = len(os.listdir(os.path.join(save_path))) samples = embed_watermark(samples) for sample in samples: sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") Image.fromarray(sample.astype(np.uint8)).save( os.path.join(save_path, f"{base_count:09}.png") ) base_count += 1 def init_save_locally(_dir, init_value: bool = False): save_locally = st.sidebar.checkbox("Save images locally", value=init_value) if save_locally: save_path = st.text_input("Save path", value=os.path.join(_dir, "samples")) else: save_path = None return save_locally, save_path def get_guider(options, key): guider = st.sidebar.selectbox( f"Discretization #{key}", [ "VanillaCFG", "IdentityGuider", "LinearPredictionGuider", "TrianglePredictionGuider", ], options.get("guider", 0), ) additional_guider_kwargs = options.pop("additional_guider_kwargs", {}) if guider == "IdentityGuider": guider_config = { "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" } elif guider == "VanillaCFG": scale = st.number_input( f"cfg-scale #{key}", value=options.get("cfg", 5.0), min_value=0.0, ) guider_config = { "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", "params": { "scale": scale, **additional_guider_kwargs, }, } elif guider == "LinearPredictionGuider": max_scale = st.number_input( f"max-cfg-scale #{key}", value=options.get("cfg", 1.5), min_value=1.0, ) min_scale = st.sidebar.number_input( f"min guidance scale", value=options.get("min_cfg", 1.0), min_value=1.0, max_value=10.0, ) guider_config = { "target": "sgm.modules.diffusionmodules.guiders.LinearPredictionGuider", "params": { "max_scale": max_scale, "min_scale": min_scale, "num_frames": options["num_frames"], **additional_guider_kwargs, }, } elif guider == "TrianglePredictionGuider": max_scale = st.number_input( f"max-cfg-scale #{key}", value=options.get("cfg", 2.5), min_value=1.0, max_value=10.0, ) min_scale = st.sidebar.number_input( f"min guidance scale", value=options.get("min_cfg", 1.0), min_value=1.0, max_value=10.0, ) guider_config = { "target": "sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider", "params": { "max_scale": max_scale, "min_scale": min_scale, "num_frames": options["num_frames"], **additional_guider_kwargs, }, } else: raise NotImplementedError return guider_config def init_sampling( key=1, img2img_strength: Optional[float] = None, specify_num_samples: bool = True, stage2strength: Optional[float] = None, options: Optional[Dict[str, int]] = None, ): options = {} if options is None else options num_rows, num_cols = 1, 1 if specify_num_samples: num_cols = st.number_input( f"num cols #{key}", value=num_cols, min_value=1, max_value=10 ) steps = st.number_input( f"steps #{key}", value=options.get("num_steps", 50), min_value=1, max_value=1000 ) sampler = st.sidebar.selectbox( f"Sampler #{key}", [ "EulerEDMSampler", "HeunEDMSampler", "EulerAncestralSampler", "DPMPP2SAncestralSampler", "DPMPP2MSampler", "LinearMultistepSampler", ], options.get("sampler", 0), ) discretization = st.sidebar.selectbox( f"Discretization #{key}", [ "LegacyDDPMDiscretization", "EDMDiscretization", ], options.get("discretization", 0), ) discretization_config = get_discretization(discretization, options=options, key=key) guider_config = get_guider(options=options, key=key) sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key) if img2img_strength is not None: st.warning( f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper" ) sampler.discretization = Img2ImgDiscretizationWrapper( sampler.discretization, strength=img2img_strength ) if stage2strength is not None: sampler.discretization = Txt2NoisyDiscretizationWrapper( sampler.discretization, strength=stage2strength, original_steps=steps ) return sampler, num_rows, num_cols def get_discretization(discretization, options, key=1): if discretization == "LegacyDDPMDiscretization": discretization_config = { "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", } elif discretization == "EDMDiscretization": sigma_min = st.sidebar.number_input( f"sigma_min #{key}", value=options.get("sigma_min", 0.03) ) # 0.0292 sigma_max = st.sidebar.number_input( f"sigma_max #{key}", value=options.get("sigma_max", 14.61) ) # 14.6146 rho = st.sidebar.number_input(f"rho #{key}", value=options.get("rho", 3.0)) discretization_config = { "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", "params": { "sigma_min": sigma_min, "sigma_max": sigma_max, "rho": rho, }, } return discretization_config def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1): if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler": s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0) s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0) s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0) s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0) if sampler_name == "EulerEDMSampler": sampler = EulerEDMSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, verbose=True, ) elif sampler_name == "HeunEDMSampler": sampler = HeunEDMSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, verbose=True, ) elif ( sampler_name == "EulerAncestralSampler" or sampler_name == "DPMPP2SAncestralSampler" ): s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0) eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0) if sampler_name == "EulerAncestralSampler": sampler = EulerAncestralSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, eta=eta, s_noise=s_noise, verbose=True, ) elif sampler_name == "DPMPP2SAncestralSampler": sampler = DPMPP2SAncestralSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, eta=eta, s_noise=s_noise, verbose=True, ) elif sampler_name == "DPMPP2MSampler": sampler = DPMPP2MSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, verbose=True, ) elif sampler_name == "LinearMultistepSampler": order = st.sidebar.number_input("order", value=4, min_value=1) sampler = LinearMultistepSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, order=order, verbose=True, ) else: raise ValueError(f"unknown sampler {sampler_name}!") return sampler def get_interactive_image() -> Image.Image: image = st.file_uploader("Input", type=["jpg", "JPEG", "png"]) if image is not None: image = Image.open(image) if not image.mode == "RGB": image = image.convert("RGB") return image def load_img( display: bool = True, size: Union[None, int, Tuple[int, int]] = None, center_crop: bool = False, ): image = get_interactive_image() if image is None: return None if display: st.image(image) w, h = image.size print(f"loaded input image of size ({w}, {h})") transform = [] if size is not None: transform.append(transforms.Resize(size)) if center_crop: transform.append(transforms.CenterCrop(size)) transform.append(transforms.ToTensor()) transform.append(transforms.Lambda(lambda x: 2.0 * x - 1.0)) transform = transforms.Compose(transform) img = transform(image)[None, ...] st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}") return img def get_init_img(batch_size=1, key=None): init_image = load_img(key=key).cuda() init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) return init_image def do_sample( model, sampler, value_dict, num_samples, H, W, C, F, force_uc_zero_embeddings: Optional[List] = None, force_cond_zero_embeddings: Optional[List] = None, batch2model_input: List = None, return_latents=False, filter=None, T=None, additional_batch_uc_fields=None, decoding_t=None, ): force_uc_zero_embeddings = default(force_uc_zero_embeddings, []) batch2model_input = default(batch2model_input, []) additional_batch_uc_fields = default(additional_batch_uc_fields, []) st.text("Sampling") outputs = st.empty() precision_scope = autocast with torch.no_grad(): with precision_scope("cuda"): with model.ema_scope(): if T is not None: num_samples = [num_samples, T] else: num_samples = [num_samples] load_model(model.conditioner) batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples, T=T, additional_batch_uc_fields=additional_batch_uc_fields, ) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings, force_cond_zero_embeddings=force_cond_zero_embeddings, ) unload_model(model.conditioner) for k in c: if not k == "crossattn": c[k], uc[k] = map( lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc) ) if k in ["crossattn", "concat"] and T is not None: uc[k] = repeat(uc[k], "b ... -> b t ...", t=T) uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=T) c[k] = repeat(c[k], "b ... -> b t ...", t=T) c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=T) additional_model_inputs = {} for k in batch2model_input: if k == "image_only_indicator": assert T is not None if isinstance( sampler.guider, ( VanillaCFG, LinearPredictionGuider, TrianglePredictionGuider, ), ): additional_model_inputs[k] = torch.zeros( num_samples[0] * 2, num_samples[1] ).to("cuda") else: additional_model_inputs[k] = torch.zeros(num_samples).to( "cuda" ) else: additional_model_inputs[k] = batch[k] shape = (math.prod(num_samples), C, H // F, W // F) randn = torch.randn(shape).to("cuda") def denoiser(input, sigma, c): return model.denoiser( model.model, input, sigma, c, **additional_model_inputs ) load_model(model.denoiser) load_model(model.model) samples_z = sampler(denoiser, randn, cond=c, uc=uc) unload_model(model.model) unload_model(model.denoiser) load_model(model.first_stage_model) model.en_and_decode_n_samples_a_time = ( decoding_t # Decode n frames at a time ) samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) unload_model(model.first_stage_model) if filter is not None: samples = filter(samples) if T is None: grid = torch.stack([samples]) grid = rearrange(grid, "n b c h w -> (n h) (b w) c") outputs.image(grid.cpu().numpy()) else: as_vids = rearrange(samples, "(b t) c h w -> b t c h w", t=T) for i, vid in enumerate(as_vids): grid = rearrange(make_grid(vid, nrow=4), "c h w -> h w c") st.image( grid.cpu().numpy(), f"Sample #{i} as image", ) if return_latents: return samples, samples_z return samples def get_batch( keys, value_dict: dict, N: Union[List, ListConfig], device: str = "cuda", T: int = None, additional_batch_uc_fields: List[str] = [], ): # Hardcoded demo setups; might undergo some changes in the future batch = {} batch_uc = {} for key in keys: if key == "txt": batch["txt"] = [value_dict["prompt"]] * math.prod(N) batch_uc["txt"] = [value_dict["negative_prompt"]] * math.prod(N) elif key == "original_size_as_tuple": batch["original_size_as_tuple"] = ( torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) .to(device) .repeat(math.prod(N), 1) ) elif key == "crop_coords_top_left": batch["crop_coords_top_left"] = ( torch.tensor( [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] ) .to(device) .repeat(math.prod(N), 1) ) elif key == "aesthetic_score": batch["aesthetic_score"] = ( torch.tensor([value_dict["aesthetic_score"]]) .to(device) .repeat(math.prod(N), 1) ) batch_uc["aesthetic_score"] = ( torch.tensor([value_dict["negative_aesthetic_score"]]) .to(device) .repeat(math.prod(N), 1) ) elif key == "target_size_as_tuple": batch["target_size_as_tuple"] = ( torch.tensor([value_dict["target_height"], value_dict["target_width"]]) .to(device) .repeat(math.prod(N), 1) ) elif key == "fps": batch[key] = ( torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N)) ) elif key == "fps_id": batch[key] = ( torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N)) ) elif key == "motion_bucket_id": batch[key] = ( torch.tensor([value_dict["motion_bucket_id"]]) .to(device) .repeat(math.prod(N)) ) elif key == "pool_image": batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to( device, dtype=torch.half ) elif key == "cond_aug": batch[key] = repeat( torch.tensor([value_dict["cond_aug"]]).to("cuda"), "1 -> b", b=math.prod(N), ) elif key == "cond_frames": batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0]) elif key == "cond_frames_without_noise": batch[key] = repeat( value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0] ) elif key == "polars_rad": batch[key] = torch.tensor(value_dict["polars_rad"]).to(device).repeat(N[0]) elif key == "azimuths_rad": batch[key] = ( torch.tensor(value_dict["azimuths_rad"]).to(device).repeat(N[0]) ) else: batch[key] = value_dict[key] if T is not None: batch["num_video_frames"] = T for key in batch.keys(): if key not in batch_uc and isinstance(batch[key], torch.Tensor): batch_uc[key] = torch.clone(batch[key]) elif key in additional_batch_uc_fields and key not in batch_uc: batch_uc[key] = copy.copy(batch[key]) return batch, batch_uc @torch.no_grad() def do_img2img( img, model, sampler, value_dict, num_samples, force_uc_zero_embeddings: Optional[List] = None, force_cond_zero_embeddings: Optional[List] = None, additional_kwargs={}, offset_noise_level: int = 0.0, return_latents=False, skip_encode=False, filter=None, add_noise=True, ): st.text("Sampling") outputs = st.empty() precision_scope = autocast with torch.no_grad(): with precision_scope("cuda"): with model.ema_scope(): load_model(model.conditioner) batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, [num_samples], ) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings, force_cond_zero_embeddings=force_cond_zero_embeddings, ) unload_model(model.conditioner) for k in c: c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc)) for k in additional_kwargs: c[k] = uc[k] = additional_kwargs[k] if skip_encode: z = img else: load_model(model.first_stage_model) z = model.encode_first_stage(img) unload_model(model.first_stage_model) noise = torch.randn_like(z) sigmas = sampler.discretization(sampler.num_steps).cuda() sigma = sigmas[0] st.info(f"all sigmas: {sigmas}") st.info(f"noising sigma: {sigma}") if offset_noise_level > 0.0: noise = noise + offset_noise_level * append_dims( torch.randn(z.shape[0], device=z.device), z.ndim ) if add_noise: noised_z = z + noise * append_dims(sigma, z.ndim).cuda() noised_z = noised_z / torch.sqrt( 1.0 + sigmas[0] ** 2.0 ) # Note: hardcoded to DDPM-like scaling. need to generalize later. else: noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0) def denoiser(x, sigma, c): return model.denoiser(model.model, x, sigma, c) load_model(model.denoiser) load_model(model.model) samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) unload_model(model.model) unload_model(model.denoiser) load_model(model.first_stage_model) samples_x = model.decode_first_stage(samples_z) unload_model(model.first_stage_model) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) if filter is not None: samples = filter(samples) grid = rearrange(grid, "n b c h w -> (n h) (b w) c") outputs.image(grid.cpu().numpy()) if return_latents: return samples, samples_z return samples def get_resizing_factor( desired_shape: Tuple[int, int], current_shape: Tuple[int, int] ) -> float: r_bound = desired_shape[1] / desired_shape[0] aspect_r = current_shape[1] / current_shape[0] if r_bound >= 1.0: if aspect_r >= r_bound: factor = min(desired_shape) / min(current_shape) else: if aspect_r < 1.0: factor = max(desired_shape) / min(current_shape) else: factor = max(desired_shape) / max(current_shape) else: if aspect_r <= r_bound: factor = min(desired_shape) / min(current_shape) else: if aspect_r > 1: factor = max(desired_shape) / min(current_shape) else: factor = max(desired_shape) / max(current_shape) return factor def get_interactive_image(key=None) -> Image.Image: image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key) if image is not None: image = Image.open(image) if not image.mode == "RGB": image = image.convert("RGB") return image def load_img_for_prediction( W: int, H: int, display=True, key=None, device="cuda" ) -> torch.Tensor: image = get_interactive_image(key=key) if image is None: return None if display: st.image(image) w, h = image.size image = np.array(image).astype(np.float32) / 255 if image.shape[-1] == 4: rgb, alpha = image[:, :, :3], image[:, :, 3:] image = rgb * alpha + (1 - alpha) image = image.transpose(2, 0, 1) image = torch.from_numpy(image).to(dtype=torch.float32) image = image.unsqueeze(0) rfs = get_resizing_factor((H, W), (h, w)) resize_size = [int(np.ceil(rfs * s)) for s in (h, w)] top = (resize_size[0] - H) // 2 left = (resize_size[1] - W) // 2 image = torch.nn.functional.interpolate( image, resize_size, mode="area", antialias=False ) image = TT.functional.crop(image, top=top, left=left, height=H, width=W) if display: numpy_img = np.transpose(image[0].numpy(), (1, 2, 0)) pil_image = Image.fromarray((numpy_img * 255).astype(np.uint8)) st.image(pil_image) return image.to(device) * 2.0 - 1.0 def save_video_as_grid_and_mp4( video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5 ): os.makedirs(save_path, exist_ok=True) base_count = len(glob(os.path.join(save_path, "*.mp4"))) video_batch = rearrange(video_batch, "(b t) c h w -> b t c h w", t=T) video_batch = embed_watermark(video_batch) for vid in video_batch: save_image(vid, fp=os.path.join(save_path, f"{base_count:06d}.png"), nrow=4) video_path = os.path.join(save_path, f"{base_count:06d}.mp4") vid = ( (rearrange(vid, "t c h w -> t h w c") * 255).cpu().numpy().astype(np.uint8) ) imageio.mimwrite(video_path, vid, fps=fps) video_path_h264 = video_path[:-4] + "_h264.mp4" os.system(f"ffmpeg -i '{video_path}' -c:v libx264 '{video_path_h264}'") with open(video_path_h264, "rb") as f: video_bytes = f.read() os.remove(video_path_h264) st.video(video_bytes) base_count += 1 ================================================ FILE: scripts/demo/sv3d_helpers.py ================================================ import os import matplotlib.pyplot as plt import numpy as np def generate_dynamic_cycle_xy_values( length=21, init_elev=0, num_components=84, frequency_range=(1, 5), amplitude_range=(0.5, 10), step_range=(0, 2), ): # Y values generation y_sequence = np.ones(length) * init_elev for _ in range(num_components): # Choose a frequency that will complete whole cycles in the sequence frequency = np.random.randint(*frequency_range) * (2 * np.pi / length) amplitude = np.random.uniform(*amplitude_range) phase_shift = np.random.choice([0, np.pi]) # np.random.uniform(0, 2 * np.pi) angles = ( np.linspace(0, frequency * length, length, endpoint=False) + phase_shift ) y_sequence += np.sin(angles) * amplitude # X values generation # Generate length - 1 steps since the last step is back to start steps = np.random.uniform(*step_range, length - 1) total_step_sum = np.sum(steps) # Calculate the scale factor to scale total steps to just under 360 scale_factor = ( 360 - ((360 / length) * np.random.uniform(*step_range)) ) / total_step_sum # Apply the scale factor and generate the sequence of X values x_values = np.cumsum(steps * scale_factor) # Ensure the sequence starts at 0 and add the final step to complete the loop x_values = np.insert(x_values, 0, 0) return x_values, y_sequence def smooth_data(data, window_size): # Extend data at both ends by wrapping around to create a continuous loop pad_size = window_size padded_data = np.concatenate((data[-pad_size:], data, data[:pad_size])) # Apply smoothing kernel = np.ones(window_size) / window_size smoothed_data = np.convolve(padded_data, kernel, mode="same") # Extract the smoothed data corresponding to the original sequence # Adjust the indices to account for the larger padding start_index = pad_size end_index = -pad_size if pad_size != 0 else None smoothed_original_data = smoothed_data[start_index:end_index] return smoothed_original_data # Function to generate and process the data def gen_dynamic_loop(length=21, elev_deg=0): while True: # Generate the combined X and Y values using the new function azim_values, elev_values = generate_dynamic_cycle_xy_values( length=84, init_elev=elev_deg ) # Smooth the Y values directly smoothed_elev_values = smooth_data(elev_values, 5) max_magnitude = np.max(np.abs(smoothed_elev_values)) if max_magnitude < 90: break subsample = 84 // length azim_rad = np.deg2rad(azim_values[::subsample]) elev_rad = np.deg2rad(smoothed_elev_values[::subsample]) # Make cond frame the last one return np.roll(azim_rad, -1), np.roll(elev_rad, -1) def plot_3D(azim, polar, save_path, dynamic=True): os.makedirs(os.path.dirname(save_path), exist_ok=True) elev = np.deg2rad(90) - polar fig = plt.figure(figsize=(5, 5)) ax = fig.add_subplot(projection="3d") cm = plt.get_cmap("Greys") col_line = [cm(i) for i in np.linspace(0.3, 1, len(azim) + 1)] cm = plt.get_cmap("cool") col = [cm(float(i) / (len(azim))) for i in np.arange(len(azim))] xs = np.cos(elev) * np.cos(azim) ys = np.cos(elev) * np.sin(azim) zs = np.sin(elev) ax.scatter(xs[0], ys[0], zs[0], s=100, color=col[0]) xs_d, ys_d, zs_d = (xs[1:] - xs[:-1]), (ys[1:] - ys[:-1]), (zs[1:] - zs[:-1]) for i in range(len(xs) - 1): if dynamic: ax.quiver( xs[i], ys[i], zs[i], xs_d[i], ys_d[i], zs_d[i], lw=2, color=col_line[i] ) else: ax.plot(xs[i : i + 2], ys[i : i + 2], zs[i : i + 2], lw=2, c=col_line[i]) ax.scatter(xs[i + 1], ys[i + 1], zs[i + 1], s=100, color=col[i + 1]) ax.scatter(xs[:1], ys[:1], zs[:1], s=120, facecolors="none", edgecolors="k") ax.scatter(xs[-1:], ys[-1:], zs[-1:], s=120, facecolors="none", edgecolors="k") ax.view_init(elev=30, azim=-20, roll=0) plt.savefig(save_path, bbox_inches="tight") plt.clf() plt.close() ================================================ FILE: scripts/demo/sv4d_helpers.py ================================================ import math import os from glob import glob from pathlib import Path from typing import Dict, List, Optional, Tuple, Union import cv2 import imageio import numpy as np import torch import torchvision.transforms as TT from einops import rearrange, repeat from omegaconf import ListConfig, OmegaConf from PIL import Image, ImageSequence from rembg import remove from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering from sgm.modules.autoencoding.temporal_ae import VideoDecoder from sgm.modules.diffusionmodules.guiders import ( LinearPredictionGuider, SpatiotemporalPredictionGuider, TrapezoidPredictionGuider, TrianglePredictionGuider, VanillaCFG, ) from sgm.modules.diffusionmodules.sampling import ( DPMPP2MSampler, DPMPP2SAncestralSampler, EulerAncestralSampler, EulerEDMSampler, HeunEDMSampler, LinearMultistepSampler, ) from sgm.util import default, instantiate_from_config from torch import autocast from torchvision.transforms import ToTensor def load_module_gpu(model): model.cuda() def unload_module_gpu(model): model.cpu() torch.cuda.empty_cache() def initial_model_load(model): model.model.half() return model def get_resizing_factor( desired_shape: Tuple[int, int], current_shape: Tuple[int, int] ) -> float: r_bound = desired_shape[1] / desired_shape[0] aspect_r = current_shape[1] / current_shape[0] if r_bound >= 1.0: if aspect_r >= r_bound: factor = min(desired_shape) / min(current_shape) else: if aspect_r < 1.0: factor = max(desired_shape) / min(current_shape) else: factor = max(desired_shape) / max(current_shape) else: if aspect_r <= r_bound: factor = min(desired_shape) / min(current_shape) else: if aspect_r > 1: factor = max(desired_shape) / min(current_shape) else: factor = max(desired_shape) / max(current_shape) return factor def read_gif(input_path, n_frames): frames = [] video = Image.open(input_path) for img in ImageSequence.Iterator(video): frames.append(img.convert("RGBA")) if len(frames) == n_frames: break return frames def read_mp4(input_path, n_frames): frames = [] vidcap = cv2.VideoCapture(input_path) success, image = vidcap.read() while success: frames.append(Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))) success, image = vidcap.read() if len(frames) == n_frames: break return frames def save_img(file_name, img): output_dir = os.path.dirname(file_name) os.makedirs(output_dir, exist_ok=True) imageio.imwrite( file_name, (((img[0].permute(1, 2, 0) + 1) / 2).cpu().numpy() * 255.0).astype(np.uint8), ) def save_video(file_name, imgs, fps=10): output_dir = os.path.dirname(file_name) os.makedirs(output_dir, exist_ok=True) img_grid = [ (((img[0].permute(1, 2, 0) + 1) / 2).cpu().numpy() * 255.0).astype(np.uint8) for img in imgs ] if file_name.endswith(".gif"): imageio.mimwrite(file_name, img_grid, fps=fps, loop=0) else: imageio.mimwrite(file_name, img_grid, fps=fps) def read_video( input_path: str, n_frames: int, device: str = "cuda", ): path = Path(input_path) is_video_file = False all_img_paths = [] if path.is_file(): if any([input_path.endswith(x) for x in [".gif", ".mp4"]]): is_video_file = True else: raise ValueError("Path is not a valid video file.") elif path.is_dir(): all_img_paths = sorted( [ f for f in path.iterdir() if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"] ] )[:n_frames] elif "*" in input_path: all_img_paths = sorted(glob(input_path))[:n_frames] else: raise ValueError if is_video_file and input_path.endswith(".gif"): images = read_gif(input_path, n_frames)[:n_frames] elif is_video_file and input_path.endswith(".mp4"): images = read_mp4(input_path, n_frames)[:n_frames] else: print(f"Loading {len(all_img_paths)} video frames...") images = [Image.open(img_path) for img_path in all_img_paths] if len(images) < n_frames: images = (images + images[::-1])[:n_frames] if len(images) != n_frames: raise ValueError(f"Input video contains fewer than {n_frames} frames.") images_v0 = [] for image in images: image = ToTensor()(image).unsqueeze(0).to(device) images_v0.append(image * 2.0 - 1.0) return images_v0 def preprocess_video( input_path, remove_bg=False, n_frames=21, W=576, H=576, output_folder=None, image_frame_ratio=0.917, base_count=0, ): print(f"preprocess {input_path}") if output_folder is None: output_folder = os.path.dirname(input_path) path = Path(input_path) is_video_file = False all_img_paths = [] if path.is_file(): if any([input_path.endswith(x) for x in [".gif", ".mp4"]]): is_video_file = True else: raise ValueError("Path is not a valid video file.") elif path.is_dir(): all_img_paths = sorted( [ f for f in path.iterdir() if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"] ] )[:n_frames] elif "*" in input_path: all_img_paths = sorted(glob(input_path))[:n_frames] else: raise ValueError if is_video_file and input_path.endswith(".gif"): images = read_gif(input_path, n_frames)[:n_frames] elif is_video_file and input_path.endswith(".mp4"): images = read_mp4(input_path, n_frames)[:n_frames] else: print(f"Loading {len(all_img_paths)} video frames...") images = [Image.open(img_path) for img_path in all_img_paths] if len(images) != n_frames: raise ValueError( f"Input video contains {len(images)} frames, fewer than {n_frames} frames." ) # Remove background for i, image in enumerate(images): if remove_bg: if image.mode == "RGBA": pass else: # image.thumbnail([W, H], Image.Resampling.LANCZOS) image = remove(image.convert("RGBA"), alpha_matting=True) images[i] = image # Crop video frames, assume the object is already in the center of the image white_thresh = 250 images_v0 = [] box_coord = [np.inf, np.inf, 0, 0] for image in images: image_arr = np.array(image) in_w, in_h = image_arr.shape[:2] original_center = (in_w // 2, in_h // 2) if image.mode == "RGBA": ret, mask = cv2.threshold( np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY ) else: # assume the input image has white background ret, mask = cv2.threshold( (np.array(image).mean(-1) <= white_thresh).astype(np.uint8) * 255, 0, 255, cv2.THRESH_BINARY, ) x, y, w, h = cv2.boundingRect(mask) box_coord[0] = min(box_coord[0], x) box_coord[1] = min(box_coord[1], y) box_coord[2] = max(box_coord[2], x + w) box_coord[3] = max(box_coord[3], y + h) box_square = max( original_center[0] - box_coord[0], original_center[1] - box_coord[1] ) box_square = max(box_square, box_coord[2] - original_center[0]) box_square = max(box_square, box_coord[3] - original_center[1]) x, y = max(0, original_center[0] - box_square), max( 0, original_center[1] - box_square ) w, h = min(image_arr.shape[0], 2 * box_square), min( image_arr.shape[1], 2 * box_square ) box_size = box_square * 2 for image in images: if image.mode == "RGB": image = image.convert("RGBA") image_arr = np.array(image) side_len = ( int(box_size / image_frame_ratio) if image_frame_ratio is not None else in_w ) padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) center = side_len // 2 box_size_w = min(w, box_size) box_size_h = min(h, box_size) padded_image[ center - box_size_w // 2 : center - box_size_w // 2 + box_size_w, center - box_size_h // 2 : center - box_size_h // 2 + box_size_h, ] = image_arr[x : x + w, y : y + h] rgba = Image.fromarray(padded_image).resize((W, H), Image.LANCZOS) # rgba = image.resize((W, H), Image.LANCZOS) rgba_arr = np.array(rgba) / 255.0 rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:]) image = (rgb * 255).astype(np.uint8) images_v0.append(image) processed_file = os.path.join(output_folder, f"{base_count:06d}_process_input.mp4") imageio.mimwrite(processed_file, images_v0, fps=10) return processed_file def sample_sv3d( image, num_frames: Optional[int] = None, # 21 for SV3D num_steps: Optional[int] = None, version: str = "sv3d_u", fps_id: int = 6, motion_bucket_id: int = 127, cond_aug: float = 0.02, decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. device: str = "cuda", polar_rad: Optional[Union[float, List[float]]] = None, azim_rad: Optional[List[float]] = None, verbose: Optional[bool] = False, sv3d_model=None, ): """ Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`. """ if sv3d_model is None: if version == "sv3d_u": model_config = "scripts/sampling/configs/sv3d_u.yaml" elif version == "sv3d_p": model_config = "scripts/sampling/configs/sv3d_p.yaml" else: raise ValueError(f"Version {version} does not exist.") model, filter = load_model( model_config, device, num_frames, num_steps, verbose, ) else: model = sv3d_model load_module_gpu(model) H, W = image.shape[2:] F = 8 C = 4 shape = (num_frames, C, H // F, W // F) value_dict = {} value_dict["cond_frames_without_noise"] = image value_dict["motion_bucket_id"] = motion_bucket_id value_dict["fps_id"] = fps_id value_dict["cond_aug"] = cond_aug value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image) if "sv3d_p" in version: value_dict["polars_rad"] = polar_rad value_dict["azimuths_rad"] = azim_rad with torch.no_grad(): with torch.autocast(device): load_module_gpu(model.conditioner) batch, batch_uc = get_batch_sv3d( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, [1, num_frames], T=num_frames, device=device, ) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=[ "cond_frames", "cond_frames_without_noise", ], ) unload_module_gpu(model.conditioner) for k in ["crossattn", "concat"]: uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) randn = torch.randn(shape, device=device) additional_model_inputs = {} additional_model_inputs["image_only_indicator"] = torch.zeros( 2, num_frames ).to(device) additional_model_inputs["num_video_frames"] = batch["num_video_frames"] def denoiser(input, sigma, c): return model.denoiser( model.model, input, sigma, c, **additional_model_inputs ) load_module_gpu(model.model) load_module_gpu(model.denoiser) samples_z = model.sampler(denoiser, randn, cond=c, uc=uc) unload_module_gpu(model.denoiser) unload_module_gpu(model.model) load_module_gpu(model.first_stage_model) model.en_and_decode_n_samples_a_time = decoding_t samples_x = model.decode_first_stage(samples_z) unload_module_gpu(model.first_stage_model) samples_x[-1:] = value_dict["cond_frames_without_noise"] samples = torch.clamp(samples_x, min=-1.0, max=1.0) unload_module_gpu(model) return samples def decode_latents( model, samples_z, img_matrix, frame_indices, view_indices, timesteps ): load_module_gpu(model.first_stage_model) for t in frame_indices: for v in view_indices: if True: # t != 0 and v != 0: if isinstance(model.first_stage_model.decoder, VideoDecoder): samples_x = model.decode_first_stage( samples_z[t, v][None], timesteps=timesteps ) else: samples_x = model.decode_first_stage(samples_z[t, v][None]) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) img_matrix[t][v] = samples * 2 - 1 unload_module_gpu(model.first_stage_model) return img_matrix def init_embedder_options_no_st(keys, init_dict, prompt=None, negative_prompt=None): # Hardcoded demo settings; might undergo some changes in the future value_dict = {} for key in keys: if key == "txt": if prompt is None: prompt = "A professional photograph of an astronaut riding a pig" if negative_prompt is None: negative_prompt = "" value_dict["prompt"] = prompt value_dict["negative_prompt"] = negative_prompt if key == "original_size_as_tuple": orig_width = init_dict["orig_width"] orig_height = init_dict["orig_height"] value_dict["orig_width"] = orig_width value_dict["orig_height"] = orig_height if key == "crop_coords_top_left": crop_coord_top = 0 crop_coord_left = 0 value_dict["crop_coords_top"] = crop_coord_top value_dict["crop_coords_left"] = crop_coord_left if key == "aesthetic_score": value_dict["aesthetic_score"] = 6.0 value_dict["negative_aesthetic_score"] = 2.5 if key == "target_size_as_tuple": value_dict["target_width"] = init_dict["target_width"] value_dict["target_height"] = init_dict["target_height"] if key in ["fps_id", "fps"]: fps = 6 value_dict["fps"] = fps value_dict["fps_id"] = fps - 1 if key == "motion_bucket_id": mb_id = 127 value_dict["motion_bucket_id"] = mb_id if key == "noise_level": value_dict["noise_level"] = 0 return value_dict def get_discretization_no_st(discretization, options, key=1): if discretization == "LegacyDDPMDiscretization": discretization_config = { "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", } elif discretization == "EDMDiscretization": sigma_min = options.get("sigma_min", 0.03) sigma_max = options.get("sigma_max", 14.61) rho = options.get("rho", 3.0) discretization_config = { "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", "params": { "sigma_min": sigma_min, "sigma_max": sigma_max, "rho": rho, }, } return discretization_config def get_guider_no_st(options, key): guider = [ "VanillaCFG", "IdentityGuider", "LinearPredictionGuider", "TrianglePredictionGuider", "TrapezoidPredictionGuider", "SpatiotemporalPredictionGuider", ][options.get("guider", 2)] additional_guider_kwargs = ( options["additional_guider_kwargs"] if "additional_guider_kwargs" in options else {} ) if guider == "IdentityGuider": guider_config = { "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" } elif guider == "VanillaCFG": scale_schedule = "Identity" if scale_schedule == "Identity": scale = options.get("cfg", 5.0) scale_schedule_config = { "target": "sgm.modules.diffusionmodules.guiders.IdentitySchedule", "params": {"scale": scale}, } elif scale_schedule == "Oscillating": small_scale = 4.0 large_scale = 16.0 sigma_cutoff = 1.0 scale_schedule_config = { "target": "sgm.modules.diffusionmodules.guiders.OscillatingSchedule", "params": { "small_scale": small_scale, "large_scale": large_scale, "sigma_cutoff": sigma_cutoff, }, } else: raise NotImplementedError guider_config = { "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", "params": { "scale_schedule_config": scale_schedule_config, **additional_guider_kwargs, }, } elif guider == "LinearPredictionGuider": max_scale = options.get("cfg", 1.5) guider_config = { "target": "sgm.modules.diffusionmodules.guiders.LinearPredictionGuider", "params": { "max_scale": max_scale, "num_frames": options["num_frames"], **additional_guider_kwargs, }, } elif guider == "TrianglePredictionGuider": max_scale = options.get("cfg", 1.5) period = options.get("period", 1.0) period_fusing = options.get("period_fusing", "max") guider_config = { "target": "sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider", "params": { "max_scale": max_scale, "num_frames": options["num_frames"], "period": period, "period_fusing": period_fusing, **additional_guider_kwargs, }, } elif guider == "TrapezoidPredictionGuider": max_scale = options.get("cfg", 1.5) edge_perc = options.get("edge_perc", 0.1) guider_config = { "target": "sgm.modules.diffusionmodules.guiders.TrapezoidPredictionGuider", "params": { "max_scale": max_scale, "num_frames": options["num_frames"], "edge_perc": edge_perc, **additional_guider_kwargs, }, } elif guider == "SpatiotemporalPredictionGuider": max_scale = options.get("cfg", 1.5) min_scale = options.get("min_cfg", 1.0) guider_config = { "target": "sgm.modules.diffusionmodules.guiders.SpatiotemporalPredictionGuider", "params": { "max_scale": max_scale, "min_scale": min_scale, "num_frames": options["num_frames"], "num_views": options["num_views"], **additional_guider_kwargs, }, } else: raise NotImplementedError return guider_config def get_sampler_no_st(sampler_name, steps, discretization_config, guider_config, key=1): if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler": s_churn = 0.0 s_tmin = 0.0 s_tmax = 999.0 s_noise = 1.0 if sampler_name == "EulerEDMSampler": sampler = EulerEDMSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, verbose=False, ) elif sampler_name == "HeunEDMSampler": sampler = HeunEDMSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, verbose=False, ) elif ( sampler_name == "EulerAncestralSampler" or sampler_name == "DPMPP2SAncestralSampler" ): s_noise = 1.0 eta = 1.0 if sampler_name == "EulerAncestralSampler": sampler = EulerAncestralSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, eta=eta, s_noise=s_noise, verbose=False, ) elif sampler_name == "DPMPP2SAncestralSampler": sampler = DPMPP2SAncestralSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, eta=eta, s_noise=s_noise, verbose=False, ) elif sampler_name == "DPMPP2MSampler": sampler = DPMPP2MSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, verbose=False, ) elif sampler_name == "LinearMultistepSampler": order = 4 sampler = LinearMultistepSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, order=order, verbose=False, ) else: raise ValueError(f"unknown sampler {sampler_name}!") return sampler def init_sampling_no_st( key=1, options: Optional[Dict[str, int]] = None, ): options = {} if options is None else options num_rows, num_cols = 1, 1 steps = options.get("num_steps", 50) sampler = [ "EulerEDMSampler", "HeunEDMSampler", "EulerAncestralSampler", "DPMPP2SAncestralSampler", "DPMPP2MSampler", "LinearMultistepSampler", ][options.get("sampler", 0)] discretization = [ "LegacyDDPMDiscretization", "EDMDiscretization", ][options.get("discretization", 1)] discretization_config = get_discretization_no_st( discretization, options=options, key=key ) guider_config = get_guider_no_st(options=options, key=key) sampler = get_sampler_no_st( sampler, steps, discretization_config, guider_config, key=key ) return sampler, num_rows, num_cols def run_img2vid( version_dict, model, image, seed=23, polar_rad=[10] * 21, azim_rad=np.linspace(0, 360, 21 + 1)[1:], cond_motion=None, cond_view=None, decoding_t=None, cond_mv=True, ): options = version_dict["options"] H = version_dict["H"] W = version_dict["W"] T = version_dict["T"] C = version_dict["C"] F = version_dict["f"] init_dict = { "orig_width": 576, "orig_height": 576, "target_width": W, "target_height": H, } ukeys = set(get_unique_embedder_keys_from_conditioner(model.conditioner)) value_dict = init_embedder_options_no_st( ukeys, init_dict, negative_prompt=options.get("negative_promt", ""), prompt="A 3D model.", ) if "fps" not in ukeys: value_dict["fps"] = 6 value_dict["is_image"] = 0 value_dict["is_webvid"] = 0 if cond_mv: value_dict["image_only_indicator"] = 1.0 else: value_dict["image_only_indicator"] = 0.0 cond_aug = 0.00 if cond_motion is not None: value_dict["cond_frames_without_noise"] = cond_motion value_dict["cond_frames"] = ( cond_motion[:, None].repeat(1, cond_view.shape[0], 1, 1, 1).flatten(0, 1) ) else: value_dict["cond_frames_without_noise"] = image value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image) value_dict["cond_aug"] = cond_aug value_dict["polar_rad"] = polar_rad value_dict["azimuth_rad"] = azim_rad value_dict["rotated"] = False value_dict["cond_motion"] = cond_motion value_dict["cond_view"] = cond_view # seed_everything(seed) options["num_frames"] = T sampler, num_rows, num_cols = init_sampling_no_st(options=options) num_samples = num_rows * num_cols samples = do_sample( model, sampler, value_dict, num_samples, H, W, C, F, T=T, batch2model_input=["num_video_frames", "image_only_indicator"], force_uc_zero_embeddings=options.get("force_uc_zero_embeddings", None), force_cond_zero_embeddings=options.get("force_cond_zero_embeddings", None), return_latents=False, decoding_t=decoding_t, ) return samples def prepare_inputs_forward_backward( img_matrix, view_indices, frame_indices, v0, t0, t1, model, version_dict, seed, polars, azims, ): # forward sampling forward_frame_indices = frame_indices.copy() image = img_matrix[t0][v0] cond_motion = torch.cat([img_matrix[t][v0] for t in forward_frame_indices], 0) cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0) forward_inputs = prepare_sampling( version_dict, model, image, seed, polars, azims, cond_motion, cond_view, ) # backward sampling backward_frame_indices = frame_indices[::-1].copy() image = img_matrix[t1][v0] cond_motion = torch.cat([img_matrix[t][v0] for t in backward_frame_indices], 0) cond_view = torch.cat([img_matrix[t1][v] for v in view_indices], 0) backward_inputs = prepare_sampling( version_dict, model, image, seed, polars, azims, cond_motion, cond_view, ) return ( forward_inputs, forward_frame_indices, backward_inputs, backward_frame_indices, ) def prepare_inputs( frame_indices, img_matrix, v0, view_indices, model, version_dict, seed, polars, azims, ): load_module_gpu(model.conditioner) # forward sampling forward_frame_indices = frame_indices.copy() t0 = forward_frame_indices[0] image = img_matrix[t0][v0] cond_motion = torch.cat([img_matrix[t][v0] for t in forward_frame_indices], 0) cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0) forward_inputs = prepare_sampling( version_dict, model, image, seed, polars, azims, cond_motion, cond_view, ) # backward sampling backward_frame_indices = frame_indices[::-1].copy() t0 = backward_frame_indices[0] image = img_matrix[t0][v0] cond_motion = torch.cat([img_matrix[t][v0] for t in backward_frame_indices], 0) cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0) backward_inputs = prepare_sampling( version_dict, model, image, seed, polars, azims, cond_motion, cond_view, ) unload_module_gpu(model.conditioner) return ( forward_inputs, forward_frame_indices, backward_inputs, backward_frame_indices, ) def do_sample( model, sampler, value_dict, num_samples, H, W, C, F, force_uc_zero_embeddings: Optional[List] = None, force_cond_zero_embeddings: Optional[List] = None, batch2model_input: List = None, return_latents=False, filter=None, T=None, additional_batch_uc_fields=None, decoding_t=None, ): force_uc_zero_embeddings = default(force_uc_zero_embeddings, []) batch2model_input = default(batch2model_input, []) additional_batch_uc_fields = default(additional_batch_uc_fields, []) precision_scope = autocast with torch.no_grad(): with precision_scope("cuda"): with model.ema_scope(): if T is not None: num_samples = [num_samples, T] else: num_samples = [num_samples] load_module_gpu(model.conditioner) batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples, T=T, additional_batch_uc_fields=additional_batch_uc_fields, ) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings, force_cond_zero_embeddings=force_cond_zero_embeddings, ) unload_module_gpu(model.conditioner) for k in c: if not k == "crossattn": c[k], uc[k] = map( lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc) ) if value_dict["image_only_indicator"] == 0: c["cond_view"] *= 0 uc["cond_view"] *= 0 additional_model_inputs = {} for k in batch2model_input: if k == "image_only_indicator": assert T is not None if isinstance( sampler.guider, ( VanillaCFG, LinearPredictionGuider, TrianglePredictionGuider, TrapezoidPredictionGuider, SpatiotemporalPredictionGuider, ), ): additional_model_inputs[k] = ( torch.zeros(num_samples[0] * 2, num_samples[1]).to( "cuda" ) + value_dict["image_only_indicator"] ) else: additional_model_inputs[k] = torch.zeros(num_samples).to( "cuda" ) else: additional_model_inputs[k] = batch[k] shape = (math.prod(num_samples), C, H // F, W // F) randn = torch.randn(shape).to("cuda") def denoiser(input, sigma, c): return model.denoiser( model.model, input, sigma, c, **additional_model_inputs ) load_module_gpu(model.model) load_module_gpu(model.denoiser) samples_z = sampler(denoiser, randn, cond=c, uc=uc) unload_module_gpu(model.denoiser) unload_module_gpu(model.model) load_module_gpu(model.first_stage_model) if isinstance(model.first_stage_model.decoder, VideoDecoder): samples_x = model.decode_first_stage( samples_z, timesteps=default(decoding_t, T) ) else: samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) unload_module_gpu(model.first_stage_model) if filter is not None: samples = filter(samples) if return_latents: return samples, samples_z return samples def prepare_sampling_( model, sampler, value_dict, num_samples, force_uc_zero_embeddings: Optional[List] = None, force_cond_zero_embeddings: Optional[List] = None, batch2model_input: List = None, T=None, additional_batch_uc_fields=None, ): force_uc_zero_embeddings = default(force_uc_zero_embeddings, []) batch2model_input = default(batch2model_input, []) additional_batch_uc_fields = default(additional_batch_uc_fields, []) precision_scope = autocast with torch.no_grad(): with precision_scope("cuda"): with model.ema_scope(): if T is not None: num_samples = [num_samples, T] else: num_samples = [num_samples] load_module_gpu(model.conditioner) batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples, T=T, additional_batch_uc_fields=additional_batch_uc_fields, ) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings, force_cond_zero_embeddings=force_cond_zero_embeddings, ) unload_module_gpu(model.conditioner) for k in c: if not k == "crossattn": c[k], uc[k] = map( lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc) ) additional_model_inputs = {} for k in batch2model_input: if k == "image_only_indicator": assert T is not None if isinstance( sampler.guider, ( VanillaCFG, LinearPredictionGuider, TrianglePredictionGuider, TrapezoidPredictionGuider, SpatiotemporalPredictionGuider, ), ): additional_model_inputs[k] = ( torch.zeros(num_samples[0] * 2, num_samples[1]).to( "cuda" ) + value_dict["image_only_indicator"] ) else: additional_model_inputs[k] = torch.zeros(num_samples).to( "cuda" ) else: additional_model_inputs[k] = batch[k] return c, uc, additional_model_inputs def do_sample_per_step( model, sampler, noisy_latents, c, uc, step, additional_model_inputs ): precision_scope = autocast with torch.no_grad(): with precision_scope("cuda"): with model.ema_scope(): noisy_latents_scaled, s_in, sigmas, num_sigmas, _, _ = ( sampler.prepare_sampling_loop( noisy_latents.clone(), c, uc, sampler.num_steps ) ) if step == 0: latents = noisy_latents_scaled else: latents = noisy_latents def denoiser(input, sigma, c): return model.denoiser( model.model, input, sigma, c, **additional_model_inputs ) gamma = ( min(sampler.s_churn / (num_sigmas - 1), 2**0.5 - 1) if sampler.s_tmin <= sigmas[step] <= sampler.s_tmax else 0.0 ) load_module_gpu(model.model) load_module_gpu(model.denoiser) samples_z = sampler.sampler_step( s_in * sigmas[step], s_in * sigmas[step + 1], denoiser, latents, c, uc, gamma, ) unload_module_gpu(model.denoiser) unload_module_gpu(model.model) return samples_z def prepare_sampling( version_dict, model, image, seed=23, polar_rad=[10] * 21, azim_rad=np.linspace(0, 360, 21 + 1)[1:], cond_motion=None, cond_view=None, ): options = version_dict["options"] H = version_dict["H"] W = version_dict["W"] T = version_dict["T"] C = version_dict["C"] F = version_dict["f"] init_dict = { "orig_width": 576, "orig_height": 576, "target_width": W, "target_height": H, } ukeys = set(get_unique_embedder_keys_from_conditioner(model.conditioner)) value_dict = init_embedder_options_no_st( ukeys, init_dict, negative_prompt=options.get("negative_promt", ""), prompt="A 3D model.", ) if "fps" not in ukeys: value_dict["fps"] = 6 value_dict["is_image"] = 0 value_dict["is_webvid"] = 0 value_dict["image_only_indicator"] = 1.0 cond_aug = 0.00 if cond_motion is not None: value_dict["cond_frames_without_noise"] = cond_motion value_dict["cond_frames"] = ( cond_motion[:, None].repeat(1, cond_view.shape[0], 1, 1, 1).flatten(0, 1) ) else: value_dict["cond_frames_without_noise"] = image value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image) value_dict["cond_aug"] = cond_aug value_dict["polar_rad"] = polar_rad value_dict["azimuth_rad"] = azim_rad value_dict["rotated"] = False value_dict["cond_motion"] = cond_motion value_dict["cond_view"] = cond_view options["num_frames"] = T sampler, num_rows, num_cols = init_sampling_no_st(options=options) num_samples = num_rows * num_cols c, uc, additional_model_inputs = prepare_sampling_( model, sampler, value_dict, num_samples, force_uc_zero_embeddings=options.get("force_uc_zero_embeddings", None), force_cond_zero_embeddings=options.get("force_cond_zero_embeddings", None), batch2model_input=["num_video_frames", "image_only_indicator"], T=T, ) return c, uc, additional_model_inputs, sampler def get_unique_embedder_keys_from_conditioner(conditioner): return list(set([x.input_key for x in conditioner.embedders])) def get_batch_sv3d(keys, value_dict, N, T, device): batch = {} batch_uc = {} for key in keys: if key == "fps_id": batch[key] = ( torch.tensor([value_dict["fps_id"]]) .to(device) .repeat(int(math.prod(N))) ) elif key == "motion_bucket_id": batch[key] = ( torch.tensor([value_dict["motion_bucket_id"]]) .to(device) .repeat(int(math.prod(N))) ) elif key == "cond_aug": batch[key] = repeat( torch.tensor([value_dict["cond_aug"]]).to(device), "1 -> b", b=math.prod(N), ) elif key == "cond_frames" or key == "cond_frames_without_noise": batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=N[0]) elif key == "polars_rad" or key == "azimuths_rad": batch[key] = torch.tensor(value_dict[key]).to(device).repeat(N[0]) else: batch[key] = value_dict[key] if T is not None: batch["num_video_frames"] = T for key in batch.keys(): if key not in batch_uc and isinstance(batch[key], torch.Tensor): batch_uc[key] = torch.clone(batch[key]) return batch, batch_uc def get_batch( keys, value_dict: dict, N: Union[List, ListConfig], device: str = "cuda", T: int = None, additional_batch_uc_fields: List[str] = [], ): batch = {} batch_uc = {} for key in keys: if key == "txt": batch["txt"] = [value_dict["prompt"]] * math.prod(N) batch_uc["txt"] = [value_dict["negative_prompt"]] * math.prod(N) elif key == "original_size_as_tuple": batch["original_size_as_tuple"] = ( torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) .to(device) .repeat(math.prod(N), 1) ) elif key == "crop_coords_top_left": batch["crop_coords_top_left"] = ( torch.tensor( [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] ) .to(device) .repeat(math.prod(N), 1) ) elif key == "aesthetic_score": batch["aesthetic_score"] = ( torch.tensor([value_dict["aesthetic_score"]]) .to(device) .repeat(math.prod(N), 1) ) batch_uc["aesthetic_score"] = ( torch.tensor([value_dict["negative_aesthetic_score"]]) .to(device) .repeat(math.prod(N), 1) ) elif key == "target_size_as_tuple": batch["target_size_as_tuple"] = ( torch.tensor([value_dict["target_height"], value_dict["target_width"]]) .to(device) .repeat(math.prod(N), 1) ) elif key == "fps": batch[key] = ( torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N)) ) elif key == "fps_id": batch[key] = ( torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N)) ) elif key == "motion_bucket_id": batch[key] = ( torch.tensor([value_dict["motion_bucket_id"]]) .to(device) .repeat(math.prod(N)) ) elif key == "pool_image": batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to( device, dtype=torch.half ) elif key == "is_image": batch[key] = ( torch.tensor([value_dict["is_image"]]) .to(device) .repeat(math.prod(N)) .long() ) elif key == "is_webvid": batch[key] = ( torch.tensor([value_dict["is_webvid"]]) .to(device) .repeat(math.prod(N)) .long() ) elif key == "cond_aug": batch[key] = repeat( torch.tensor([value_dict["cond_aug"]]).to("cuda"), "1 -> b", b=math.prod(N), ) elif ( key == "cond_frames" or key == "cond_frames_without_noise" or key == "back_frames" ): # batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=N[0]) batch[key] = value_dict[key] elif key == "interpolation_context": batch[key] = repeat( value_dict["interpolation_context"], "b ... -> (b n) ...", n=N[1] ) elif key == "start_frame": assert T is not None batch[key] = repeat(value_dict[key], "b ... -> (b t) ...", t=T) elif key == "polar_rad" or key == "azimuth_rad": batch[key] = ( torch.tensor(value_dict[key]).to(device).repeat(math.prod(N) // T) ) elif key == "rotated": batch[key] = ( torch.tensor([value_dict["rotated"]]).to(device).repeat(math.prod(N)) ) else: batch[key] = value_dict[key] if T is not None: batch["num_video_frames"] = T for key in batch.keys(): if key not in batch_uc and isinstance(batch[key], torch.Tensor): batch_uc[key] = torch.clone(batch[key]) elif key in additional_batch_uc_fields and key not in batch_uc: batch_uc[key] = copy.copy(batch[key]) return batch, batch_uc def load_model( config: str, device: str, num_frames: int, num_steps: int, verbose: bool = False, ckpt_path: str = None, ): config = OmegaConf.load(config) if device == "cuda": config.model.params.conditioner_config.params.emb_models[ 0 ].params.open_clip_embedding_config.params.init_device = device config.model.params.sampler_config.params.verbose = verbose config.model.params.sampler_config.params.num_steps = num_steps config.model.params.sampler_config.params.guider_config.params.num_frames = ( num_frames ) if ckpt_path is not None: config.model.params.ckpt_path = ckpt_path if device == "cuda": with torch.device(device): model = instantiate_from_config(config.model).to(device).eval() else: model = instantiate_from_config(config.model).to(device).eval() filter = DeepFloydDataFiltering(verbose=False, device=device) return model, filter ================================================ FILE: scripts/demo/turbo.py ================================================ from st_keyup import st_keyup from streamlit_helpers import * from sgm.modules.diffusionmodules.sampling import EulerAncestralSampler VERSION2SPECS = { "SDXL-Turbo": { "H": 512, "W": 512, "C": 4, "f": 8, "is_legacy": False, "config": "configs/inference/sd_xl_base.yaml", "ckpt": "checkpoints/sd_xl_turbo_1.0.safetensors", }, } class SubstepSampler(EulerAncestralSampler): def __init__(self, n_sample_steps=1, *args, **kwargs): super().__init__(*args, **kwargs) self.n_sample_steps = n_sample_steps self.steps_subset = [0, 100, 200, 300, 1000] def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): sigmas = self.discretization( self.num_steps if num_steps is None else num_steps, device=self.device ) sigmas = sigmas[ self.steps_subset[: self.n_sample_steps] + self.steps_subset[-1:] ] uc = cond x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) num_sigmas = len(sigmas) s_in = x.new_ones([x.shape[0]]) return x, s_in, sigmas, num_sigmas, cond, uc def seeded_randn(shape, seed): randn = np.random.RandomState(seed).randn(*shape) randn = torch.from_numpy(randn).to(device="cuda", dtype=torch.float32) return randn class SeededNoise: def __init__(self, seed): self.seed = seed def __call__(self, x): self.seed = self.seed + 1 return seeded_randn(x.shape, self.seed) def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): value_dict = {} for key in keys: if key == "txt": value_dict["prompt"] = prompt value_dict["negative_prompt"] = "" if key == "original_size_as_tuple": orig_width = init_dict["orig_width"] orig_height = init_dict["orig_height"] value_dict["orig_width"] = orig_width value_dict["orig_height"] = orig_height if key == "crop_coords_top_left": crop_coord_top = 0 crop_coord_left = 0 value_dict["crop_coords_top"] = crop_coord_top value_dict["crop_coords_left"] = crop_coord_left if key == "aesthetic_score": value_dict["aesthetic_score"] = 6.0 value_dict["negative_aesthetic_score"] = 2.5 if key == "target_size_as_tuple": value_dict["target_width"] = init_dict["target_width"] value_dict["target_height"] = init_dict["target_height"] return value_dict def sample( model, sampler, prompt="A lush garden with oversized flowers and vibrant colors, inhabited by miniature animals.", H=1024, W=1024, seed=0, filter=None, ): F = 8 C = 4 shape = (1, C, H // F, W // F) value_dict = init_embedder_options( keys=get_unique_embedder_keys_from_conditioner(model.conditioner), init_dict={ "orig_width": W, "orig_height": H, "target_width": W, "target_height": H, }, prompt=prompt, ) if seed is None: seed = torch.seed() precision_scope = autocast with torch.no_grad(): with precision_scope("cuda"): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, [1], ) c = model.conditioner(batch) uc = None randn = seeded_randn(shape, seed) def denoiser(input, sigma, c): return model.denoiser( model.model, input, sigma, c, ) samples_z = sampler(denoiser, randn, cond=c, uc=uc) samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) if filter is not None: samples = filter(samples) samples = ( (255 * samples) .to(dtype=torch.uint8) .permute(0, 2, 3, 1) .detach() .cpu() .numpy() ) return samples def v_spacer(height) -> None: for _ in range(height): st.write("\n") if __name__ == "__main__": st.title("Turbo") head_cols = st.columns([1, 1, 1]) with head_cols[0]: version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0) version_dict = VERSION2SPECS[version] with head_cols[1]: v_spacer(2) if st.checkbox("Load Model"): mode = "txt2img" else: mode = "skip" if mode != "skip": state = init_st(version_dict, load_filter=True) if state["msg"]: st.info(state["msg"]) model = state["model"] load_model(model) # seed if "seed" not in st.session_state: st.session_state.seed = 0 def increment_counter(): st.session_state.seed += 1 def decrement_counter(): if st.session_state.seed > 0: st.session_state.seed -= 1 with head_cols[2]: n_steps = st.number_input(label="number of steps", min_value=1, max_value=4) sampler = SubstepSampler( n_sample_steps=1, num_steps=1000, eta=1.0, discretization_config=dict( target="sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization" ), ) sampler.n_sample_steps = n_steps default_prompt = ( "A cinematic shot of a baby racoon wearing an intricate italian priest robe." ) prompt = st_keyup( "Enter a value", value=default_prompt, debounce=300, key="interactive_text" ) cols = st.columns([1, 5, 1]) if mode != "skip": with cols[0]: v_spacer(14) st.button("↩", on_click=decrement_counter) with cols[2]: v_spacer(14) st.button("↪", on_click=increment_counter) sampler.noise_sampler = SeededNoise(seed=st.session_state.seed) out = sample( model, sampler, H=512, W=512, seed=st.session_state.seed, prompt=prompt, filter=state.get("filter"), ) with cols[1]: st.image(out[0]) ================================================ FILE: scripts/demo/video_sampling.py ================================================ import os import sys sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../"))) from pytorch_lightning import seed_everything from scripts.demo.streamlit_helpers import * from scripts.demo.sv3d_helpers import * SAVE_PATH = "outputs/demo/vid/" VERSION2SPECS = { "svd": { "T": 14, "H": 576, "W": 1024, "C": 4, "f": 8, "config": "configs/inference/svd.yaml", "ckpt": "checkpoints/svd.safetensors", "options": { "discretization": 1, "cfg": 2.5, "sigma_min": 0.002, "sigma_max": 700.0, "rho": 7.0, "guider": 2, "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"], "num_steps": 25, }, }, "svd_image_decoder": { "T": 14, "H": 576, "W": 1024, "C": 4, "f": 8, "config": "configs/inference/svd_image_decoder.yaml", "ckpt": "checkpoints/svd_image_decoder.safetensors", "options": { "discretization": 1, "cfg": 2.5, "sigma_min": 0.002, "sigma_max": 700.0, "rho": 7.0, "guider": 2, "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"], "num_steps": 25, }, }, "svd_xt": { "T": 25, "H": 576, "W": 1024, "C": 4, "f": 8, "config": "configs/inference/svd.yaml", "ckpt": "checkpoints/svd_xt.safetensors", "options": { "discretization": 1, "cfg": 3.0, "min_cfg": 1.5, "sigma_min": 0.002, "sigma_max": 700.0, "rho": 7.0, "guider": 2, "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"], "num_steps": 30, "decoding_t": 14, }, }, "svd_xt_image_decoder": { "T": 25, "H": 576, "W": 1024, "C": 4, "f": 8, "config": "configs/inference/svd_image_decoder.yaml", "ckpt": "checkpoints/svd_xt_image_decoder.safetensors", "options": { "discretization": 1, "cfg": 3.0, "min_cfg": 1.5, "sigma_min": 0.002, "sigma_max": 700.0, "rho": 7.0, "guider": 2, "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"], "num_steps": 30, "decoding_t": 14, }, }, "sv3d_u": { "T": 21, "H": 576, "W": 576, "C": 4, "f": 8, "config": "configs/inference/sv3d_u.yaml", "ckpt": "checkpoints/sv3d_u.safetensors", "options": { "discretization": 1, "cfg": 2.5, "sigma_min": 0.002, "sigma_max": 700.0, "rho": 7.0, "guider": 3, "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"], "num_steps": 50, "decoding_t": 14, }, }, "sv3d_p": { "T": 21, "H": 576, "W": 576, "C": 4, "f": 8, "config": "configs/inference/sv3d_p.yaml", "ckpt": "checkpoints/sv3d_p.safetensors", "options": { "discretization": 1, "cfg": 2.5, "sigma_min": 0.002, "sigma_max": 700.0, "rho": 7.0, "guider": 3, "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"], "num_steps": 50, "decoding_t": 14, }, }, } if __name__ == "__main__": st.title("Stable Video Diffusion / SV3D") version = st.selectbox( "Model Version", [k for k in VERSION2SPECS.keys()], 0, ) version_dict = VERSION2SPECS[version] if st.checkbox("Load Model"): mode = "img2vid" else: mode = "skip" H = st.sidebar.number_input( "H", value=version_dict["H"], min_value=64, max_value=2048 ) W = st.sidebar.number_input( "W", value=version_dict["W"], min_value=64, max_value=2048 ) T = st.sidebar.number_input( "T", value=version_dict["T"], min_value=0, max_value=128 ) C = version_dict["C"] F = version_dict["f"] options = version_dict["options"] if mode != "skip": state = init_st(version_dict, load_filter=True) if state["msg"]: st.info(state["msg"]) model = state["model"] ukeys = set( get_unique_embedder_keys_from_conditioner(state["model"].conditioner) ) value_dict = init_embedder_options( ukeys, {}, ) if "fps" not in ukeys: value_dict["fps"] = 10 value_dict["image_only_indicator"] = 0 if mode == "img2vid": img = load_img_for_prediction(W, H) if "sv3d" in version: cond_aug = 1e-5 else: cond_aug = st.number_input( "Conditioning augmentation:", value=0.02, min_value=0.0 ) value_dict["cond_frames_without_noise"] = img value_dict["cond_frames"] = img + cond_aug * torch.randn_like(img) value_dict["cond_aug"] = cond_aug if "sv3d_p" in version: elev_deg = st.number_input("elev_deg", value=5, min_value=-90, max_value=90) trajectory = st.selectbox( "Trajectory", ["same elevation", "dynamic"], 0, ) if trajectory == "same elevation": value_dict["polars_rad"] = np.array([np.deg2rad(90 - elev_deg)] * T) value_dict["azimuths_rad"] = np.linspace(0, 2 * np.pi, T + 1)[1:] elif trajectory == "dynamic": azim_rad, elev_rad = gen_dynamic_loop(length=21, elev_deg=elev_deg) value_dict["polars_rad"] = np.deg2rad(90) - elev_rad value_dict["azimuths_rad"] = azim_rad elif "sv3d_u" in version: elev_deg = st.number_input("elev_deg", value=5, min_value=-90, max_value=90) value_dict["polars_rad"] = np.array([np.deg2rad(90 - elev_deg)] * T) value_dict["azimuths_rad"] = np.linspace(0, 2 * np.pi, T + 1)[1:] seed = st.sidebar.number_input( "seed", value=23, min_value=0, max_value=int(1e9) ) seed_everything(seed) save_locally, save_path = init_save_locally( os.path.join(SAVE_PATH, version), init_value=True ) if "sv3d" in version: plot_save_path = os.path.join(save_path, "plot_3D.png") plot_3D( azim=value_dict["azimuths_rad"], polar=value_dict["polars_rad"], save_path=plot_save_path, dynamic=("sv3d_p" in version), ) st.image( plot_save_path, f"3D camera trajectory", ) options["num_frames"] = T sampler, num_rows, num_cols = init_sampling(options=options) num_samples = num_rows * num_cols decoding_t = st.number_input( "Decode t frames at a time (set small if you are low on VRAM)", value=options.get("decoding_t", T), min_value=1, max_value=int(1e9), ) if st.checkbox("Overwrite fps in mp4 generator", False): saving_fps = st.number_input( f"saving video at fps:", value=value_dict["fps"], min_value=1 ) else: saving_fps = value_dict["fps"] if st.button("Sample"): out = do_sample( model, sampler, value_dict, num_samples, H, W, C, F, T=T, batch2model_input=["num_video_frames", "image_only_indicator"], force_uc_zero_embeddings=options.get("force_uc_zero_embeddings", None), force_cond_zero_embeddings=options.get( "force_cond_zero_embeddings", None ), return_latents=False, decoding_t=decoding_t, ) if isinstance(out, (tuple, list)): samples, samples_z = out else: samples = out samples_z = None if save_locally: save_video_as_grid_and_mp4(samples, save_path, T, fps=saving_fps) ================================================ FILE: scripts/sampling/configs/sv3d_p.yaml ================================================ model: target: sgm.models.diffusion.DiffusionEngine params: scale_factor: 0.18215 disable_first_stage_autocast: True ckpt_path: checkpoints/sv3d_p.safetensors denoiser_config: target: sgm.modules.diffusionmodules.denoiser.Denoiser params: scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise network_config: target: sgm.modules.diffusionmodules.video_model.VideoUNet params: adm_in_channels: 1280 num_classes: sequential use_checkpoint: True in_channels: 8 out_channels: 4 model_channels: 320 attention_resolutions: [4, 2, 1] num_res_blocks: 2 channel_mult: [1, 2, 4, 4] num_head_channels: 64 use_linear_in_transformer: True transformer_depth: 1 context_dim: 1024 spatial_transformer_attn_type: softmax-xformers extra_ff_mix_layer: True use_spatial_context: True merge_strategy: learned_with_images video_kernel_size: [3, 1, 1] conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - input_key: cond_frames_without_noise is_trainable: False target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder params: n_cond_frames: 1 n_copies: 1 open_clip_embedding_config: target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder params: freeze: True - input_key: cond_frames is_trainable: False target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder params: disable_encoder_autocast: True n_cond_frames: 1 n_copies: 1 is_ae: True encoder_config: target: sgm.models.autoencoder.AutoencoderKLModeOnly params: embed_dim: 4 monitor: val/rec_loss ddconfig: attn_type: vanilla-xformers double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 lossconfig: target: torch.nn.Identity - input_key: cond_aug is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - input_key: polars_rad is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 512 - input_key: azimuths_rad is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 512 first_stage_config: target: sgm.models.autoencoder.AutoencodingEngine params: loss_config: target: torch.nn.Identity regularizer_config: target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer encoder_config: target: torch.nn.Identity decoder_config: target: sgm.modules.diffusionmodules.model.Decoder params: attn_type: vanilla-xformers double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [ 1, 2, 4, 4 ] num_res_blocks: 2 attn_resolutions: [ ] dropout: 0.0 sampler_config: target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler params: discretization_config: target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization params: sigma_max: 700.0 guider_config: target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider params: max_scale: 2.5 ================================================ FILE: scripts/sampling/configs/sv3d_u.yaml ================================================ model: target: sgm.models.diffusion.DiffusionEngine params: scale_factor: 0.18215 disable_first_stage_autocast: True ckpt_path: checkpoints/sv3d_u.safetensors denoiser_config: target: sgm.modules.diffusionmodules.denoiser.Denoiser params: scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise network_config: target: sgm.modules.diffusionmodules.video_model.VideoUNet params: adm_in_channels: 256 num_classes: sequential use_checkpoint: True in_channels: 8 out_channels: 4 model_channels: 320 attention_resolutions: [4, 2, 1] num_res_blocks: 2 channel_mult: [1, 2, 4, 4] num_head_channels: 64 use_linear_in_transformer: True transformer_depth: 1 context_dim: 1024 spatial_transformer_attn_type: softmax-xformers extra_ff_mix_layer: True use_spatial_context: True merge_strategy: learned_with_images video_kernel_size: [3, 1, 1] conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - is_trainable: False input_key: cond_frames_without_noise target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder params: n_cond_frames: 1 n_copies: 1 open_clip_embedding_config: target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder params: freeze: True - input_key: cond_frames is_trainable: False target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder params: disable_encoder_autocast: True n_cond_frames: 1 n_copies: 1 is_ae: True encoder_config: target: sgm.models.autoencoder.AutoencoderKLModeOnly params: embed_dim: 4 monitor: val/rec_loss ddconfig: attn_type: vanilla-xformers double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 lossconfig: target: torch.nn.Identity - input_key: cond_aug is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 first_stage_config: target: sgm.models.autoencoder.AutoencodingEngine params: loss_config: target: torch.nn.Identity regularizer_config: target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer encoder_config: target: torch.nn.Identity decoder_config: target: sgm.modules.diffusionmodules.model.Decoder params: attn_type: vanilla-xformers double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [ 1, 2, 4, 4 ] num_res_blocks: 2 attn_resolutions: [ ] dropout: 0.0 sampler_config: target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler params: discretization_config: target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization params: sigma_max: 700.0 guider_config: target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider params: max_scale: 2.5 ================================================ FILE: scripts/sampling/configs/sv4d.yaml ================================================ N_TIME: 5 N_VIEW: 8 N_FRAMES: 40 model: target: sgm.models.diffusion.DiffusionEngine params: scale_factor: 0.18215 en_and_decode_n_samples_a_time: 7 disable_first_stage_autocast: True ckpt_path: checkpoints/sv4d.safetensors denoiser_config: target: sgm.modules.diffusionmodules.denoiser.Denoiser params: scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise network_config: target: sgm.modules.diffusionmodules.video_model.SpatialUNetModelWithTime params: adm_in_channels: 1280 attention_resolutions: [4, 2, 1] channel_mult: [1, 2, 4, 4] context_dim: 1024 motion_context_dim: 4 extra_ff_mix_layer: True in_channels: 8 legacy: False model_channels: 320 num_classes: sequential num_head_channels: 64 num_res_blocks: 2 out_channels: 4 replicate_time_mix_bug: True spatial_transformer_attn_type: softmax-xformers time_block_merge_factor: 0.0 time_block_merge_strategy: learned_with_images time_kernel_size: [3, 1, 1] time_mix_legacy: False transformer_depth: 1 use_checkpoint: False use_linear_in_transformer: True use_spatial_context: True use_spatial_transformer: True use_motion_attention: True conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - input_key: cond_frames_without_noise target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder is_trainable: False params: n_cond_frames: ${N_TIME} n_copies: 1 open_clip_embedding_config: target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder params: freeze: True - input_key: cond_frames target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder is_trainable: False params: is_ae: True n_cond_frames: ${N_FRAMES} n_copies: 1 encoder_config: target: sgm.models.autoencoder.AutoencoderKLModeOnly params: ddconfig: attn_resolutions: [] attn_type: vanilla-xformers ch: 128 ch_mult: [1, 2, 4, 4] double_z: True dropout: 0.0 in_channels: 3 num_res_blocks: 2 out_ch: 3 resolution: 256 z_channels: 4 embed_dim: 4 lossconfig: target: torch.nn.Identity monitor: val/rec_loss sigma_cond_config: target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 sigma_sampler_config: target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler - input_key: polar_rad is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 512 - input_key: azimuth_rad is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 512 - input_key: cond_view is_trainable: False target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder params: encoder_config: target: sgm.models.autoencoder.AutoencoderKLModeOnly params: embed_dim: 4 monitor: val/rec_loss ddconfig: attn_resolutions: [] attn_type: vanilla-xformers ch: 128 ch_mult: [1, 2, 4, 4] double_z: True dropout: 0.0 in_channels: 3 num_res_blocks: 2 out_ch: 3 resolution: 256 z_channels: 4 lossconfig: target: torch.nn.Identity is_ae: True n_cond_frames: ${N_VIEW} n_copies: 1 sigma_sampler_config: target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler - input_key: cond_motion is_trainable: False target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder params: is_ae: True n_cond_frames: ${N_TIME} n_copies: 1 encoder_config: target: sgm.models.autoencoder.AutoencoderKLModeOnly params: embed_dim: 4 monitor: val/rec_loss ddconfig: attn_resolutions: [] attn_type: vanilla-xformers ch: 128 ch_mult: [1, 2, 4, 4] double_z: True dropout: 0.0 in_channels: 3 num_res_blocks: 2 out_ch: 3 resolution: 256 z_channels: 4 lossconfig: target: torch.nn.Identity sigma_sampler_config: target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler first_stage_config: target: sgm.models.autoencoder.AutoencodingEngine params: loss_config: target: torch.nn.Identity regularizer_config: target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer encoder_config: target: torch.nn.Identity decoder_config: target: sgm.modules.diffusionmodules.model.Decoder params: attn_resolutions: [] attn_type: vanilla-xformers ch: 128 ch_mult: [1, 2, 4, 4] double_z: True dropout: 0.0 in_channels: 3 num_res_blocks: 2 out_ch: 3 resolution: 256 z_channels: 4 sampler_config: target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler params: discretization_config: target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization params: sigma_max: 500.0 guider_config: target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider params: max_scale: 2.5 num_frames: ${N_FRAMES} additional_cond_keys: [ cond_view, cond_motion ] ================================================ FILE: scripts/sampling/configs/sv4d2.yaml ================================================ N_TIME: 12 N_VIEW: 4 N_FRAMES: 48 model: target: sgm.models.diffusion.DiffusionEngine params: scale_factor: 0.18215 en_and_decode_n_samples_a_time: 8 disable_first_stage_autocast: True ckpt_path: checkpoints/sv4d2.safetensors denoiser_config: target: sgm.modules.diffusionmodules.denoiser.Denoiser params: scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise network_config: target: sgm.modules.diffusionmodules.video_model.SpatialUNetModelWithTime params: adm_in_channels: 1280 attention_resolutions: [4, 2, 1] channel_mult: [1, 2, 4, 4] context_dim: 1024 motion_context_dim: 4 extra_ff_mix_layer: True in_channels: 8 legacy: False model_channels: 320 num_classes: sequential num_head_channels: 64 num_res_blocks: 2 out_channels: 4 replicate_time_mix_bug: True spatial_transformer_attn_type: softmax-xformers time_block_merge_factor: 0.0 time_block_merge_strategy: learned_with_images time_kernel_size: [3, 1, 1] time_mix_legacy: False transformer_depth: 1 use_checkpoint: False use_linear_in_transformer: True use_spatial_context: True use_spatial_transformer: True separate_motion_merge_factor: True use_motion_attention: True use_3d_attention: True use_camera_emb: True conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - input_key: cond_frames_without_noise target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder is_trainable: False params: n_cond_frames: ${N_TIME} n_copies: 1 open_clip_embedding_config: target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder params: freeze: True - input_key: cond_frames target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder is_trainable: False params: is_ae: True n_cond_frames: ${N_FRAMES} n_copies: 1 encoder_config: target: sgm.models.autoencoder.AutoencoderKLModeOnly params: ddconfig: attn_resolutions: [] attn_type: vanilla-xformers ch: 128 ch_mult: [1, 2, 4, 4] double_z: True dropout: 0.0 in_channels: 3 num_res_blocks: 2 out_ch: 3 resolution: 256 z_channels: 4 embed_dim: 4 lossconfig: target: torch.nn.Identity monitor: val/rec_loss sigma_cond_config: target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 sigma_sampler_config: target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler - input_key: polar_rad is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 512 - input_key: azimuth_rad is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 512 - input_key: cond_view is_trainable: False target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder params: is_ae: True n_cond_frames: ${N_VIEW} n_copies: 1 encoder_config: target: sgm.models.autoencoder.AutoencoderKLModeOnly params: embed_dim: 4 monitor: val/rec_loss ddconfig: attn_resolutions: [] attn_type: vanilla-xformers ch: 128 ch_mult: [1, 2, 4, 4] double_z: True dropout: 0.0 in_channels: 3 num_res_blocks: 2 out_ch: 3 resolution: 256 z_channels: 4 lossconfig: target: torch.nn.Identity sigma_sampler_config: target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler - input_key: cond_motion is_trainable: False target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder params: is_ae: True n_cond_frames: ${N_TIME} n_copies: 1 encoder_config: target: sgm.models.autoencoder.AutoencoderKLModeOnly params: embed_dim: 4 monitor: val/rec_loss ddconfig: attn_resolutions: [] attn_type: vanilla-xformers ch: 128 ch_mult: [1, 2, 4, 4] double_z: True dropout: 0.0 in_channels: 3 num_res_blocks: 2 out_ch: 3 resolution: 256 z_channels: 4 lossconfig: target: torch.nn.Identity sigma_sampler_config: target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler first_stage_config: target: sgm.models.autoencoder.AutoencodingEngine params: loss_config: target: torch.nn.Identity regularizer_config: target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer encoder_config: target: torch.nn.Identity decoder_config: target: sgm.modules.diffusionmodules.model.Decoder params: attn_resolutions: [] attn_type: vanilla-xformers ch: 128 ch_mult: [1, 2, 4, 4] double_z: True dropout: 0.0 in_channels: 3 num_res_blocks: 2 out_ch: 3 resolution: 256 z_channels: 4 sampler_config: target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler params: num_steps: 50 discretization_config: target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization params: sigma_max: 500.0 guider_config: target: sgm.modules.diffusionmodules.guiders.SpatiotemporalPredictionGuider params: max_scale: 1.5 min_scale: 1.5 num_frames: ${N_FRAMES} num_views: ${N_VIEW} additional_cond_keys: [ cond_view, cond_motion ] ================================================ FILE: scripts/sampling/configs/sv4d2_8views.yaml ================================================ N_TIME: 5 N_VIEW: 8 N_FRAMES: 40 model: target: sgm.models.diffusion.DiffusionEngine params: scale_factor: 0.18215 en_and_decode_n_samples_a_time: 8 disable_first_stage_autocast: True ckpt_path: checkpoints/sv4d2_8views.safetensors denoiser_config: target: sgm.modules.diffusionmodules.denoiser.Denoiser params: scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise network_config: target: sgm.modules.diffusionmodules.video_model.SpatialUNetModelWithTime params: adm_in_channels: 1280 attention_resolutions: [4, 2, 1] channel_mult: [1, 2, 4, 4] context_dim: 1024 motion_context_dim: 4 extra_ff_mix_layer: True in_channels: 8 legacy: False model_channels: 320 num_classes: sequential num_head_channels: 64 num_res_blocks: 2 out_channels: 4 replicate_time_mix_bug: True spatial_transformer_attn_type: softmax-xformers time_block_merge_factor: 0.0 time_block_merge_strategy: learned_with_images time_kernel_size: [3, 1, 1] time_mix_legacy: False transformer_depth: 1 use_checkpoint: False use_linear_in_transformer: True use_spatial_context: True use_spatial_transformer: True separate_motion_merge_factor: True use_motion_attention: True use_3d_attention: False use_camera_emb: True conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - input_key: cond_frames_without_noise target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder is_trainable: False params: n_cond_frames: ${N_TIME} n_copies: 1 open_clip_embedding_config: target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder params: freeze: True - input_key: cond_frames target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder is_trainable: False params: is_ae: True n_cond_frames: ${N_FRAMES} n_copies: 1 encoder_config: target: sgm.models.autoencoder.AutoencoderKLModeOnly params: ddconfig: attn_resolutions: [] attn_type: vanilla-xformers ch: 128 ch_mult: [1, 2, 4, 4] double_z: True dropout: 0.0 in_channels: 3 num_res_blocks: 2 out_ch: 3 resolution: 256 z_channels: 4 embed_dim: 4 lossconfig: target: torch.nn.Identity monitor: val/rec_loss sigma_cond_config: target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 sigma_sampler_config: target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler - input_key: polar_rad is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 512 - input_key: azimuth_rad is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 512 - input_key: cond_view is_trainable: False target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder params: is_ae: True n_cond_frames: ${N_VIEW} n_copies: 1 encoder_config: target: sgm.models.autoencoder.AutoencoderKLModeOnly params: embed_dim: 4 monitor: val/rec_loss ddconfig: attn_resolutions: [] attn_type: vanilla-xformers ch: 128 ch_mult: [1, 2, 4, 4] double_z: True dropout: 0.0 in_channels: 3 num_res_blocks: 2 out_ch: 3 resolution: 256 z_channels: 4 lossconfig: target: torch.nn.Identity sigma_sampler_config: target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler - input_key: cond_motion is_trainable: False target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder params: is_ae: True n_cond_frames: ${N_TIME} n_copies: 1 encoder_config: target: sgm.models.autoencoder.AutoencoderKLModeOnly params: embed_dim: 4 monitor: val/rec_loss ddconfig: attn_resolutions: [] attn_type: vanilla-xformers ch: 128 ch_mult: [1, 2, 4, 4] double_z: True dropout: 0.0 in_channels: 3 num_res_blocks: 2 out_ch: 3 resolution: 256 z_channels: 4 lossconfig: target: torch.nn.Identity sigma_sampler_config: target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler first_stage_config: target: sgm.models.autoencoder.AutoencodingEngine params: loss_config: target: torch.nn.Identity regularizer_config: target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer encoder_config: target: torch.nn.Identity decoder_config: target: sgm.modules.diffusionmodules.model.Decoder params: attn_resolutions: [] attn_type: vanilla-xformers ch: 128 ch_mult: [1, 2, 4, 4] double_z: True dropout: 0.0 in_channels: 3 num_res_blocks: 2 out_ch: 3 resolution: 256 z_channels: 4 sampler_config: target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler params: num_steps: 50 discretization_config: target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization params: sigma_max: 500.0 guider_config: target: sgm.modules.diffusionmodules.guiders.SpatiotemporalPredictionGuider params: max_scale: 2.0 min_scale: 1.5 num_frames: ${N_FRAMES} num_views: ${N_VIEW} additional_cond_keys: [ cond_view, cond_motion ] ================================================ FILE: scripts/sampling/configs/svd.yaml ================================================ model: target: sgm.models.diffusion.DiffusionEngine params: scale_factor: 0.18215 disable_first_stage_autocast: True ckpt_path: checkpoints/svd.safetensors denoiser_config: target: sgm.modules.diffusionmodules.denoiser.Denoiser params: scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise network_config: target: sgm.modules.diffusionmodules.video_model.VideoUNet params: adm_in_channels: 768 num_classes: sequential use_checkpoint: True in_channels: 8 out_channels: 4 model_channels: 320 attention_resolutions: [4, 2, 1] num_res_blocks: 2 channel_mult: [1, 2, 4, 4] num_head_channels: 64 use_linear_in_transformer: True transformer_depth: 1 context_dim: 1024 spatial_transformer_attn_type: softmax-xformers extra_ff_mix_layer: True use_spatial_context: True merge_strategy: learned_with_images video_kernel_size: [3, 1, 1] conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - is_trainable: False input_key: cond_frames_without_noise target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder params: n_cond_frames: 1 n_copies: 1 open_clip_embedding_config: target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder params: freeze: True - input_key: fps_id is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - input_key: motion_bucket_id is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - input_key: cond_frames is_trainable: False target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder params: disable_encoder_autocast: True n_cond_frames: 1 n_copies: 1 is_ae: True encoder_config: target: sgm.models.autoencoder.AutoencoderKLModeOnly params: embed_dim: 4 monitor: val/rec_loss ddconfig: attn_type: vanilla-xformers double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 lossconfig: target: torch.nn.Identity - input_key: cond_aug is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 first_stage_config: target: sgm.models.autoencoder.AutoencodingEngine params: loss_config: target: torch.nn.Identity regularizer_config: target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer encoder_config: target: sgm.modules.diffusionmodules.model.Encoder params: attn_type: vanilla double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 decoder_config: target: sgm.modules.autoencoding.temporal_ae.VideoDecoder params: attn_type: vanilla double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 video_kernel_size: [3, 1, 1] sampler_config: target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler params: discretization_config: target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization params: sigma_max: 700.0 guider_config: target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider params: max_scale: 2.5 min_scale: 1.0 ================================================ FILE: scripts/sampling/configs/svd_image_decoder.yaml ================================================ model: target: sgm.models.diffusion.DiffusionEngine params: scale_factor: 0.18215 disable_first_stage_autocast: True ckpt_path: checkpoints/svd_image_decoder.safetensors denoiser_config: target: sgm.modules.diffusionmodules.denoiser.Denoiser params: scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise network_config: target: sgm.modules.diffusionmodules.video_model.VideoUNet params: adm_in_channels: 768 num_classes: sequential use_checkpoint: True in_channels: 8 out_channels: 4 model_channels: 320 attention_resolutions: [4, 2, 1] num_res_blocks: 2 channel_mult: [1, 2, 4, 4] num_head_channels: 64 use_linear_in_transformer: True transformer_depth: 1 context_dim: 1024 spatial_transformer_attn_type: softmax-xformers extra_ff_mix_layer: True use_spatial_context: True merge_strategy: learned_with_images video_kernel_size: [3, 1, 1] conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - is_trainable: False input_key: cond_frames_without_noise target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder params: n_cond_frames: 1 n_copies: 1 open_clip_embedding_config: target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder params: freeze: True - input_key: fps_id is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - input_key: motion_bucket_id is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - input_key: cond_frames is_trainable: False target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder params: disable_encoder_autocast: True n_cond_frames: 1 n_copies: 1 is_ae: True encoder_config: target: sgm.models.autoencoder.AutoencoderKLModeOnly params: embed_dim: 4 monitor: val/rec_loss ddconfig: attn_type: vanilla-xformers double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 lossconfig: target: torch.nn.Identity - input_key: cond_aug is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 first_stage_config: target: sgm.models.autoencoder.AutoencoderKL params: embed_dim: 4 monitor: val/rec_loss ddconfig: attn_type: vanilla-xformers double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 lossconfig: target: torch.nn.Identity sampler_config: target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler params: discretization_config: target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization params: sigma_max: 700.0 guider_config: target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider params: max_scale: 2.5 min_scale: 1.0 ================================================ FILE: scripts/sampling/configs/svd_xt.yaml ================================================ model: target: sgm.models.diffusion.DiffusionEngine params: scale_factor: 0.18215 disable_first_stage_autocast: True ckpt_path: checkpoints/svd_xt.safetensors denoiser_config: target: sgm.modules.diffusionmodules.denoiser.Denoiser params: scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise network_config: target: sgm.modules.diffusionmodules.video_model.VideoUNet params: adm_in_channels: 768 num_classes: sequential use_checkpoint: True in_channels: 8 out_channels: 4 model_channels: 320 attention_resolutions: [4, 2, 1] num_res_blocks: 2 channel_mult: [1, 2, 4, 4] num_head_channels: 64 use_linear_in_transformer: True transformer_depth: 1 context_dim: 1024 spatial_transformer_attn_type: softmax-xformers extra_ff_mix_layer: True use_spatial_context: True merge_strategy: learned_with_images video_kernel_size: [3, 1, 1] conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - is_trainable: False input_key: cond_frames_without_noise target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder params: n_cond_frames: 1 n_copies: 1 open_clip_embedding_config: target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder params: freeze: True - input_key: fps_id is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - input_key: motion_bucket_id is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - input_key: cond_frames is_trainable: False target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder params: disable_encoder_autocast: True n_cond_frames: 1 n_copies: 1 is_ae: True encoder_config: target: sgm.models.autoencoder.AutoencoderKLModeOnly params: embed_dim: 4 monitor: val/rec_loss ddconfig: attn_type: vanilla-xformers double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 lossconfig: target: torch.nn.Identity - input_key: cond_aug is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 first_stage_config: target: sgm.models.autoencoder.AutoencodingEngine params: loss_config: target: torch.nn.Identity regularizer_config: target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer encoder_config: target: sgm.modules.diffusionmodules.model.Encoder params: attn_type: vanilla double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 decoder_config: target: sgm.modules.autoencoding.temporal_ae.VideoDecoder params: attn_type: vanilla double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 video_kernel_size: [3, 1, 1] sampler_config: target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler params: discretization_config: target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization params: sigma_max: 700.0 guider_config: target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider params: max_scale: 3.0 min_scale: 1.5 ================================================ FILE: scripts/sampling/configs/svd_xt_1_1.yaml ================================================ model: target: sgm.models.diffusion.DiffusionEngine params: scale_factor: 0.18215 disable_first_stage_autocast: True ckpt_path: checkpoints/svd_xt_1_1.safetensors denoiser_config: target: sgm.modules.diffusionmodules.denoiser.Denoiser params: scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise network_config: target: sgm.modules.diffusionmodules.video_model.VideoUNet params: adm_in_channels: 768 num_classes: sequential use_checkpoint: True in_channels: 8 out_channels: 4 model_channels: 320 attention_resolutions: [4, 2, 1] num_res_blocks: 2 channel_mult: [1, 2, 4, 4] num_head_channels: 64 use_linear_in_transformer: True transformer_depth: 1 context_dim: 1024 spatial_transformer_attn_type: softmax-xformers extra_ff_mix_layer: True use_spatial_context: True merge_strategy: learned_with_images video_kernel_size: [3, 1, 1] conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - is_trainable: False input_key: cond_frames_without_noise target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder params: n_cond_frames: 1 n_copies: 1 open_clip_embedding_config: target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder params: freeze: True - input_key: fps_id is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - input_key: motion_bucket_id is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - input_key: cond_frames is_trainable: False target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder params: disable_encoder_autocast: True n_cond_frames: 1 n_copies: 1 is_ae: True encoder_config: target: sgm.models.autoencoder.AutoencoderKLModeOnly params: embed_dim: 4 monitor: val/rec_loss ddconfig: attn_type: vanilla-xformers double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 lossconfig: target: torch.nn.Identity - input_key: cond_aug is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 first_stage_config: target: sgm.models.autoencoder.AutoencodingEngine params: loss_config: target: torch.nn.Identity regularizer_config: target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer encoder_config: target: sgm.modules.diffusionmodules.model.Encoder params: attn_type: vanilla double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 decoder_config: target: sgm.modules.autoencoding.temporal_ae.VideoDecoder params: attn_type: vanilla double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 video_kernel_size: [3, 1, 1] sampler_config: target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler params: discretization_config: target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization params: sigma_max: 700.0 guider_config: target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider params: max_scale: 3.0 min_scale: 1.5 ================================================ FILE: scripts/sampling/configs/svd_xt_image_decoder.yaml ================================================ model: target: sgm.models.diffusion.DiffusionEngine params: scale_factor: 0.18215 disable_first_stage_autocast: True ckpt_path: checkpoints/svd_xt_image_decoder.safetensors denoiser_config: target: sgm.modules.diffusionmodules.denoiser.Denoiser params: scaling_config: target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise network_config: target: sgm.modules.diffusionmodules.video_model.VideoUNet params: adm_in_channels: 768 num_classes: sequential use_checkpoint: True in_channels: 8 out_channels: 4 model_channels: 320 attention_resolutions: [4, 2, 1] num_res_blocks: 2 channel_mult: [1, 2, 4, 4] num_head_channels: 64 use_linear_in_transformer: True transformer_depth: 1 context_dim: 1024 spatial_transformer_attn_type: softmax-xformers extra_ff_mix_layer: True use_spatial_context: True merge_strategy: learned_with_images video_kernel_size: [3, 1, 1] conditioner_config: target: sgm.modules.GeneralConditioner params: emb_models: - is_trainable: False input_key: cond_frames_without_noise target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder params: n_cond_frames: 1 n_copies: 1 open_clip_embedding_config: target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder params: freeze: True - input_key: fps_id is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - input_key: motion_bucket_id is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 - input_key: cond_frames is_trainable: False target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder params: disable_encoder_autocast: True n_cond_frames: 1 n_copies: 1 is_ae: True encoder_config: target: sgm.models.autoencoder.AutoencoderKLModeOnly params: embed_dim: 4 monitor: val/rec_loss ddconfig: attn_type: vanilla-xformers double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 lossconfig: target: torch.nn.Identity - input_key: cond_aug is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND params: outdim: 256 first_stage_config: target: sgm.models.autoencoder.AutoencoderKL params: embed_dim: 4 monitor: val/rec_loss ddconfig: attn_type: vanilla-xformers double_z: True z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: [1, 2, 4, 4] num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 lossconfig: target: torch.nn.Identity sampler_config: target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler params: discretization_config: target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization params: sigma_max: 700.0 guider_config: target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider params: max_scale: 3.0 min_scale: 1.5 ================================================ FILE: scripts/sampling/simple_video_sample.py ================================================ import math import os import sys from glob import glob from pathlib import Path from typing import List, Optional sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../"))) import cv2 import imageio import numpy as np import torch from einops import rearrange, repeat from fire import Fire from omegaconf import OmegaConf from PIL import Image from rembg import remove from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering from sgm.inference.helpers import embed_watermark from sgm.util import default, instantiate_from_config from torchvision.transforms import ToTensor def sample( input_path: str = "assets/test_image.png", # Can either be image file or folder with image files num_frames: Optional[int] = None, # 21 for SV3D num_steps: Optional[int] = None, version: str = "svd", fps_id: int = 6, motion_bucket_id: int = 127, cond_aug: float = 0.02, seed: int = 23, decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. device: str = "cuda", output_folder: Optional[str] = None, elevations_deg: Optional[float | List[float]] = 10.0, # For SV3D azimuths_deg: Optional[List[float]] = None, # For SV3D image_frame_ratio: Optional[float] = None, verbose: Optional[bool] = False, ): """ Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`. """ if version == "svd": num_frames = default(num_frames, 14) num_steps = default(num_steps, 25) output_folder = default(output_folder, "outputs/simple_video_sample/svd/") model_config = "scripts/sampling/configs/svd.yaml" elif version == "svd_xt": num_frames = default(num_frames, 25) num_steps = default(num_steps, 30) output_folder = default(output_folder, "outputs/simple_video_sample/svd_xt/") model_config = "scripts/sampling/configs/svd_xt.yaml" elif version == "svd_image_decoder": num_frames = default(num_frames, 14) num_steps = default(num_steps, 25) output_folder = default( output_folder, "outputs/simple_video_sample/svd_image_decoder/" ) model_config = "scripts/sampling/configs/svd_image_decoder.yaml" elif version == "svd_xt_image_decoder": num_frames = default(num_frames, 25) num_steps = default(num_steps, 30) output_folder = default( output_folder, "outputs/simple_video_sample/svd_xt_image_decoder/" ) model_config = "scripts/sampling/configs/svd_xt_image_decoder.yaml" elif version == "sv3d_u": num_frames = 21 num_steps = default(num_steps, 50) output_folder = default(output_folder, "outputs/simple_video_sample/sv3d_u/") model_config = "scripts/sampling/configs/sv3d_u.yaml" cond_aug = 1e-5 elif version == "sv3d_p": num_frames = 21 num_steps = default(num_steps, 50) output_folder = default(output_folder, "outputs/simple_video_sample/sv3d_p/") model_config = "scripts/sampling/configs/sv3d_p.yaml" cond_aug = 1e-5 if isinstance(elevations_deg, float) or isinstance(elevations_deg, int): elevations_deg = [elevations_deg] * num_frames assert ( len(elevations_deg) == num_frames ), f"Please provide 1 value, or a list of {num_frames} values for elevations_deg! Given {len(elevations_deg)}" polars_rad = [np.deg2rad(90 - e) for e in elevations_deg] if azimuths_deg is None: azimuths_deg = np.linspace(0, 360, num_frames + 1)[1:] % 360 assert ( len(azimuths_deg) == num_frames ), f"Please provide a list of {num_frames} values for azimuths_deg! Given {len(azimuths_deg)}" azimuths_rad = [np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg] azimuths_rad[:-1].sort() else: raise ValueError(f"Version {version} does not exist.") model, filter = load_model( model_config, device, num_frames, num_steps, verbose, ) torch.manual_seed(seed) path = Path(input_path) all_img_paths = [] if path.is_file(): if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]): all_img_paths = [input_path] else: raise ValueError("Path is not valid image file.") elif path.is_dir(): all_img_paths = sorted( [ f for f in path.iterdir() if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"] ] ) if len(all_img_paths) == 0: raise ValueError("Folder does not contain any images.") else: raise ValueError for input_img_path in all_img_paths: if "sv3d" in version: image = Image.open(input_img_path) if image.mode == "RGBA": pass else: # remove bg image.thumbnail([768, 768], Image.Resampling.LANCZOS) image = remove(image.convert("RGBA"), alpha_matting=True) # resize object in frame image_arr = np.array(image) in_w, in_h = image_arr.shape[:2] ret, mask = cv2.threshold( np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY ) x, y, w, h = cv2.boundingRect(mask) max_size = max(w, h) side_len = ( int(max_size / image_frame_ratio) if image_frame_ratio is not None else in_w ) padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) center = side_len // 2 padded_image[ center - h // 2 : center - h // 2 + h, center - w // 2 : center - w // 2 + w, ] = image_arr[y : y + h, x : x + w] # resize frame to 576x576 rgba = Image.fromarray(padded_image).resize((576, 576), Image.LANCZOS) # white bg rgba_arr = np.array(rgba) / 255.0 rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:]) input_image = Image.fromarray((rgb * 255).astype(np.uint8)) else: with Image.open(input_img_path) as image: if image.mode == "RGBA": image = image.convert("RGB") w, h = image.size if h % 64 != 0 or w % 64 != 0: width, height = map(lambda x: x - x % 64, (w, h)) input_image = input_image.resize((width, height)) print( f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!" ) input_image = np.array(image) image = ToTensor()(input_image) image = image * 2.0 - 1.0 image = image.unsqueeze(0).to(device) H, W = image.shape[2:] assert image.shape[1] == 3 F = 8 C = 4 shape = (num_frames, C, H // F, W // F) if (H, W) != (576, 1024) and "sv3d" not in version: print( "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`." ) if (H, W) != (576, 576) and "sv3d" in version: print( "WARNING: The conditioning frame you provided is not 576x576. This leads to suboptimal performance as model was only trained on 576x576." ) if motion_bucket_id > 255: print( "WARNING: High motion bucket! This may lead to suboptimal performance." ) if fps_id < 5: print("WARNING: Small fps value! This may lead to suboptimal performance.") if fps_id > 30: print("WARNING: Large fps value! This may lead to suboptimal performance.") value_dict = {} value_dict["cond_frames_without_noise"] = image value_dict["motion_bucket_id"] = motion_bucket_id value_dict["fps_id"] = fps_id value_dict["cond_aug"] = cond_aug value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image) if "sv3d_p" in version: value_dict["polars_rad"] = polars_rad value_dict["azimuths_rad"] = azimuths_rad with torch.no_grad(): with torch.autocast(device): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, [1, num_frames], T=num_frames, device=device, ) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=[ "cond_frames", "cond_frames_without_noise", ], ) for k in ["crossattn", "concat"]: uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) randn = torch.randn(shape, device=device) additional_model_inputs = {} additional_model_inputs["image_only_indicator"] = torch.zeros( 2, num_frames ).to(device) additional_model_inputs["num_video_frames"] = batch["num_video_frames"] def denoiser(input, sigma, c): return model.denoiser( model.model, input, sigma, c, **additional_model_inputs ) samples_z = model.sampler(denoiser, randn, cond=c, uc=uc) model.en_and_decode_n_samples_a_time = decoding_t samples_x = model.decode_first_stage(samples_z) if "sv3d" in version: samples_x[-1:] = value_dict["cond_frames_without_noise"] samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) os.makedirs(output_folder, exist_ok=True) base_count = len(glob(os.path.join(output_folder, "*.mp4"))) imageio.imwrite( os.path.join(output_folder, f"{base_count:06d}.jpg"), input_image ) samples = embed_watermark(samples) samples = filter(samples) vid = ( (rearrange(samples, "t c h w -> t h w c") * 255) .cpu() .numpy() .astype(np.uint8) ) video_path = os.path.join(output_folder, f"{base_count:06d}.mp4") imageio.mimwrite(video_path, vid) def get_unique_embedder_keys_from_conditioner(conditioner): return list(set([x.input_key for x in conditioner.embedders])) def get_batch(keys, value_dict, N, T, device): batch = {} batch_uc = {} for key in keys: if key == "fps_id": batch[key] = ( torch.tensor([value_dict["fps_id"]]) .to(device) .repeat(int(math.prod(N))) ) elif key == "motion_bucket_id": batch[key] = ( torch.tensor([value_dict["motion_bucket_id"]]) .to(device) .repeat(int(math.prod(N))) ) elif key == "cond_aug": batch[key] = repeat( torch.tensor([value_dict["cond_aug"]]).to(device), "1 -> b", b=math.prod(N), ) elif key == "cond_frames" or key == "cond_frames_without_noise": batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=N[0]) elif key == "polars_rad" or key == "azimuths_rad": batch[key] = torch.tensor(value_dict[key]).to(device).repeat(N[0]) else: batch[key] = value_dict[key] if T is not None: batch["num_video_frames"] = T for key in batch.keys(): if key not in batch_uc and isinstance(batch[key], torch.Tensor): batch_uc[key] = torch.clone(batch[key]) return batch, batch_uc def load_model( config: str, device: str, num_frames: int, num_steps: int, verbose: bool = False, ): config = OmegaConf.load(config) if device == "cuda": config.model.params.conditioner_config.params.emb_models[ 0 ].params.open_clip_embedding_config.params.init_device = device config.model.params.sampler_config.params.verbose = verbose config.model.params.sampler_config.params.num_steps = num_steps config.model.params.sampler_config.params.guider_config.params.num_frames = ( num_frames ) if device == "cuda": with torch.device(device): model = instantiate_from_config(config.model).to(device).eval() else: model = instantiate_from_config(config.model).to(device).eval() filter = DeepFloydDataFiltering(verbose=False, device=device) return model, filter if __name__ == "__main__": Fire(sample) ================================================ FILE: scripts/sampling/simple_video_sample_4d.py ================================================ import os import sys from glob import glob from typing import List, Optional, Union from tqdm import tqdm sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../"))) import numpy as np import torch from fire import Fire from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder from scripts.demo.sv4d_helpers import ( decode_latents, load_model, initial_model_load, read_video, run_img2vid, prepare_sampling, prepare_inputs, do_sample_per_step, sample_sv3d, save_video, preprocess_video, ) def sample( input_path: str = "assets/sv4d_videos/test_video1.mp4", # Can either be image file or folder with image files output_folder: Optional[str] = "outputs/sv4d", num_steps: Optional[int] = 20, sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p img_size: int = 576, # image resolution fps_id: int = 6, motion_bucket_id: int = 127, cond_aug: float = 1e-5, seed: int = 23, encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary. decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. device: str = "cuda", elevations_deg: Optional[Union[float, List[float]]] = 10.0, azimuths_deg: Optional[List[float]] = None, image_frame_ratio: Optional[float] = 0.917, verbose: Optional[bool] = False, remove_bg: bool = False, ): """ Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t` and `encoding_t`. """ # Set model config T = 5 # number of frames per sample V = 8 # number of views per sample F = 8 # vae factor to downsize image->latent C = 4 H, W = img_size, img_size n_frames = 21 # number of input and output video frames n_views = V + 1 # number of output video views (1 input view + 8 novel views) n_views_sv3d = 21 subsampled_views = np.array( [0, 2, 5, 7, 9, 12, 14, 16, 19] ) # subsample (V+1=)9 (uniform) views from 21 SV3D views model_config = "scripts/sampling/configs/sv4d.yaml" version_dict = { "T": T * V, "H": H, "W": W, "C": C, "f": F, "options": { "discretization": 1, "cfg": 2.0, "num_views": V, "sigma_min": 0.002, "sigma_max": 700.0, "rho": 7.0, "guider": 5, "num_steps": num_steps, "force_uc_zero_embeddings": [ "cond_frames", "cond_frames_without_noise", "cond_view", "cond_motion", ], "additional_guider_kwargs": { "additional_cond_keys": ["cond_view", "cond_motion"] }, }, } torch.manual_seed(seed) os.makedirs(output_folder, exist_ok=True) # Read input video frames i.e. images at view 0 print(f"Reading {input_path}") base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 11 processed_input_path = preprocess_video( input_path, remove_bg=remove_bg, n_frames=n_frames, W=W, H=H, output_folder=output_folder, image_frame_ratio=image_frame_ratio, base_count=base_count, ) images_v0 = read_video(processed_input_path, n_frames=n_frames, device=device) # Get camera viewpoints if isinstance(elevations_deg, float) or isinstance(elevations_deg, int): elevations_deg = [elevations_deg] * n_views_sv3d assert ( len(elevations_deg) == n_views_sv3d ), f"Please provide 1 value, or a list of {n_views_sv3d} values for elevations_deg! Given {len(elevations_deg)}" if azimuths_deg is None: azimuths_deg = np.linspace(0, 360, n_views_sv3d + 1)[1:] % 360 assert ( len(azimuths_deg) == n_views_sv3d ), f"Please provide a list of {n_views_sv3d} values for azimuths_deg! Given {len(azimuths_deg)}" polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg]) azimuths_rad = np.array( [np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg] ) # Sample multi-view images of the first frame using SV3D i.e. images at time 0 images_t0 = sample_sv3d( images_v0[0], n_views_sv3d, num_steps, sv3d_version, fps_id, motion_bucket_id, cond_aug, decoding_t, device, polars_rad, azimuths_rad, verbose, ) images_t0 = torch.roll(images_t0, 1, 0) # move conditioning image to first frame # Initialize image matrix img_matrix = [[None] * n_views for _ in range(n_frames)] for i, v in enumerate(subsampled_views): img_matrix[0][i] = images_t0[v].unsqueeze(0) for t in range(n_frames): img_matrix[t][0] = images_v0[t] save_video( os.path.join(output_folder, f"{base_count:06d}_t000.mp4"), img_matrix[0], ) # save_video( # os.path.join(output_folder, f"{base_count:06d}_v000.mp4"), # [img_matrix[t][0] for t in range(n_frames)], # ) # Load SV4D model model, filter = load_model( model_config, device, version_dict["T"], num_steps, verbose, ) model = initial_model_load(model) for emb in model.conditioner.embedders: if isinstance(emb, VideoPredictionEmbedderWithEncoder): emb.en_and_decode_n_samples_a_time = encoding_t model.en_and_decode_n_samples_a_time = decoding_t # Interleaved sampling for anchor frames t0, v0 = 0, 0 frame_indices = np.arange(T - 1, n_frames, T - 1) # [4, 8, 12, 16, 20] view_indices = np.arange(V) + 1 print(f"Sampling anchor frames {frame_indices}") image = img_matrix[t0][v0] cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0) cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0) polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() azims = (azims - azimuths_rad[v0]) % (torch.pi * 2) samples = run_img2vid( version_dict, model, image, seed, polars, azims, cond_motion, cond_view, decoding_t ) samples = samples.view(T, V, 3, H, W) for i, t in enumerate(frame_indices): for j, v in enumerate(view_indices): if img_matrix[t][v] is None: img_matrix[t][v] = samples[i, j][None] * 2 - 1 # Dense sampling for the rest print(f"Sampling dense frames:") for t0 in tqdm(np.arange(0, n_frames - 1, T - 1)): # [0, 4, 8, 12, 16] frame_indices = t0 + np.arange(T) print(f"Sampling dense frames {frame_indices}") latent_matrix = torch.randn(n_frames, n_views, C, H // F, W // F).to("cuda") polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() azims = (azims - azimuths_rad[v0]) % (torch.pi * 2) # alternate between forward and backward conditioning forward_inputs, forward_frame_indices, backward_inputs, backward_frame_indices = prepare_inputs( frame_indices, img_matrix, v0, view_indices, model, version_dict, seed, polars, azims ) for step in tqdm(range(num_steps)): if step % 2 == 1: c, uc, additional_model_inputs, sampler = forward_inputs frame_indices = forward_frame_indices else: c, uc, additional_model_inputs, sampler = backward_inputs frame_indices = backward_frame_indices noisy_latents = latent_matrix[frame_indices][:, view_indices].flatten(0, 1) samples = do_sample_per_step( model, sampler, noisy_latents, c, uc, step, additional_model_inputs, ) samples = samples.view(T, V, C, H // F, W // F) for i, t in enumerate(frame_indices): for j, v in enumerate(view_indices): latent_matrix[t, v] = samples[i, j] img_matrix = decode_latents(model, latent_matrix, img_matrix, frame_indices, view_indices, T) # Save output videos for v in view_indices: vid_file = os.path.join(output_folder, f"{base_count:06d}_v{v:03d}.mp4") print(f"Saving {vid_file}") save_video(vid_file, [img_matrix[t][v] for t in range(n_frames)]) # Save diagonal video diag_frames = [ img_matrix[t][(t // (n_frames // n_views)) % n_views] for t in range(n_frames) ] vid_file = os.path.join(output_folder, f"{base_count:06d}_diag.mp4") print(f"Saving {vid_file}") save_video(vid_file, diag_frames) if __name__ == "__main__": Fire(sample) ================================================ FILE: scripts/sampling/simple_video_sample_4d2.py ================================================ import os import sys from glob import glob from typing import List, Optional from tqdm import tqdm sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../"))) import numpy as np import torch from fire import Fire from scripts.demo.sv4d_helpers import ( load_model, preprocess_video, read_video, run_img2vid, save_video, ) from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder sv4d2_configs = { "sv4d2": { "T": 12, # number of frames per sample "V": 4, # number of views per sample "model_config": "scripts/sampling/configs/sv4d2.yaml", "version_dict": { "T": 12 * 4, "options": { "discretization": 1, "cfg": 2.0, "min_cfg": 2.0, "num_views": 4, "sigma_min": 0.002, "sigma_max": 700.0, "rho": 7.0, "guider": 2, "force_uc_zero_embeddings": [ "cond_frames", "cond_frames_without_noise", "cond_view", "cond_motion", ], "additional_guider_kwargs": { "additional_cond_keys": ["cond_view", "cond_motion"] }, }, }, }, "sv4d2_8views": { "T": 5, # number of frames per sample "V": 8, # number of views per sample "model_config": "scripts/sampling/configs/sv4d2_8views.yaml", "version_dict": { "T": 5 * 8, "options": { "discretization": 1, "cfg": 2.5, "min_cfg": 1.5, "num_views": 8, "sigma_min": 0.002, "sigma_max": 700.0, "rho": 7.0, "guider": 5, "force_uc_zero_embeddings": [ "cond_frames", "cond_frames_without_noise", "cond_view", "cond_motion", ], "additional_guider_kwargs": { "additional_cond_keys": ["cond_view", "cond_motion"] }, }, }, }, } def sample( input_path: str = "assets/sv4d_videos/camel.gif", # Can either be image file or folder with image files model_path: Optional[str] = "checkpoints/sv4d2.safetensors", output_folder: Optional[str] = "outputs", num_steps: Optional[int] = 50, img_size: int = 576, # image resolution n_frames: int = 21, # number of input and output video frames seed: int = 23, encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary. decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. device: str = "cuda", elevations_deg: Optional[List[float]] = 0.0, azimuths_deg: Optional[List[float]] = None, image_frame_ratio: Optional[float] = 0.9, verbose: Optional[bool] = False, remove_bg: bool = False, ): """ Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t` and `encoding_t`. """ # Set model config assert os.path.basename(model_path) in [ "sv4d2.safetensors", "sv4d2_8views.safetensors", ] sv4d2_model = os.path.splitext(os.path.basename(model_path))[0] config = sv4d2_configs[sv4d2_model] print(sv4d2_model, config) T = config["T"] V = config["V"] model_config = config["model_config"] version_dict = config["version_dict"] F = 8 # vae factor to downsize image->latent C = 4 H, W = img_size, img_size n_views = V + 1 # number of output video views (1 input view + 8 novel views) subsampled_views = np.arange(n_views) version_dict["H"] = H version_dict["W"] = W version_dict["C"] = C version_dict["f"] = F version_dict["options"]["num_steps"] = num_steps torch.manual_seed(seed) output_folder = os.path.join(output_folder, sv4d2_model) os.makedirs(output_folder, exist_ok=True) # Read input video frames i.e. images at view 0 print(f"Reading {input_path}") base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // n_views processed_input_path = preprocess_video( input_path, remove_bg=remove_bg, n_frames=n_frames, W=W, H=H, output_folder=output_folder, image_frame_ratio=image_frame_ratio, base_count=base_count, ) images_v0 = read_video(processed_input_path, n_frames=n_frames, device=device) images_t0 = torch.zeros(n_views, 3, H, W).float().to(device) # Get camera viewpoints if isinstance(elevations_deg, float) or isinstance(elevations_deg, int): elevations_deg = [elevations_deg] * n_views assert ( len(elevations_deg) == n_views ), f"Please provide 1 value, or a list of {n_views} values for elevations_deg! Given {len(elevations_deg)}" if azimuths_deg is None: # azimuths_deg = np.linspace(0, 360, n_views + 1)[1:] % 360 azimuths_deg = ( np.array([0, 60, 120, 180, 240]) if sv4d2_model == "sv4d2" else np.array([0, 30, 75, 120, 165, 210, 255, 300, 330]) ) assert ( len(azimuths_deg) == n_views ), f"Please provide a list of {n_views} values for azimuths_deg! Given {len(azimuths_deg)}" polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg]) azimuths_rad = np.array( [np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg] ) # Initialize image matrix img_matrix = [[None] * n_views for _ in range(n_frames)] for i, v in enumerate(subsampled_views): img_matrix[0][i] = images_t0[v].unsqueeze(0) for t in range(n_frames): img_matrix[t][0] = images_v0[t] # Load SV4D++ model model, _ = load_model( model_config, device, version_dict["T"], num_steps, verbose, model_path, ) model.en_and_decode_n_samples_a_time = decoding_t for emb in model.conditioner.embedders: if isinstance(emb, VideoPredictionEmbedderWithEncoder): emb.en_and_decode_n_samples_a_time = encoding_t # Sampling novel-view videos v0 = 0 view_indices = np.arange(V) + 1 t0_list = ( range(0, n_frames, T-1) if sv4d2_model == "sv4d2" else range(0, n_frames - T + 1, T - 1) ) for t0 in tqdm(t0_list): if t0 + T > n_frames: t0 = n_frames - T frame_indices = t0 + np.arange(T) print(f"Sampling frames {frame_indices}") image = img_matrix[t0][v0] cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0) cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0) polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() polars = (polars - polars_rad[v0] + torch.pi / 2) % (torch.pi * 2) azims = (azims - azimuths_rad[v0]) % (torch.pi * 2) cond_mv = False if t0 == 0 else True samples = run_img2vid( version_dict, model, image, seed, polars, azims, cond_motion, cond_view, decoding_t, cond_mv=cond_mv, ) samples = samples.view(T, V, 3, H, W) for i, t in enumerate(frame_indices): for j, v in enumerate(view_indices): img_matrix[t][v] = samples[i, j][None] * 2 - 1 # Save output videos for v in view_indices: vid_file = os.path.join(output_folder, f"{base_count:06d}_v{v:03d}.mp4") print(f"Saving {vid_file}") save_video( vid_file, [img_matrix[t][v] for t in range(n_frames) if img_matrix[t][v] is not None], ) if __name__ == "__main__": Fire(sample) ================================================ FILE: scripts/tests/attention.py ================================================ import einops import torch import torch.nn.functional as F import torch.utils.benchmark as benchmark from torch.backends.cuda import SDPBackend from sgm.modules.attention import BasicTransformerBlock, SpatialTransformer def benchmark_attn(): # Lets define a helpful benchmarking function: # https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html device = "cuda" if torch.cuda.is_available() else "cpu" def benchmark_torch_function_in_microseconds(f, *args, **kwargs): t0 = benchmark.Timer( stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} ) return t0.blocked_autorange().mean * 1e6 # Lets define the hyper-parameters of our input batch_size = 32 max_sequence_len = 1024 num_heads = 32 embed_dimension = 32 dtype = torch.float16 query = torch.rand( batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype, ) key = torch.rand( batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype, ) value = torch.rand( batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype, ) print(f"q/k/v shape:", query.shape, key.shape, value.shape) # Lets explore the speed of each of the 3 implementations from torch.backends.cuda import SDPBackend, sdp_kernel # Helpful arguments mapper backend_map = { SDPBackend.MATH: { "enable_math": True, "enable_flash": False, "enable_mem_efficient": False, }, SDPBackend.FLASH_ATTENTION: { "enable_math": False, "enable_flash": True, "enable_mem_efficient": False, }, SDPBackend.EFFICIENT_ATTENTION: { "enable_math": False, "enable_flash": False, "enable_mem_efficient": True, }, } from torch.profiler import ProfilerActivity, profile, record_function activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] print( f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" ) with profile( activities=activities, record_shapes=False, profile_memory=True ) as prof: with record_function("Default detailed stats"): for _ in range(25): o = F.scaled_dot_product_attention(query, key, value) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) print( f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" ) with sdp_kernel(**backend_map[SDPBackend.MATH]): with profile( activities=activities, record_shapes=False, profile_memory=True ) as prof: with record_function("Math implmentation stats"): for _ in range(25): o = F.scaled_dot_product_attention(query, key, value) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]): try: print( f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" ) except RuntimeError: print("FlashAttention is not supported. See warnings for reasons.") with profile( activities=activities, record_shapes=False, profile_memory=True ) as prof: with record_function("FlashAttention stats"): for _ in range(25): o = F.scaled_dot_product_attention(query, key, value) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): try: print( f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" ) except RuntimeError: print("EfficientAttention is not supported. See warnings for reasons.") with profile( activities=activities, record_shapes=False, profile_memory=True ) as prof: with record_function("EfficientAttention stats"): for _ in range(25): o = F.scaled_dot_product_attention(query, key, value) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) def run_model(model, x, context): return model(x, context) def benchmark_transformer_blocks(): device = "cuda" if torch.cuda.is_available() else "cpu" import torch.utils.benchmark as benchmark def benchmark_torch_function_in_microseconds(f, *args, **kwargs): t0 = benchmark.Timer( stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} ) return t0.blocked_autorange().mean * 1e6 checkpoint = True compile = False batch_size = 32 h, w = 64, 64 context_len = 77 embed_dimension = 1024 context_dim = 1024 d_head = 64 transformer_depth = 4 n_heads = embed_dimension // d_head dtype = torch.float16 model_native = SpatialTransformer( embed_dimension, n_heads, d_head, context_dim=context_dim, use_linear=True, use_checkpoint=checkpoint, attn_type="softmax", depth=transformer_depth, sdp_backend=SDPBackend.FLASH_ATTENTION, ).to(device) model_efficient_attn = SpatialTransformer( embed_dimension, n_heads, d_head, context_dim=context_dim, use_linear=True, depth=transformer_depth, use_checkpoint=checkpoint, attn_type="softmax-xformers", ).to(device) if not checkpoint and compile: print("compiling models") model_native = torch.compile(model_native) model_efficient_attn = torch.compile(model_efficient_attn) x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype) c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype) from torch.profiler import ProfilerActivity, profile, record_function activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] with torch.autocast("cuda"): print( f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds" ) print( f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds" ) print(75 * "+") print("NATIVE") print(75 * "+") torch.cuda.reset_peak_memory_stats() with profile( activities=activities, record_shapes=False, profile_memory=True ) as prof: with record_function("NativeAttention stats"): for _ in range(25): model_native(x, c) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block") print(75 * "+") print("Xformers") print(75 * "+") torch.cuda.reset_peak_memory_stats() with profile( activities=activities, record_shapes=False, profile_memory=True ) as prof: with record_function("xformers stats"): for _ in range(25): model_efficient_attn(x, c) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block") def test01(): # conv1x1 vs linear from sgm.util import count_params conv = torch.nn.Conv2d(3, 32, kernel_size=1).cuda() print(count_params(conv)) linear = torch.nn.Linear(3, 32).cuda() print(count_params(linear)) print(conv.weight.shape) # use same initialization linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1)) linear.bias = torch.nn.Parameter(conv.bias) print(linear.weight.shape) x = torch.randn(11, 3, 64, 64).cuda() xr = einops.rearrange(x, "b c h w -> b (h w) c").contiguous() print(xr.shape) out_linear = linear(xr) print(out_linear.mean(), out_linear.shape) out_conv = conv(x) print(out_conv.mean(), out_conv.shape) print("done with test01.\n") def test02(): # try cosine flash attention import time torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True print("testing cosine flash attention...") DIM = 1024 SEQLEN = 4096 BS = 16 print(" softmax (vanilla) first...") model = BasicTransformerBlock( dim=DIM, n_heads=16, d_head=64, dropout=0.0, context_dim=None, attn_mode="softmax", ).cuda() try: x = torch.randn(BS, SEQLEN, DIM).cuda() tic = time.time() y = model(x) toc = time.time() print(y.shape, toc - tic) except RuntimeError as e: # likely oom print(str(e)) print("\n now flash-cosine...") model = BasicTransformerBlock( dim=DIM, n_heads=16, d_head=64, dropout=0.0, context_dim=None, attn_mode="flash-cosine", ).cuda() x = torch.randn(BS, SEQLEN, DIM).cuda() tic = time.time() y = model(x) toc = time.time() print(y.shape, toc - tic) print("done with test02.\n") if __name__ == "__main__": # test01() # test02() # test03() # benchmark_attn() benchmark_transformer_blocks() print("done.") ================================================ FILE: scripts/util/__init__.py ================================================ ================================================ FILE: scripts/util/detection/__init__.py ================================================ ================================================ FILE: scripts/util/detection/nsfw_and_watermark_dectection.py ================================================ import os import clip import numpy as np import torch import torchvision.transforms as T from PIL import Image RESOURCES_ROOT = "scripts/util/detection/" def predict_proba(X, weights, biases): logits = X @ weights.T + biases proba = np.where( logits >= 0, 1 / (1 + np.exp(-logits)), np.exp(logits) / (1 + np.exp(logits)) ) return proba.T def load_model_weights(path: str): model_weights = np.load(path) return model_weights["weights"], model_weights["biases"] def clip_process_images(images: torch.Tensor) -> torch.Tensor: min_size = min(images.shape[-2:]) return T.Compose( [ T.CenterCrop(min_size), # TODO: this might affect the watermark, check this T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True), T.Normalize( (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711), ), ] )(images) class DeepFloydDataFiltering(object): def __init__( self, verbose: bool = False, device: torch.device = torch.device("cpu") ): super().__init__() self.verbose = verbose self._device = None self.clip_model, _ = clip.load("ViT-L/14", device=device) self.clip_model.eval() self.cpu_w_weights, self.cpu_w_biases = load_model_weights( os.path.join(RESOURCES_ROOT, "w_head_v1.npz") ) self.cpu_p_weights, self.cpu_p_biases = load_model_weights( os.path.join(RESOURCES_ROOT, "p_head_v1.npz") ) self.w_threshold, self.p_threshold = 0.5, 0.5 @torch.inference_mode() def __call__(self, images: torch.Tensor) -> torch.Tensor: imgs = clip_process_images(images) if self._device is None: self._device = next(p for p in self.clip_model.parameters()).device image_features = self.clip_model.encode_image(imgs.to(self._device)) image_features = image_features.detach().cpu().numpy().astype(np.float16) p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases) w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases) print(f"p_pred = {p_pred}, w_pred = {w_pred}") if self.verbose else None query = p_pred > self.p_threshold if query.sum() > 0: print(f"Hit for p_threshold: {p_pred}") if self.verbose else None images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query]) query = w_pred > self.w_threshold if query.sum() > 0: print(f"Hit for w_threshold: {w_pred}") if self.verbose else None images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query]) return images def load_img(path: str) -> torch.Tensor: image = Image.open(path) if not image.mode == "RGB": image = image.convert("RGB") image_transforms = T.Compose( [ T.ToTensor(), ] ) return image_transforms(image)[None, ...] def test(root): from einops import rearrange filter = DeepFloydDataFiltering(verbose=True) for p in os.listdir((root)): print(f"running on {p}...") img = load_img(os.path.join(root, p)) filtered_img = filter(img) filtered_img = rearrange( 255.0 * (filtered_img.numpy())[0], "c h w -> h w c" ).astype(np.uint8) Image.fromarray(filtered_img).save( os.path.join(root, f"{os.path.splitext(p)[0]}-filtered.jpg") ) if __name__ == "__main__": import fire fire.Fire(test) print("done.") ================================================ FILE: sgm/__init__.py ================================================ from .models import AutoencodingEngine, DiffusionEngine from .util import get_configs_path, instantiate_from_config __version__ = "0.1.0" ================================================ FILE: sgm/data/__init__.py ================================================ from .dataset import StableDataModuleFromConfig ================================================ FILE: sgm/data/cifar10.py ================================================ import pytorch_lightning as pl import torchvision from torch.utils.data import DataLoader, Dataset from torchvision import transforms class CIFAR10DataDictWrapper(Dataset): def __init__(self, dset): super().__init__() self.dset = dset def __getitem__(self, i): x, y = self.dset[i] return {"jpg": x, "cls": y} def __len__(self): return len(self.dset) class CIFAR10Loader(pl.LightningDataModule): def __init__(self, batch_size, num_workers=0, shuffle=True): super().__init__() transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] ) self.batch_size = batch_size self.num_workers = num_workers self.shuffle = shuffle self.train_dataset = CIFAR10DataDictWrapper( torchvision.datasets.CIFAR10( root=".data/", train=True, download=True, transform=transform ) ) self.test_dataset = CIFAR10DataDictWrapper( torchvision.datasets.CIFAR10( root=".data/", train=False, download=True, transform=transform ) ) def prepare_data(self): pass def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, ) def test_dataloader(self): return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, ) def val_dataloader(self): return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, ) ================================================ FILE: sgm/data/dataset.py ================================================ from typing import Optional import torchdata.datapipes.iter import webdataset as wds from omegaconf import DictConfig from pytorch_lightning import LightningDataModule try: from sdata import create_dataset, create_dummy_dataset, create_loader except ImportError as e: print("#" * 100) print("Datasets not yet available") print("to enable, we need to add stable-datasets as a submodule") print("please use ``git submodule update --init --recursive``") print("and do ``pip install -e stable-datasets/`` from the root of this repo") print("#" * 100) exit(1) class StableDataModuleFromConfig(LightningDataModule): def __init__( self, train: DictConfig, validation: Optional[DictConfig] = None, test: Optional[DictConfig] = None, skip_val_loader: bool = False, dummy: bool = False, ): super().__init__() self.train_config = train assert ( "datapipeline" in self.train_config and "loader" in self.train_config ), "train config requires the fields `datapipeline` and `loader`" self.val_config = validation if not skip_val_loader: if self.val_config is not None: assert ( "datapipeline" in self.val_config and "loader" in self.val_config ), "validation config requires the fields `datapipeline` and `loader`" else: print( "Warning: No Validation datapipeline defined, using that one from training" ) self.val_config = train self.test_config = test if self.test_config is not None: assert ( "datapipeline" in self.test_config and "loader" in self.test_config ), "test config requires the fields `datapipeline` and `loader`" self.dummy = dummy if self.dummy: print("#" * 100) print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)") print("#" * 100) def setup(self, stage: str) -> None: print("Preparing datasets") if self.dummy: data_fn = create_dummy_dataset else: data_fn = create_dataset self.train_datapipeline = data_fn(**self.train_config.datapipeline) if self.val_config: self.val_datapipeline = data_fn(**self.val_config.datapipeline) if self.test_config: self.test_datapipeline = data_fn(**self.test_config.datapipeline) def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe: loader = create_loader(self.train_datapipeline, **self.train_config.loader) return loader def val_dataloader(self) -> wds.DataPipeline: return create_loader(self.val_datapipeline, **self.val_config.loader) def test_dataloader(self) -> wds.DataPipeline: return create_loader(self.test_datapipeline, **self.test_config.loader) ================================================ FILE: sgm/data/mnist.py ================================================ import pytorch_lightning as pl import torchvision from torch.utils.data import DataLoader, Dataset from torchvision import transforms class MNISTDataDictWrapper(Dataset): def __init__(self, dset): super().__init__() self.dset = dset def __getitem__(self, i): x, y = self.dset[i] return {"jpg": x, "cls": y} def __len__(self): return len(self.dset) class MNISTLoader(pl.LightningDataModule): def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True): super().__init__() transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] ) self.batch_size = batch_size self.num_workers = num_workers self.prefetch_factor = prefetch_factor if num_workers > 0 else 0 self.shuffle = shuffle self.train_dataset = MNISTDataDictWrapper( torchvision.datasets.MNIST( root=".data/", train=True, download=True, transform=transform ) ) self.test_dataset = MNISTDataDictWrapper( torchvision.datasets.MNIST( root=".data/", train=False, download=True, transform=transform ) ) def prepare_data(self): pass def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, ) def test_dataloader(self): return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, ) def val_dataloader(self): return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, ) if __name__ == "__main__": dset = MNISTDataDictWrapper( torchvision.datasets.MNIST( root=".data/", train=False, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] ), ) ) ex = dset[0] ================================================ FILE: sgm/inference/api.py ================================================ import pathlib from dataclasses import asdict, dataclass from enum import Enum from typing import Optional from omegaconf import OmegaConf from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img, do_sample) from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler, DPMPP2SAncestralSampler, EulerAncestralSampler, EulerEDMSampler, HeunEDMSampler, LinearMultistepSampler) from sgm.util import load_model_from_config class ModelArchitecture(str, Enum): SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base" SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner" SDXL_V1_BASE = "stable-diffusion-xl-v1-base" SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner" class Sampler(str, Enum): EULER_EDM = "EulerEDMSampler" HEUN_EDM = "HeunEDMSampler" EULER_ANCESTRAL = "EulerAncestralSampler" DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler" DPMPP2M = "DPMPP2MSampler" LINEAR_MULTISTEP = "LinearMultistepSampler" class Discretization(str, Enum): LEGACY_DDPM = "LegacyDDPMDiscretization" EDM = "EDMDiscretization" class Guider(str, Enum): VANILLA = "VanillaCFG" IDENTITY = "IdentityGuider" class Thresholder(str, Enum): NONE = "None" @dataclass class SamplingParams: width: int = 1024 height: int = 1024 steps: int = 50 sampler: Sampler = Sampler.DPMPP2M discretization: Discretization = Discretization.LEGACY_DDPM guider: Guider = Guider.VANILLA thresholder: Thresholder = Thresholder.NONE scale: float = 6.0 aesthetic_score: float = 5.0 negative_aesthetic_score: float = 5.0 img2img_strength: float = 1.0 orig_width: int = 1024 orig_height: int = 1024 crop_coords_top: int = 0 crop_coords_left: int = 0 sigma_min: float = 0.0292 sigma_max: float = 14.6146 rho: float = 3.0 s_churn: float = 0.0 s_tmin: float = 0.0 s_tmax: float = 999.0 s_noise: float = 1.0 eta: float = 1.0 order: int = 4 @dataclass class SamplingSpec: width: int height: int channels: int factor: int is_legacy: bool config: str ckpt: str is_guided: bool model_specs = { ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec( height=1024, width=1024, channels=4, factor=8, is_legacy=False, config="sd_xl_base.yaml", ckpt="sd_xl_base_0.9.safetensors", is_guided=True, ), ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec( height=1024, width=1024, channels=4, factor=8, is_legacy=True, config="sd_xl_refiner.yaml", ckpt="sd_xl_refiner_0.9.safetensors", is_guided=True, ), ModelArchitecture.SDXL_V1_BASE: SamplingSpec( height=1024, width=1024, channels=4, factor=8, is_legacy=False, config="sd_xl_base.yaml", ckpt="sd_xl_base_1.0.safetensors", is_guided=True, ), ModelArchitecture.SDXL_V1_REFINER: SamplingSpec( height=1024, width=1024, channels=4, factor=8, is_legacy=True, config="sd_xl_refiner.yaml", ckpt="sd_xl_refiner_1.0.safetensors", is_guided=True, ), } class SamplingPipeline: def __init__( self, model_id: ModelArchitecture, model_path="checkpoints", config_path="configs/inference", device="cuda", use_fp16=True, ) -> None: if model_id not in model_specs: raise ValueError(f"Model {model_id} not supported") self.model_id = model_id self.specs = model_specs[self.model_id] self.config = str(pathlib.Path(config_path, self.specs.config)) self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt)) self.device = device self.model = self._load_model(device=device, use_fp16=use_fp16) def _load_model(self, device="cuda", use_fp16=True): config = OmegaConf.load(self.config) model = load_model_from_config(config, self.ckpt) if model is None: raise ValueError(f"Model {self.model_id} could not be loaded") model.to(device) if use_fp16: model.conditioner.half() model.model.half() return model def text_to_image( self, params: SamplingParams, prompt: str, negative_prompt: str = "", samples: int = 1, return_latents: bool = False, ): sampler = get_sampler_config(params) value_dict = asdict(params) value_dict["prompt"] = prompt value_dict["negative_prompt"] = negative_prompt value_dict["target_width"] = params.width value_dict["target_height"] = params.height return do_sample( self.model, sampler, value_dict, samples, params.height, params.width, self.specs.channels, self.specs.factor, force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, filter=None, ) def image_to_image( self, params: SamplingParams, image, prompt: str, negative_prompt: str = "", samples: int = 1, return_latents: bool = False, ): sampler = get_sampler_config(params) if params.img2img_strength < 1.0: sampler.discretization = Img2ImgDiscretizationWrapper( sampler.discretization, strength=params.img2img_strength, ) height, width = image.shape[2], image.shape[3] value_dict = asdict(params) value_dict["prompt"] = prompt value_dict["negative_prompt"] = negative_prompt value_dict["target_width"] = width value_dict["target_height"] = height return do_img2img( image, self.model, sampler, value_dict, samples, force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, filter=None, ) def refiner( self, params: SamplingParams, image, prompt: str, negative_prompt: Optional[str] = None, samples: int = 1, return_latents: bool = False, ): sampler = get_sampler_config(params) value_dict = { "orig_width": image.shape[3] * 8, "orig_height": image.shape[2] * 8, "target_width": image.shape[3] * 8, "target_height": image.shape[2] * 8, "prompt": prompt, "negative_prompt": negative_prompt, "crop_coords_top": 0, "crop_coords_left": 0, "aesthetic_score": 6.0, "negative_aesthetic_score": 2.5, } return do_img2img( image, self.model, sampler, value_dict, samples, skip_encode=True, return_latents=return_latents, filter=None, ) def get_guider_config(params: SamplingParams): if params.guider == Guider.IDENTITY: guider_config = { "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" } elif params.guider == Guider.VANILLA: scale = params.scale thresholder = params.thresholder if thresholder == Thresholder.NONE: dyn_thresh_config = { "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" } else: raise NotImplementedError guider_config = { "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config}, } else: raise NotImplementedError return guider_config def get_discretization_config(params: SamplingParams): if params.discretization == Discretization.LEGACY_DDPM: discretization_config = { "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", } elif params.discretization == Discretization.EDM: discretization_config = { "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", "params": { "sigma_min": params.sigma_min, "sigma_max": params.sigma_max, "rho": params.rho, }, } else: raise ValueError(f"unknown discretization {params.discretization}") return discretization_config def get_sampler_config(params: SamplingParams): discretization_config = get_discretization_config(params) guider_config = get_guider_config(params) sampler = None if params.sampler == Sampler.EULER_EDM: return EulerEDMSampler( num_steps=params.steps, discretization_config=discretization_config, guider_config=guider_config, s_churn=params.s_churn, s_tmin=params.s_tmin, s_tmax=params.s_tmax, s_noise=params.s_noise, verbose=True, ) if params.sampler == Sampler.HEUN_EDM: return HeunEDMSampler( num_steps=params.steps, discretization_config=discretization_config, guider_config=guider_config, s_churn=params.s_churn, s_tmin=params.s_tmin, s_tmax=params.s_tmax, s_noise=params.s_noise, verbose=True, ) if params.sampler == Sampler.EULER_ANCESTRAL: return EulerAncestralSampler( num_steps=params.steps, discretization_config=discretization_config, guider_config=guider_config, eta=params.eta, s_noise=params.s_noise, verbose=True, ) if params.sampler == Sampler.DPMPP2S_ANCESTRAL: return DPMPP2SAncestralSampler( num_steps=params.steps, discretization_config=discretization_config, guider_config=guider_config, eta=params.eta, s_noise=params.s_noise, verbose=True, ) if params.sampler == Sampler.DPMPP2M: return DPMPP2MSampler( num_steps=params.steps, discretization_config=discretization_config, guider_config=guider_config, verbose=True, ) if params.sampler == Sampler.LINEAR_MULTISTEP: return LinearMultistepSampler( num_steps=params.steps, discretization_config=discretization_config, guider_config=guider_config, order=params.order, verbose=True, ) raise ValueError(f"unknown sampler {params.sampler}!") ================================================ FILE: sgm/inference/helpers.py ================================================ import math import os from typing import List, Optional, Union import numpy as np import torch from einops import rearrange from imwatermark import WatermarkEncoder from omegaconf import ListConfig from PIL import Image from torch import autocast from sgm.util import append_dims class WatermarkEmbedder: def __init__(self, watermark): self.watermark = watermark self.num_bits = len(WATERMARK_BITS) self.encoder = WatermarkEncoder() self.encoder.set_watermark("bits", self.watermark) def __call__(self, image: torch.Tensor) -> torch.Tensor: """ Adds a predefined watermark to the input image Args: image: ([N,] B, RGB, H, W) in range [0, 1] Returns: same as input but watermarked """ squeeze = len(image.shape) == 4 if squeeze: image = image[None, ...] n = image.shape[0] image_np = rearrange( (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" ).numpy()[:, :, :, ::-1] # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] # watermarking libary expects input as cv2 BGR format for k in range(image_np.shape[0]): image_np[k] = self.encoder.encode(image_np[k], "dwtDct") image = torch.from_numpy( rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n) ).to(image.device) image = torch.clamp(image / 255, min=0.0, max=1.0) if squeeze: image = image[0] return image # A fixed 48-bit message that was choosen at random # WATERMARK_MESSAGE = 0xB3EC907BB19E WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] embed_watermark = WatermarkEmbedder(WATERMARK_BITS) def get_unique_embedder_keys_from_conditioner(conditioner): return list({x.input_key for x in conditioner.embedders}) def perform_save_locally(save_path, samples): os.makedirs(os.path.join(save_path), exist_ok=True) base_count = len(os.listdir(os.path.join(save_path))) samples = embed_watermark(samples) for sample in samples: sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") Image.fromarray(sample.astype(np.uint8)).save( os.path.join(save_path, f"{base_count:09}.png") ) base_count += 1 class Img2ImgDiscretizationWrapper: """ wraps a discretizer, and prunes the sigmas params: strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) """ def __init__(self, discretization, strength: float = 1.0): self.discretization = discretization self.strength = strength assert 0.0 <= self.strength <= 1.0 def __call__(self, *args, **kwargs): # sigmas start large first, and decrease then sigmas = self.discretization(*args, **kwargs) print(f"sigmas after discretization, before pruning img2img: ", sigmas) sigmas = torch.flip(sigmas, (0,)) sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] print("prune index:", max(int(self.strength * len(sigmas)), 1)) sigmas = torch.flip(sigmas, (0,)) print(f"sigmas after pruning: ", sigmas) return sigmas def do_sample( model, sampler, value_dict, num_samples, H, W, C, F, force_uc_zero_embeddings: Optional[List] = None, batch2model_input: Optional[List] = None, return_latents=False, filter=None, device="cuda", ): if force_uc_zero_embeddings is None: force_uc_zero_embeddings = [] if batch2model_input is None: batch2model_input = [] with torch.no_grad(): with autocast(device) as precision_scope: with model.ema_scope(): num_samples = [num_samples] batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples, ) for key in batch: if isinstance(batch[key], torch.Tensor): print(key, batch[key].shape) elif isinstance(batch[key], list): print(key, [len(l) for l in batch[key]]) else: print(key, batch[key]) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings, ) for k in c: if not k == "crossattn": c[k], uc[k] = map( lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc) ) additional_model_inputs = {} for k in batch2model_input: additional_model_inputs[k] = batch[k] shape = (math.prod(num_samples), C, H // F, W // F) randn = torch.randn(shape).to(device) def denoiser(input, sigma, c): return model.denoiser( model.model, input, sigma, c, **additional_model_inputs ) samples_z = sampler(denoiser, randn, cond=c, uc=uc) samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) if filter is not None: samples = filter(samples) if return_latents: return samples, samples_z return samples def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): # Hardcoded demo setups; might undergo some changes in the future batch = {} batch_uc = {} for key in keys: if key == "txt": batch["txt"] = ( np.repeat([value_dict["prompt"]], repeats=math.prod(N)) .reshape(N) .tolist() ) batch_uc["txt"] = ( np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) .reshape(N) .tolist() ) elif key == "original_size_as_tuple": batch["original_size_as_tuple"] = ( torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) .to(device) .repeat(*N, 1) ) elif key == "crop_coords_top_left": batch["crop_coords_top_left"] = ( torch.tensor( [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] ) .to(device) .repeat(*N, 1) ) elif key == "aesthetic_score": batch["aesthetic_score"] = ( torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) ) batch_uc["aesthetic_score"] = ( torch.tensor([value_dict["negative_aesthetic_score"]]) .to(device) .repeat(*N, 1) ) elif key == "target_size_as_tuple": batch["target_size_as_tuple"] = ( torch.tensor([value_dict["target_height"], value_dict["target_width"]]) .to(device) .repeat(*N, 1) ) else: batch[key] = value_dict[key] for key in batch.keys(): if key not in batch_uc and isinstance(batch[key], torch.Tensor): batch_uc[key] = torch.clone(batch[key]) return batch, batch_uc def get_input_image_tensor(image: Image.Image, device="cuda"): w, h = image.size print(f"loaded input image of size ({w}, {h})") width, height = map( lambda x: x - x % 64, (w, h) ) # resize to integer multiple of 64 image = image.resize((width, height)) image_array = np.array(image.convert("RGB")) image_array = image_array[None].transpose(0, 3, 1, 2) image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0 return image_tensor.to(device) def do_img2img( img, model, sampler, value_dict, num_samples, force_uc_zero_embeddings=[], additional_kwargs={}, offset_noise_level: float = 0.0, return_latents=False, skip_encode=False, filter=None, device="cuda", ): with torch.no_grad(): with autocast(device) as precision_scope: with model.ema_scope(): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, [num_samples], ) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings, ) for k in c: c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc)) for k in additional_kwargs: c[k] = uc[k] = additional_kwargs[k] if skip_encode: z = img else: z = model.encode_first_stage(img) noise = torch.randn_like(z) sigmas = sampler.discretization(sampler.num_steps) sigma = sigmas[0].to(z.device) if offset_noise_level > 0.0: noise = noise + offset_noise_level * append_dims( torch.randn(z.shape[0], device=z.device), z.ndim ) noised_z = z + noise * append_dims(sigma, z.ndim) noised_z = noised_z / torch.sqrt( 1.0 + sigmas[0] ** 2.0 ) # Note: hardcoded to DDPM-like scaling. need to generalize later. def denoiser(x, sigma, c): return model.denoiser(model.model, x, sigma, c) samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) if filter is not None: samples = filter(samples) if return_latents: return samples, samples_z return samples ================================================ FILE: sgm/lr_scheduler.py ================================================ import numpy as np class LambdaWarmUpCosineScheduler: """ note: use with a base_lr of 1.0 """ def __init__( self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0, ): self.lr_warm_up_steps = warm_up_steps self.lr_start = lr_start self.lr_min = lr_min self.lr_max = lr_max self.lr_max_decay_steps = max_decay_steps self.last_lr = 0.0 self.verbosity_interval = verbosity_interval def schedule(self, n, **kwargs): if self.verbosity_interval > 0: if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") if n < self.lr_warm_up_steps: lr = ( self.lr_max - self.lr_start ) / self.lr_warm_up_steps * n + self.lr_start self.last_lr = lr return lr else: t = (n - self.lr_warm_up_steps) / ( self.lr_max_decay_steps - self.lr_warm_up_steps ) t = min(t, 1.0) lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 1 + np.cos(t * np.pi) ) self.last_lr = lr return lr def __call__(self, n, **kwargs): return self.schedule(n, **kwargs) class LambdaWarmUpCosineScheduler2: """ supports repeated iterations, configurable via lists note: use with a base_lr of 1.0. """ def __init__( self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0 ): assert ( len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) ) self.lr_warm_up_steps = warm_up_steps self.f_start = f_start self.f_min = f_min self.f_max = f_max self.cycle_lengths = cycle_lengths self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) self.last_f = 0.0 self.verbosity_interval = verbosity_interval def find_in_interval(self, n): interval = 0 for cl in self.cum_cycles[1:]: if n <= cl: return interval interval += 1 def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] if self.verbosity_interval > 0: if n % self.verbosity_interval == 0: print( f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}" ) if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ cycle ] * n + self.f_start[cycle] self.last_f = f return f else: t = (n - self.lr_warm_up_steps[cycle]) / ( self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] ) t = min(t, 1.0) f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 1 + np.cos(t * np.pi) ) self.last_f = f return f def __call__(self, n, **kwargs): return self.schedule(n, **kwargs) class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] if self.verbosity_interval > 0: if n % self.verbosity_interval == 0: print( f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}" ) if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ cycle ] * n + self.f_start[cycle] self.last_f = f return f else: f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * ( self.cycle_lengths[cycle] - n ) / (self.cycle_lengths[cycle]) self.last_f = f return f ================================================ FILE: sgm/models/__init__.py ================================================ from .autoencoder import AutoencodingEngine from .diffusion import DiffusionEngine ================================================ FILE: sgm/models/autoencoder.py ================================================ import logging import math import re from abc import abstractmethod from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple, Union import pytorch_lightning as pl import torch import torch.nn as nn from einops import rearrange from packaging import version from ..modules.autoencoding.regularizers import AbstractRegularizer from ..modules.ema import LitEma from ..util import (default, get_nested_attribute, get_obj_from_str, instantiate_from_config) logpy = logging.getLogger(__name__) class AbstractAutoencoder(pl.LightningModule): """ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators, unCLIP models, etc. Hence, it is fairly general, and specific features (e.g. discriminator training, encoding, decoding) must be implemented in subclasses. """ def __init__( self, ema_decay: Union[None, float] = None, monitor: Union[None, str] = None, input_key: str = "jpg", ): super().__init__() self.input_key = input_key self.use_ema = ema_decay is not None if monitor is not None: self.monitor = monitor if self.use_ema: self.model_ema = LitEma(self, decay=ema_decay) logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") if version.parse(torch.__version__) >= version.parse("2.0.0"): self.automatic_optimization = False def apply_ckpt(self, ckpt: Union[None, str, dict]): if ckpt is None: return if isinstance(ckpt, str): ckpt = { "target": "sgm.modules.checkpoint.CheckpointEngine", "params": {"ckpt_path": ckpt}, } engine = instantiate_from_config(ckpt) engine(self) @abstractmethod def get_input(self, batch) -> Any: raise NotImplementedError() def on_train_batch_end(self, *args, **kwargs): # for EMA computation if self.use_ema: self.model_ema(self) @contextmanager def ema_scope(self, context=None): if self.use_ema: self.model_ema.store(self.parameters()) self.model_ema.copy_to(self) if context is not None: logpy.info(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.parameters()) if context is not None: logpy.info(f"{context}: Restored training weights") @abstractmethod def encode(self, *args, **kwargs) -> torch.Tensor: raise NotImplementedError("encode()-method of abstract base class called") @abstractmethod def decode(self, *args, **kwargs) -> torch.Tensor: raise NotImplementedError("decode()-method of abstract base class called") def instantiate_optimizer_from_config(self, params, lr, cfg): logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config") return get_obj_from_str(cfg["target"])( params, lr=lr, **cfg.get("params", dict()) ) def configure_optimizers(self) -> Any: raise NotImplementedError() class AutoencodingEngine(AbstractAutoencoder): """ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL (we also restore them explicitly as special cases for legacy reasons). Regularizations such as KL or VQ are moved to the regularizer class. """ def __init__( self, *args, encoder_config: Dict, decoder_config: Dict, loss_config: Dict, regularizer_config: Dict, optimizer_config: Union[Dict, None] = None, lr_g_factor: float = 1.0, trainable_ae_params: Optional[List[List[str]]] = None, ae_optimizer_args: Optional[List[dict]] = None, trainable_disc_params: Optional[List[List[str]]] = None, disc_optimizer_args: Optional[List[dict]] = None, disc_start_iter: int = 0, diff_boost_factor: float = 3.0, ckpt_engine: Union[None, str, dict] = None, ckpt_path: Optional[str] = None, additional_decode_keys: Optional[List[str]] = None, **kwargs, ): super().__init__(*args, **kwargs) self.automatic_optimization = False # pytorch lightning self.encoder: torch.nn.Module = instantiate_from_config(encoder_config) self.decoder: torch.nn.Module = instantiate_from_config(decoder_config) self.loss: torch.nn.Module = instantiate_from_config(loss_config) self.regularization: AbstractRegularizer = instantiate_from_config( regularizer_config ) self.optimizer_config = default( optimizer_config, {"target": "torch.optim.Adam"} ) self.diff_boost_factor = diff_boost_factor self.disc_start_iter = disc_start_iter self.lr_g_factor = lr_g_factor self.trainable_ae_params = trainable_ae_params if self.trainable_ae_params is not None: self.ae_optimizer_args = default( ae_optimizer_args, [{} for _ in range(len(self.trainable_ae_params))], ) assert len(self.ae_optimizer_args) == len(self.trainable_ae_params) else: self.ae_optimizer_args = [{}] # makes type consitent self.trainable_disc_params = trainable_disc_params if self.trainable_disc_params is not None: self.disc_optimizer_args = default( disc_optimizer_args, [{} for _ in range(len(self.trainable_disc_params))], ) assert len(self.disc_optimizer_args) == len(self.trainable_disc_params) else: self.disc_optimizer_args = [{}] # makes type consitent if ckpt_path is not None: assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path" logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead") self.apply_ckpt(default(ckpt_path, ckpt_engine)) self.additional_decode_keys = set(default(additional_decode_keys, [])) def get_input(self, batch: Dict) -> torch.Tensor: # assuming unified data format, dataloader returns a dict. # image tensors should be scaled to -1 ... 1 and in channels-first # format (e.g., bchw instead if bhwc) return batch[self.input_key] def get_autoencoder_params(self) -> list: params = [] if hasattr(self.loss, "get_trainable_autoencoder_parameters"): params += list(self.loss.get_trainable_autoencoder_parameters()) if hasattr(self.regularization, "get_trainable_parameters"): params += list(self.regularization.get_trainable_parameters()) params = params + list(self.encoder.parameters()) params = params + list(self.decoder.parameters()) return params def get_discriminator_params(self) -> list: if hasattr(self.loss, "get_trainable_parameters"): params = list(self.loss.get_trainable_parameters()) # e.g., discriminator else: params = [] return params def get_last_layer(self): return self.decoder.get_last_layer() def encode( self, x: torch.Tensor, return_reg_log: bool = False, unregularized: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: z = self.encoder(x) if unregularized: return z, dict() z, reg_log = self.regularization(z) if return_reg_log: return z, reg_log return z def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: x = self.decoder(z, **kwargs) return x def forward( self, x: torch.Tensor, **additional_decode_kwargs ) -> Tuple[torch.Tensor, torch.Tensor, dict]: z, reg_log = self.encode(x, return_reg_log=True) dec = self.decode(z, **additional_decode_kwargs) return z, dec, reg_log def inner_training_step( self, batch: dict, batch_idx: int, optimizer_idx: int = 0 ) -> torch.Tensor: x = self.get_input(batch) additional_decode_kwargs = { key: batch[key] for key in self.additional_decode_keys.intersection(batch) } z, xrec, regularization_log = self(x, **additional_decode_kwargs) if hasattr(self.loss, "forward_keys"): extra_info = { "z": z, "optimizer_idx": optimizer_idx, "global_step": self.global_step, "last_layer": self.get_last_layer(), "split": "train", "regularization_log": regularization_log, "autoencoder": self, } extra_info = {k: extra_info[k] for k in self.loss.forward_keys} else: extra_info = dict() if optimizer_idx == 0: # autoencode out_loss = self.loss(x, xrec, **extra_info) if isinstance(out_loss, tuple): aeloss, log_dict_ae = out_loss else: # simple loss function aeloss = out_loss log_dict_ae = {"train/loss/rec": aeloss.detach()} self.log_dict( log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True, sync_dist=False, ) self.log( "loss", aeloss.mean().detach(), prog_bar=True, logger=False, on_epoch=False, on_step=True, ) return aeloss elif optimizer_idx == 1: # discriminator discloss, log_dict_disc = self.loss(x, xrec, **extra_info) # -> discriminator always needs to return a tuple self.log_dict( log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True ) return discloss else: raise NotImplementedError(f"Unknown optimizer {optimizer_idx}") def training_step(self, batch: dict, batch_idx: int): opts = self.optimizers() if not isinstance(opts, list): # Non-adversarial case opts = [opts] optimizer_idx = batch_idx % len(opts) if self.global_step < self.disc_start_iter: optimizer_idx = 0 opt = opts[optimizer_idx] opt.zero_grad() with opt.toggle_model(): loss = self.inner_training_step( batch, batch_idx, optimizer_idx=optimizer_idx ) self.manual_backward(loss) opt.step() def validation_step(self, batch: dict, batch_idx: int) -> Dict: log_dict = self._validation_step(batch, batch_idx) with self.ema_scope(): log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") log_dict.update(log_dict_ema) return log_dict def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict: x = self.get_input(batch) z, xrec, regularization_log = self(x) if hasattr(self.loss, "forward_keys"): extra_info = { "z": z, "optimizer_idx": 0, "global_step": self.global_step, "last_layer": self.get_last_layer(), "split": "val" + postfix, "regularization_log": regularization_log, "autoencoder": self, } extra_info = {k: extra_info[k] for k in self.loss.forward_keys} else: extra_info = dict() out_loss = self.loss(x, xrec, **extra_info) if isinstance(out_loss, tuple): aeloss, log_dict_ae = out_loss else: # simple loss function aeloss = out_loss log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()} full_log_dict = log_dict_ae if "optimizer_idx" in extra_info: extra_info["optimizer_idx"] = 1 discloss, log_dict_disc = self.loss(x, xrec, **extra_info) full_log_dict.update(log_dict_disc) self.log( f"val{postfix}/loss/rec", log_dict_ae[f"val{postfix}/loss/rec"], sync_dist=True, ) self.log_dict(full_log_dict, sync_dist=True) return full_log_dict def get_param_groups( self, parameter_names: List[List[str]], optimizer_args: List[dict] ) -> Tuple[List[Dict[str, Any]], int]: groups = [] num_params = 0 for names, args in zip(parameter_names, optimizer_args): params = [] for pattern_ in names: pattern_params = [] pattern = re.compile(pattern_) for p_name, param in self.named_parameters(): if re.match(pattern, p_name): pattern_params.append(param) num_params += param.numel() if len(pattern_params) == 0: logpy.warn(f"Did not find parameters for pattern {pattern_}") params.extend(pattern_params) groups.append({"params": params, **args}) return groups, num_params def configure_optimizers(self) -> List[torch.optim.Optimizer]: if self.trainable_ae_params is None: ae_params = self.get_autoencoder_params() else: ae_params, num_ae_params = self.get_param_groups( self.trainable_ae_params, self.ae_optimizer_args ) logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}") if self.trainable_disc_params is None: disc_params = self.get_discriminator_params() else: disc_params, num_disc_params = self.get_param_groups( self.trainable_disc_params, self.disc_optimizer_args ) logpy.info( f"Number of trainable discriminator parameters: {num_disc_params:,}" ) opt_ae = self.instantiate_optimizer_from_config( ae_params, default(self.lr_g_factor, 1.0) * self.learning_rate, self.optimizer_config, ) opts = [opt_ae] if len(disc_params) > 0: opt_disc = self.instantiate_optimizer_from_config( disc_params, self.learning_rate, self.optimizer_config ) opts.append(opt_disc) return opts @torch.no_grad() def log_images( self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs ) -> dict: log = dict() additional_decode_kwargs = {} x = self.get_input(batch) additional_decode_kwargs.update( {key: batch[key] for key in self.additional_decode_keys.intersection(batch)} ) _, xrec, _ = self(x, **additional_decode_kwargs) log["inputs"] = x log["reconstructions"] = xrec diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x) diff.clamp_(0, 1.0) log["diff"] = 2.0 * diff - 1.0 # diff_boost shows location of small errors, by boosting their # brightness. log["diff_boost"] = ( 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1 ) if hasattr(self.loss, "log_images"): log.update(self.loss.log_images(x, xrec)) with self.ema_scope(): _, xrec_ema, _ = self(x, **additional_decode_kwargs) log["reconstructions_ema"] = xrec_ema diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x) diff_ema.clamp_(0, 1.0) log["diff_ema"] = 2.0 * diff_ema - 1.0 log["diff_boost_ema"] = ( 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1 ) if additional_log_kwargs: additional_decode_kwargs.update(additional_log_kwargs) _, xrec_add, _ = self(x, **additional_decode_kwargs) log_str = "reconstructions-" + "-".join( [f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs] ) log[log_str] = xrec_add return log class AutoencodingEngineLegacy(AutoencodingEngine): def __init__(self, embed_dim: int, **kwargs): self.max_batch_size = kwargs.pop("max_batch_size", None) ddconfig = kwargs.pop("ddconfig") ckpt_path = kwargs.pop("ckpt_path", None) ckpt_engine = kwargs.pop("ckpt_engine", None) super().__init__( encoder_config={ "target": "sgm.modules.diffusionmodules.model.Encoder", "params": ddconfig, }, decoder_config={ "target": "sgm.modules.diffusionmodules.model.Decoder", "params": ddconfig, }, **kwargs, ) self.quant_conv = torch.nn.Conv2d( (1 + ddconfig["double_z"]) * ddconfig["z_channels"], (1 + ddconfig["double_z"]) * embed_dim, 1, ) self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.embed_dim = embed_dim self.apply_ckpt(default(ckpt_path, ckpt_engine)) def get_autoencoder_params(self) -> list: params = super().get_autoencoder_params() return params def encode( self, x: torch.Tensor, return_reg_log: bool = False ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: if self.max_batch_size is None: z = self.encoder(x) z = self.quant_conv(z) else: N = x.shape[0] bs = self.max_batch_size n_batches = int(math.ceil(N / bs)) z = list() for i_batch in range(n_batches): z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs]) z_batch = self.quant_conv(z_batch) z.append(z_batch) z = torch.cat(z, 0) z, reg_log = self.regularization(z) if return_reg_log: return z, reg_log return z def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor: if self.max_batch_size is None: dec = self.post_quant_conv(z) dec = self.decoder(dec, **decoder_kwargs) else: N = z.shape[0] bs = self.max_batch_size n_batches = int(math.ceil(N / bs)) dec = list() for i_batch in range(n_batches): dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs]) dec_batch = self.decoder(dec_batch, **decoder_kwargs) dec.append(dec_batch) dec = torch.cat(dec, 0) return dec class AutoencoderKL(AutoencodingEngineLegacy): def __init__(self, **kwargs): if "lossconfig" in kwargs: kwargs["loss_config"] = kwargs.pop("lossconfig") super().__init__( regularizer_config={ "target": ( "sgm.modules.autoencoding.regularizers" ".DiagonalGaussianRegularizer" ) }, **kwargs, ) class AutoencoderLegacyVQ(AutoencodingEngineLegacy): def __init__( self, embed_dim: int, n_embed: int, sane_index_shape: bool = False, **kwargs, ): if "lossconfig" in kwargs: logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.") kwargs["loss_config"] = kwargs.pop("lossconfig") super().__init__( regularizer_config={ "target": ( "sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer" ), "params": { "n_e": n_embed, "e_dim": embed_dim, "sane_index_shape": sane_index_shape, }, }, **kwargs, ) class IdentityFirstStage(AbstractAutoencoder): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def get_input(self, x: Any) -> Any: return x def encode(self, x: Any, *args, **kwargs) -> Any: return x def decode(self, x: Any, *args, **kwargs) -> Any: return x class AEIntegerWrapper(nn.Module): def __init__( self, model: nn.Module, shape: Union[None, Tuple[int, int], List[int]] = (16, 16), regularization_key: str = "regularization", encoder_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__() self.model = model assert hasattr(model, "encode") and hasattr( model, "decode" ), "Need AE interface" self.regularization = get_nested_attribute(model, regularization_key) self.shape = shape self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True}) def encode(self, x) -> torch.Tensor: assert ( not self.training ), f"{self.__class__.__name__} only supports inference currently" _, log = self.model.encode(x, **self.encoder_kwargs) assert isinstance(log, dict) inds = log["min_encoding_indices"] return rearrange(inds, "b ... -> b (...)") def decode( self, inds: torch.Tensor, shape: Union[None, tuple, list] = None ) -> torch.Tensor: # expect inds shape (b, s) with s = h*w shape = default(shape, self.shape) # Optional[(h, w)] if shape is not None: assert len(shape) == 2, f"Unhandeled shape {shape}" inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1]) h = self.regularization.get_codebook_entry(inds) # (b, h, w, c) h = rearrange(h, "b h w c -> b c h w") return self.model.decode(h) class AutoencoderKLModeOnly(AutoencodingEngineLegacy): def __init__(self, **kwargs): if "lossconfig" in kwargs: kwargs["loss_config"] = kwargs.pop("lossconfig") super().__init__( regularizer_config={ "target": ( "sgm.modules.autoencoding.regularizers" ".DiagonalGaussianRegularizer" ), "params": {"sample": False}, }, **kwargs, ) ================================================ FILE: sgm/models/diffusion.py ================================================ import math from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple, Union import pytorch_lightning as pl import torch from omegaconf import ListConfig, OmegaConf from safetensors.torch import load_file as load_safetensors from torch.optim.lr_scheduler import LambdaLR from ..modules import UNCONDITIONAL_CONFIG from ..modules.autoencoding.temporal_ae import VideoDecoder from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER from ..modules.ema import LitEma from ..util import (default, disabled_train, get_obj_from_str, instantiate_from_config, log_txt_as_img) class DiffusionEngine(pl.LightningModule): def __init__( self, network_config, denoiser_config, first_stage_config, conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None, sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None, optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None, scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None, loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None, network_wrapper: Union[None, str] = None, ckpt_path: Union[None, str] = None, use_ema: bool = False, ema_decay_rate: float = 0.9999, scale_factor: float = 1.0, disable_first_stage_autocast=False, input_key: str = "jpg", log_keys: Union[List, None] = None, no_cond_log: bool = False, compile_model: bool = False, en_and_decode_n_samples_a_time: Optional[int] = None, ): super().__init__() self.log_keys = log_keys self.input_key = input_key self.optimizer_config = default( optimizer_config, {"target": "torch.optim.AdamW"} ) model = instantiate_from_config(network_config) self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))( model, compile_model=compile_model ) self.denoiser = instantiate_from_config(denoiser_config) self.sampler = ( instantiate_from_config(sampler_config) if sampler_config is not None else None ) self.conditioner = instantiate_from_config( default(conditioner_config, UNCONDITIONAL_CONFIG) ) self.scheduler_config = scheduler_config self._init_first_stage(first_stage_config) self.loss_fn = ( instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None ) self.use_ema = use_ema if self.use_ema: self.model_ema = LitEma(self.model, decay=ema_decay_rate) print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") self.scale_factor = scale_factor self.disable_first_stage_autocast = disable_first_stage_autocast self.no_cond_log = no_cond_log if ckpt_path is not None: self.init_from_ckpt(ckpt_path) self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time def init_from_ckpt( self, path: str, ) -> None: if path.endswith("ckpt"): sd = torch.load(path, map_location="cpu")["state_dict"] elif path.endswith("safetensors"): sd = load_safetensors(path) else: raise NotImplementedError missing, unexpected = self.load_state_dict(sd, strict=False) print( f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" ) if len(missing) > 0: print(f"Missing Keys: {missing}") if len(unexpected) > 0: print(f"Unexpected Keys: {unexpected}") def _init_first_stage(self, config): model = instantiate_from_config(config).eval() model.train = disabled_train for param in model.parameters(): param.requires_grad = False self.first_stage_model = model def get_input(self, batch): # assuming unified data format, dataloader returns a dict. # image tensors should be scaled to -1 ... 1 and in bchw format return batch[self.input_key] @torch.no_grad() def decode_first_stage(self, z): z = 1.0 / self.scale_factor * z n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) n_rounds = math.ceil(z.shape[0] / n_samples) all_out = [] with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): for n in range(n_rounds): if isinstance(self.first_stage_model.decoder, VideoDecoder): kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} else: kwargs = {} out = self.first_stage_model.decode( z[n * n_samples : (n + 1) * n_samples], **kwargs ) all_out.append(out) out = torch.cat(all_out, dim=0) return out @torch.no_grad() def encode_first_stage(self, x): n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0]) n_rounds = math.ceil(x.shape[0] / n_samples) all_out = [] with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): for n in range(n_rounds): out = self.first_stage_model.encode( x[n * n_samples : (n + 1) * n_samples] ) all_out.append(out) z = torch.cat(all_out, dim=0) z = self.scale_factor * z return z def forward(self, x, batch): loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch) loss_mean = loss.mean() loss_dict = {"loss": loss_mean} return loss_mean, loss_dict def shared_step(self, batch: Dict) -> Any: x = self.get_input(batch) x = self.encode_first_stage(x) batch["global_step"] = self.global_step loss, loss_dict = self(x, batch) return loss, loss_dict def training_step(self, batch, batch_idx): loss, loss_dict = self.shared_step(batch) self.log_dict( loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False ) self.log( "global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False, ) if self.scheduler_config is not None: lr = self.optimizers().param_groups[0]["lr"] self.log( "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False ) return loss def on_train_start(self, *args, **kwargs): if self.sampler is None or self.loss_fn is None: raise ValueError("Sampler and loss function need to be set for training.") def on_train_batch_end(self, *args, **kwargs): if self.use_ema: self.model_ema(self.model) @contextmanager def ema_scope(self, context=None): if self.use_ema: self.model_ema.store(self.model.parameters()) self.model_ema.copy_to(self.model) if context is not None: print(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.model.parameters()) if context is not None: print(f"{context}: Restored training weights") def instantiate_optimizer_from_config(self, params, lr, cfg): return get_obj_from_str(cfg["target"])( params, lr=lr, **cfg.get("params", dict()) ) def configure_optimizers(self): lr = self.learning_rate params = list(self.model.parameters()) for embedder in self.conditioner.embedders: if embedder.is_trainable: params = params + list(embedder.parameters()) opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config) if self.scheduler_config is not None: scheduler = instantiate_from_config(self.scheduler_config) print("Setting up LambdaLR scheduler...") scheduler = [ { "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), "interval": "step", "frequency": 1, } ] return [opt], scheduler return opt @torch.no_grad() def sample( self, cond: Dict, uc: Union[Dict, None] = None, batch_size: int = 16, shape: Union[None, Tuple, List] = None, **kwargs, ): randn = torch.randn(batch_size, *shape).to(self.device) denoiser = lambda input, sigma, c: self.denoiser( self.model, input, sigma, c, **kwargs ) samples = self.sampler(denoiser, randn, cond, uc=uc) return samples @torch.no_grad() def log_conditionings(self, batch: Dict, n: int) -> Dict: """ Defines heuristics to log different conditionings. These can be lists of strings (text-to-image), tensors, ints, ... """ image_h, image_w = batch[self.input_key].shape[2:] log = dict() for embedder in self.conditioner.embedders: if ( (self.log_keys is None) or (embedder.input_key in self.log_keys) ) and not self.no_cond_log: x = batch[embedder.input_key][:n] if isinstance(x, torch.Tensor): if x.dim() == 1: # class-conditional, convert integer to string x = [str(x[i].item()) for i in range(x.shape[0])] xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4) elif x.dim() == 2: # size and crop cond and the like x = [ "x".join([str(xx) for xx in x[i].tolist()]) for i in range(x.shape[0]) ] xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) else: raise NotImplementedError() elif isinstance(x, (List, ListConfig)): if isinstance(x[0], str): # strings xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) else: raise NotImplementedError() else: raise NotImplementedError() log[embedder.input_key] = xc return log @torch.no_grad() def log_images( self, batch: Dict, N: int = 8, sample: bool = True, ucg_keys: List[str] = None, **kwargs, ) -> Dict: conditioner_input_keys = [e.input_key for e in self.conditioner.embedders] if ucg_keys: assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( "Each defined ucg key for sampling must be in the provided conditioner input keys," f"but we have {ucg_keys} vs. {conditioner_input_keys}" ) else: ucg_keys = conditioner_input_keys log = dict() x = self.get_input(batch) c, uc = self.conditioner.get_unconditional_conditioning( batch, force_uc_zero_embeddings=ucg_keys if len(self.conditioner.embedders) > 0 else [], ) sampling_kwargs = {} N = min(x.shape[0], N) x = x.to(self.device)[:N] log["inputs"] = x z = self.encode_first_stage(x) log["reconstructions"] = self.decode_first_stage(z) log.update(self.log_conditionings(batch, N)) for k in c: if isinstance(c[k], torch.Tensor): c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc)) if sample: with self.ema_scope("Plotting"): samples = self.sample( c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs ) samples = self.decode_first_stage(samples) log["samples"] = samples return log ================================================ FILE: sgm/modules/__init__.py ================================================ from .encoders.modules import GeneralConditioner UNCONDITIONAL_CONFIG = { "target": "sgm.modules.GeneralConditioner", "params": {"emb_models": []}, } ================================================ FILE: sgm/modules/attention.py ================================================ import logging import math from inspect import isfunction from typing import Any, Optional import torch import torch.nn.functional as F from einops import rearrange, repeat from packaging import version from torch import nn from torch.utils.checkpoint import checkpoint logpy = logging.getLogger(__name__) if version.parse(torch.__version__) >= version.parse("2.0.0"): SDP_IS_AVAILABLE = True from torch.backends.cuda import SDPBackend, sdp_kernel BACKEND_MAP = { SDPBackend.MATH: { "enable_math": True, "enable_flash": False, "enable_mem_efficient": False, }, SDPBackend.FLASH_ATTENTION: { "enable_math": False, "enable_flash": True, "enable_mem_efficient": False, }, SDPBackend.EFFICIENT_ATTENTION: { "enable_math": False, "enable_flash": False, "enable_mem_efficient": True, }, None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True}, } else: from contextlib import nullcontext SDP_IS_AVAILABLE = False sdp_kernel = nullcontext BACKEND_MAP = {} logpy.warn( f"No SDP backend available, likely because you are running in pytorch " f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. " f"You might want to consider upgrading." ) try: import xformers import xformers.ops XFORMERS_IS_AVAILABLE = True except: XFORMERS_IS_AVAILABLE = False logpy.warn("no module 'xformers'. Processing without...") # from .diffusionmodules.util import mixed_checkpoint as checkpoint def exists(val): return val is not None def uniq(arr): return {el: True for el in arr}.keys() def default(val, d): if exists(val): return val return d() if isfunction(d) else d def max_neg_value(t): return -torch.finfo(t.dtype).max def init_(tensor): dim = tensor.shape[-1] std = 1 / math.sqrt(dim) tensor.uniform_(-std, std) return tensor # feedforward class GEGLU(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) return x * F.gelu(gate) class FeedForward(nn.Module): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) project_in = ( nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) ) self.net = nn.Sequential( project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) ) def forward(self, x): return self.net(x) def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module def Normalize(in_channels): return torch.nn.GroupNorm( num_groups=32, num_channels=in_channels, eps=1e-6, affine=True ) class LinearAttention(nn.Module): def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) q, k, v = rearrange( qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 ) k = k.softmax(dim=-1) context = torch.einsum("bhdn,bhen->bhde", k, v) out = torch.einsum("bhde,bhdn->bhen", context, q) out = rearrange( out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w ) return self.to_out(out) class SelfAttention(nn.Module): ATTENTION_MODES = ("xformers", "torch", "math") def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, qk_scale: Optional[float] = None, attn_drop: float = 0.0, proj_drop: float = 0.0, attn_mode: str = "xformers", ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) assert attn_mode in self.ATTENTION_MODES self.attn_mode = attn_mode def forward(self, x: torch.Tensor) -> torch.Tensor: B, L, C = x.shape qkv = self.qkv(x) if self.attn_mode == "torch": qkv = rearrange( qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads ).float() q, k, v = qkv[0], qkv[1], qkv[2] # B H L D x = torch.nn.functional.scaled_dot_product_attention(q, k, v) x = rearrange(x, "B H L D -> B L (H D)") elif self.attn_mode == "xformers": qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads) q, k, v = qkv[0], qkv[1], qkv[2] # B L H D x = xformers.ops.memory_efficient_attention(q, k, v) x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads) elif self.attn_mode == "math": qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) q, k, v = qkv[0], qkv[1], qkv[2] # B H L D attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, L, C) else: raise NotImplemented x = self.proj(x) x = self.proj_drop(x) return x class SpatialSelfAttention(nn.Module): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.k = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.v = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.proj_out = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) def forward(self, x): h_ = x h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention b, c, h, w = q.shape q = rearrange(q, "b c h w -> b (h w) c") k = rearrange(k, "b c h w -> b c (h w)") w_ = torch.einsum("bij,bjk->bik", q, k) w_ = w_ * (int(c) ** (-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values v = rearrange(v, "b c h w -> b c (h w)") w_ = rearrange(w_, "b i j -> b j i") h_ = torch.einsum("bij,bjk->bik", v, w_) h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) h_ = self.proj_out(h_) return x + h_ class CrossAttention(nn.Module): def __init__( self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, backend=None, ): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) self.scale = dim_head**-0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) self.backend = backend def forward( self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0, ): h = self.heads if additional_tokens is not None: # get the number of masked tokens at the beginning of the output sequence n_tokens_to_mask = additional_tokens.shape[1] # add additional token x = torch.cat([additional_tokens, x], dim=1) q = self.to_q(x) context = default(context, x) k = self.to_k(context) v = self.to_v(context) if n_times_crossframe_attn_in_self: # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439 assert x.shape[0] % n_times_crossframe_attn_in_self == 0 n_cp = x.shape[0] // n_times_crossframe_attn_in_self k = repeat( k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp ) v = repeat( v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp ) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) ## old """ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale del q, k if exists(mask): mask = rearrange(mask, 'b ... -> b (...)') max_neg_value = -torch.finfo(sim.dtype).max mask = repeat(mask, 'b j -> (b h) () j', h=h) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of sim = sim.softmax(dim=-1) out = einsum('b i j, b j d -> b i d', sim, v) """ ## new with sdp_kernel(**BACKEND_MAP[self.backend]): # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape) out = F.scaled_dot_product_attention( q, k, v, attn_mask=mask ) # scale is dim_head ** -0.5 per default del q, k, v out = rearrange(out, "b h n d -> b n (h d)", h=h) if additional_tokens is not None: # remove additional token out = out[:, n_tokens_to_mask:] return self.to_out(out) class MemoryEfficientCrossAttention(nn.Module): # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 def __init__( self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs ): super().__init__() logpy.debug( f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, " f"context_dim is {context_dim} and using {heads} heads with a " f"dimension of {dim_head}." ) inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) self.heads = heads self.dim_head = dim_head self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) self.attention_op: Optional[Any] = None def forward( self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0, ): if additional_tokens is not None: # get the number of masked tokens at the beginning of the output sequence n_tokens_to_mask = additional_tokens.shape[1] # add additional token x = torch.cat([additional_tokens, x], dim=1) q = self.to_q(x) context = default(context, x) k = self.to_k(context) v = self.to_v(context) if n_times_crossframe_attn_in_self: # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439 assert x.shape[0] % n_times_crossframe_attn_in_self == 0 # n_cp = x.shape[0]//n_times_crossframe_attn_in_self k = repeat( k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_times_crossframe_attn_in_self, ) v = repeat( v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_times_crossframe_attn_in_self, ) b, _, _ = q.shape q, k, v = map( lambda t: t.unsqueeze(3) .reshape(b, t.shape[1], self.heads, self.dim_head) .permute(0, 2, 1, 3) .reshape(b * self.heads, t.shape[1], self.dim_head) .contiguous(), (q, k, v), ) # actually compute the attention, what we cannot get enough of if version.parse(xformers.__version__) >= version.parse("0.0.21"): # NOTE: workaround for # https://github.com/facebookresearch/xformers/issues/845 max_bs = 32768 N = q.shape[0] n_batches = math.ceil(N / max_bs) out = list() for i_batch in range(n_batches): batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs) out.append( xformers.ops.memory_efficient_attention( q[batch], k[batch], v[batch], attn_bias=None, op=self.attention_op, ) ) out = torch.cat(out, 0) else: out = xformers.ops.memory_efficient_attention( q, k, v, attn_bias=None, op=self.attention_op ) # TODO: Use this directly in the attention operation, as a bias if exists(mask): raise NotImplementedError out = ( out.unsqueeze(0) .reshape(b, self.heads, out.shape[1], self.dim_head) .permute(0, 2, 1, 3) .reshape(b, out.shape[1], self.heads * self.dim_head) ) if additional_tokens is not None: # remove additional token out = out[:, n_tokens_to_mask:] return self.to_out(out) class BasicTransformerBlock(nn.Module): ATTENTION_MODES = { "softmax": CrossAttention, # vanilla attention "softmax-xformers": MemoryEfficientCrossAttention, # ampere } def __init__( self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False, attn_mode="softmax", sdp_backend=None, ): super().__init__() assert attn_mode in self.ATTENTION_MODES if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE: logpy.warn( f"Attention mode '{attn_mode}' is not available. Falling " f"back to native attention. This is not a problem in " f"Pytorch >= 2.0. FYI, you are running with PyTorch " f"version {torch.__version__}." ) attn_mode = "softmax" elif attn_mode == "softmax" and not SDP_IS_AVAILABLE: logpy.warn( "We do not support vanilla attention anymore, as it is too " "expensive. Sorry." ) if not XFORMERS_IS_AVAILABLE: assert ( False ), "Please install xformers via e.g. 'pip install xformers==0.0.16'" else: logpy.info("Falling back to xformers efficient attention.") attn_mode = "softmax-xformers" attn_cls = self.ATTENTION_MODES[attn_mode] if version.parse(torch.__version__) >= version.parse("2.0.0"): assert sdp_backend is None or isinstance(sdp_backend, SDPBackend) else: assert sdp_backend is None self.disable_self_attn = disable_self_attn self.attn1 = attn_cls( query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, context_dim=context_dim if self.disable_self_attn else None, backend=sdp_backend, ) # is a self-attention if not self.disable_self_attn self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.attn2 = attn_cls( query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout, backend=sdp_backend, ) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) self.checkpoint = checkpoint if self.checkpoint: logpy.debug(f"{self.__class__.__name__} is using checkpointing") def forward( self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 ): kwargs = {"x": x} if context is not None: kwargs.update({"context": context}) if additional_tokens is not None: kwargs.update({"additional_tokens": additional_tokens}) if n_times_crossframe_attn_in_self: kwargs.update( {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self} ) # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint) if self.checkpoint: # inputs = {"x": x, "context": context} return checkpoint(self._forward, x, context) # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint) else: return self._forward(**kwargs) def _forward( self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 ): x = ( self.attn1( self.norm1(x), context=context if self.disable_self_attn else None, additional_tokens=additional_tokens, n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0, ) + x ) x = ( self.attn2( self.norm2(x), context=context, additional_tokens=additional_tokens ) + x ) x = self.ff(self.norm3(x)) + x return x class BasicTransformerSingleLayerBlock(nn.Module): ATTENTION_MODES = { "softmax": CrossAttention, # vanilla attention "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128]) } def __init__( self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True, attn_mode="softmax", ): super().__init__() assert attn_mode in self.ATTENTION_MODES attn_cls = self.ATTENTION_MODES[attn_mode] self.attn1 = attn_cls( query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, context_dim=context_dim, ) self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.checkpoint = checkpoint def forward(self, x, context=None): # inputs = {"x": x, "context": context} # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint) return checkpoint(self._forward, x, context) def _forward(self, x, context=None): x = self.attn1(self.norm1(x), context=context) + x x = self.ff(self.norm2(x)) + x return x class SpatialTransformer(nn.Module): """ Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply standard transformer action. Finally, reshape to image NEW: use_linear for more efficiency instead of the 1x1 convs """ def __init__( self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None, disable_self_attn=False, use_linear=False, attn_type="softmax", use_checkpoint=True, # sdp_backend=SDPBackend.FLASH_ATTENTION sdp_backend=None, ): super().__init__() logpy.debug( f"constructing {self.__class__.__name__} of depth {depth} w/ " f"{in_channels} channels and {n_heads} heads." ) if exists(context_dim) and not isinstance(context_dim, list): context_dim = [context_dim] if exists(context_dim) and isinstance(context_dim, list): if depth != len(context_dim): logpy.warn( f"{self.__class__.__name__}: Found context dims " f"{context_dim} of depth {len(context_dim)}, which does not " f"match the specified 'depth' of {depth}. Setting context_dim " f"to {depth * [context_dim[0]]} now." ) # depth does not match context dims. assert all( map(lambda x: x == context_dim[0], context_dim) ), "need homogenous context_dim to match depth automatically" context_dim = depth * [context_dim[0]] elif context_dim is None: context_dim = [None] * depth self.in_channels = in_channels inner_dim = n_heads * d_head self.norm = Normalize(in_channels) if not use_linear: self.proj_in = nn.Conv2d( in_channels, inner_dim, kernel_size=1, stride=1, padding=0 ) else: self.proj_in = nn.Linear(in_channels, inner_dim) self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], disable_self_attn=disable_self_attn, attn_mode=attn_type, checkpoint=use_checkpoint, sdp_backend=sdp_backend, ) for d in range(depth) ] ) if not use_linear: self.proj_out = zero_module( nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) ) else: # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) self.use_linear = use_linear def forward(self, x, context=None): # note: if no context is given, cross-attention defaults to self-attention if not isinstance(context, list): context = [context] b, c, h, w = x.shape x_in = x x = self.norm(x) if not self.use_linear: x = self.proj_in(x) x = rearrange(x, "b c h w -> b (h w) c").contiguous() if self.use_linear: x = self.proj_in(x) for i, block in enumerate(self.transformer_blocks): if i > 0 and len(context) == 1: i = 0 # use same context for each block x = block(x, context=context[i]) if self.use_linear: x = self.proj_out(x) x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() if not self.use_linear: x = self.proj_out(x) return x + x_in class SimpleTransformer(nn.Module): def __init__( self, dim: int, depth: int, heads: int, dim_head: int, context_dim: Optional[int] = None, dropout: float = 0.0, checkpoint: bool = True, ): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( BasicTransformerBlock( dim, heads, dim_head, dropout=dropout, context_dim=context_dim, attn_mode="softmax-xformers", checkpoint=checkpoint, ) ) def forward( self, x: torch.Tensor, context: Optional[torch.Tensor] = None, ) -> torch.Tensor: for layer in self.layers: x = layer(x, context) return x ================================================ FILE: sgm/modules/autoencoding/__init__.py ================================================ ================================================ FILE: sgm/modules/autoencoding/losses/__init__.py ================================================ __all__ = [ "GeneralLPIPSWithDiscriminator", "LatentLPIPS", ] from .discriminator_loss import GeneralLPIPSWithDiscriminator from .lpips import LatentLPIPS ================================================ FILE: sgm/modules/autoencoding/losses/discriminator_loss.py ================================================ from typing import Dict, Iterator, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torchvision from einops import rearrange from matplotlib import colormaps from matplotlib import pyplot as plt from ....util import default, instantiate_from_config from ..lpips.loss.lpips import LPIPS from ..lpips.model.model import weights_init from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss class GeneralLPIPSWithDiscriminator(nn.Module): def __init__( self, disc_start: int, logvar_init: float = 0.0, disc_num_layers: int = 3, disc_in_channels: int = 3, disc_factor: float = 1.0, disc_weight: float = 1.0, perceptual_weight: float = 1.0, disc_loss: str = "hinge", scale_input_to_tgt_size: bool = False, dims: int = 2, learn_logvar: bool = False, regularization_weights: Union[None, Dict[str, float]] = None, additional_log_keys: Optional[List[str]] = None, discriminator_config: Optional[Dict] = None, ): super().__init__() self.dims = dims if self.dims > 2: print( f"running with dims={dims}. This means that for perceptual loss " f"calculation, the LPIPS loss will be applied to each frame " f"independently." ) self.scale_input_to_tgt_size = scale_input_to_tgt_size assert disc_loss in ["hinge", "vanilla"] self.perceptual_loss = LPIPS().eval() self.perceptual_weight = perceptual_weight # output log variance self.logvar = nn.Parameter( torch.full((), logvar_init), requires_grad=learn_logvar ) self.learn_logvar = learn_logvar discriminator_config = default( discriminator_config, { "target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator", "params": { "input_nc": disc_in_channels, "n_layers": disc_num_layers, "use_actnorm": False, }, }, ) self.discriminator = instantiate_from_config(discriminator_config).apply( weights_init ) self.discriminator_iter_start = disc_start self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss self.disc_factor = disc_factor self.discriminator_weight = disc_weight self.regularization_weights = default(regularization_weights, {}) self.forward_keys = [ "optimizer_idx", "global_step", "last_layer", "split", "regularization_log", ] self.additional_log_keys = set(default(additional_log_keys, [])) self.additional_log_keys.update(set(self.regularization_weights.keys())) def get_trainable_parameters(self) -> Iterator[nn.Parameter]: return self.discriminator.parameters() def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]: if self.learn_logvar: yield self.logvar yield from () @torch.no_grad() def log_images( self, inputs: torch.Tensor, reconstructions: torch.Tensor ) -> Dict[str, torch.Tensor]: # calc logits of real/fake logits_real = self.discriminator(inputs.contiguous().detach()) if len(logits_real.shape) < 4: # Non patch-discriminator return dict() logits_fake = self.discriminator(reconstructions.contiguous().detach()) # -> (b, 1, h, w) # parameters for colormapping high = max(logits_fake.abs().max(), logits_real.abs().max()).item() cmap = colormaps["PiYG"] # diverging colormap def to_colormap(logits: torch.Tensor) -> torch.Tensor: """(b, 1, ...) -> (b, 3, ...)""" logits = (logits + high) / (2 * high) logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel # -> (b, 1, ..., 3) logits = torch.from_numpy(logits_np).to(logits.device) return rearrange(logits, "b 1 ... c -> b c ...") logits_real = torch.nn.functional.interpolate( logits_real, size=inputs.shape[-2:], mode="nearest", antialias=False, ) logits_fake = torch.nn.functional.interpolate( logits_fake, size=reconstructions.shape[-2:], mode="nearest", antialias=False, ) # alpha value of logits for overlay alpha_real = torch.abs(logits_real) / high alpha_fake = torch.abs(logits_fake) / high # -> (b, 1, h, w) in range [0, 0.5] # alpha value of lines don't really matter, since the values are the same # for both images and logits anyway grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4) grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4) grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1) # -> (1, h, w) # blend logits and images together # prepare logits for plotting logits_real = to_colormap(logits_real) logits_fake = to_colormap(logits_fake) # resize logits # -> (b, 3, h, w) # make some grids # add all logits to one plot logits_real = torchvision.utils.make_grid(logits_real, nrow=4) logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4) # I just love how torchvision calls the number of columns `nrow` grid_logits = torch.cat((logits_real, logits_fake), dim=1) # -> (3, h, w) grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4) grid_images_fake = torchvision.utils.make_grid( 0.5 * reconstructions + 0.5, nrow=4 ) grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1) # -> (3, h, w) in range [0, 1] grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images # Create labeled colorbar dpi = 100 height = 128 / dpi width = grid_logits.shape[2] / dpi fig, ax = plt.subplots(figsize=(width, height), dpi=dpi) img = ax.imshow(np.array([[-high, high]]), cmap=cmap) plt.colorbar( img, cax=ax, orientation="horizontal", fraction=0.9, aspect=width / height, pad=0.0, ) img.set_visible(False) fig.tight_layout() fig.canvas.draw() # manually convert figure to numpy cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,)) cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0 cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device) # Add colorbar to plot annotated_grid = torch.cat((grid_logits, cbar), dim=1) blended_grid = torch.cat((grid_blend, cbar), dim=1) return { "vis_logits": 2 * annotated_grid[None, ...] - 1, "vis_logits_blended": 2 * blended_grid[None, ...] - 1, } def calculate_adaptive_weight( self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor ) -> torch.Tensor: nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() d_weight = d_weight * self.discriminator_weight return d_weight def forward( self, inputs: torch.Tensor, reconstructions: torch.Tensor, *, # added because I changed the order here regularization_log: Dict[str, torch.Tensor], optimizer_idx: int, global_step: int, last_layer: torch.Tensor, split: str = "train", weights: Union[None, float, torch.Tensor] = None, ) -> Tuple[torch.Tensor, dict]: if self.scale_input_to_tgt_size: inputs = torch.nn.functional.interpolate( inputs, reconstructions.shape[2:], mode="bicubic", antialias=True ) if self.dims > 2: inputs, reconstructions = map( lambda x: rearrange(x, "b c t h w -> (b t) c h w"), (inputs, reconstructions), ) rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) if self.perceptual_weight > 0: p_loss = self.perceptual_loss( inputs.contiguous(), reconstructions.contiguous() ) rec_loss = rec_loss + self.perceptual_weight * p_loss nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights) # now the GAN part if optimizer_idx == 0: # generator update if global_step >= self.discriminator_iter_start or not self.training: logits_fake = self.discriminator(reconstructions.contiguous()) g_loss = -torch.mean(logits_fake) if self.training: d_weight = self.calculate_adaptive_weight( nll_loss, g_loss, last_layer=last_layer ) else: d_weight = torch.tensor(1.0) else: d_weight = torch.tensor(0.0) g_loss = torch.tensor(0.0, requires_grad=True) loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss log = dict() for k in regularization_log: if k in self.regularization_weights: loss = loss + self.regularization_weights[k] * regularization_log[k] if k in self.additional_log_keys: log[f"{split}/{k}"] = regularization_log[k].detach().float().mean() log.update( { f"{split}/loss/total": loss.clone().detach().mean(), f"{split}/loss/nll": nll_loss.detach().mean(), f"{split}/loss/rec": rec_loss.detach().mean(), f"{split}/loss/g": g_loss.detach().mean(), f"{split}/scalars/logvar": self.logvar.detach(), f"{split}/scalars/d_weight": d_weight.detach(), } ) return loss, log elif optimizer_idx == 1: # second pass for discriminator update logits_real = self.discriminator(inputs.contiguous().detach()) logits_fake = self.discriminator(reconstructions.contiguous().detach()) if global_step >= self.discriminator_iter_start or not self.training: d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake) else: d_loss = torch.tensor(0.0, requires_grad=True) log = { f"{split}/loss/disc": d_loss.clone().detach().mean(), f"{split}/logits/real": logits_real.detach().mean(), f"{split}/logits/fake": logits_fake.detach().mean(), } return d_loss, log else: raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}") def get_nll_loss( self, rec_loss: torch.Tensor, weights: Optional[Union[float, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar weighted_nll_loss = nll_loss if weights is not None: weighted_nll_loss = weights * nll_loss weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] return nll_loss, weighted_nll_loss ================================================ FILE: sgm/modules/autoencoding/losses/lpips.py ================================================ import torch import torch.nn as nn from ....util import default, instantiate_from_config from ..lpips.loss.lpips import LPIPS class LatentLPIPS(nn.Module): def __init__( self, decoder_config, perceptual_weight=1.0, latent_weight=1.0, scale_input_to_tgt_size=False, scale_tgt_to_input_size=False, perceptual_weight_on_inputs=0.0, ): super().__init__() self.scale_input_to_tgt_size = scale_input_to_tgt_size self.scale_tgt_to_input_size = scale_tgt_to_input_size self.init_decoder(decoder_config) self.perceptual_loss = LPIPS().eval() self.perceptual_weight = perceptual_weight self.latent_weight = latent_weight self.perceptual_weight_on_inputs = perceptual_weight_on_inputs def init_decoder(self, config): self.decoder = instantiate_from_config(config) if hasattr(self.decoder, "encoder"): del self.decoder.encoder def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"): log = dict() loss = (latent_inputs - latent_predictions) ** 2 log[f"{split}/latent_l2_loss"] = loss.mean().detach() image_reconstructions = None if self.perceptual_weight > 0.0: image_reconstructions = self.decoder.decode(latent_predictions) image_targets = self.decoder.decode(latent_inputs) perceptual_loss = self.perceptual_loss( image_targets.contiguous(), image_reconstructions.contiguous() ) loss = ( self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean() ) log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach() if self.perceptual_weight_on_inputs > 0.0: image_reconstructions = default( image_reconstructions, self.decoder.decode(latent_predictions) ) if self.scale_input_to_tgt_size: image_inputs = torch.nn.functional.interpolate( image_inputs, image_reconstructions.shape[2:], mode="bicubic", antialias=True, ) elif self.scale_tgt_to_input_size: image_reconstructions = torch.nn.functional.interpolate( image_reconstructions, image_inputs.shape[2:], mode="bicubic", antialias=True, ) perceptual_loss2 = self.perceptual_loss( image_inputs.contiguous(), image_reconstructions.contiguous() ) loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean() log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach() return loss, log ================================================ FILE: sgm/modules/autoencoding/lpips/__init__.py ================================================ ================================================ FILE: sgm/modules/autoencoding/lpips/loss/.gitignore ================================================ vgg.pth ================================================ FILE: sgm/modules/autoencoding/lpips/loss/LICENSE ================================================ Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: sgm/modules/autoencoding/lpips/loss/__init__.py ================================================ ================================================ FILE: sgm/modules/autoencoding/lpips/loss/lpips.py ================================================ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" from collections import namedtuple import torch import torch.nn as nn from torchvision import models from ..util import get_ckpt_path class LPIPS(nn.Module): # Learned perceptual metric def __init__(self, use_dropout=True): super().__init__() self.scaling_layer = ScalingLayer() self.chns = [64, 128, 256, 512, 512] # vg16 features self.net = vgg16(pretrained=True, requires_grad=False) self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) self.load_from_pretrained() for param in self.parameters(): param.requires_grad = False def load_from_pretrained(self, name="vgg_lpips"): ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss") self.load_state_dict( torch.load(ckpt, map_location=torch.device("cpu")), strict=False ) print("loaded pretrained LPIPS loss from {}".format(ckpt)) @classmethod def from_pretrained(cls, name="vgg_lpips"): if name != "vgg_lpips": raise NotImplementedError model = cls() ckpt = get_ckpt_path(name) model.load_state_dict( torch.load(ckpt, map_location=torch.device("cpu")), strict=False ) return model def forward(self, input, target): in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) outs0, outs1 = self.net(in0_input), self.net(in1_input) feats0, feats1, diffs = {}, {}, {} lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] for kk in range(len(self.chns)): feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( outs1[kk] ) diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 res = [ spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns)) ] val = res[0] for l in range(1, len(self.chns)): val += res[l] return val class ScalingLayer(nn.Module): def __init__(self): super(ScalingLayer, self).__init__() self.register_buffer( "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] ) self.register_buffer( "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] ) def forward(self, inp): return (inp - self.shift) / self.scale class NetLinLayer(nn.Module): """A single linear layer which does a 1x1 conv""" def __init__(self, chn_in, chn_out=1, use_dropout=False): super(NetLinLayer, self).__init__() layers = ( [ nn.Dropout(), ] if (use_dropout) else [] ) layers += [ nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] self.model = nn.Sequential(*layers) class vgg16(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): super(vgg16, self).__init__() vgg_pretrained_features = models.vgg16(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.N_slices = 5 for x in range(4): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(4, 9): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(9, 16): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(16, 23): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(23, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h = self.slice1(X) h_relu1_2 = h h = self.slice2(h) h_relu2_2 = h h = self.slice3(h) h_relu3_3 = h h = self.slice4(h) h_relu4_3 = h h = self.slice5(h) h_relu5_3 = h vgg_outputs = namedtuple( "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] ) out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) return out def normalize_tensor(x, eps=1e-10): norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) return x / (norm_factor + eps) def spatial_average(x, keepdim=True): return x.mean([2, 3], keepdim=keepdim) ================================================ FILE: sgm/modules/autoencoding/lpips/model/LICENSE ================================================ Copyright (c) 2017, Jun-Yan Zhu and Taesung Park All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. --------------------------- LICENSE FOR pix2pix -------------------------------- BSD License For pix2pix software Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. ----------------------------- LICENSE FOR DCGAN -------------------------------- BSD License For dcgan.torch software Copyright (c) 2015, Facebook, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: sgm/modules/autoencoding/lpips/model/__init__.py ================================================ ================================================ FILE: sgm/modules/autoencoding/lpips/model/model.py ================================================ import functools import torch.nn as nn from ..util import ActNorm def weights_init(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find("BatchNorm") != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0) class NLayerDiscriminator(nn.Module): """Defines a PatchGAN discriminator as in Pix2Pix --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py """ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): """Construct a PatchGAN discriminator Parameters: input_nc (int) -- the number of channels in input images ndf (int) -- the number of filters in the last conv layer n_layers (int) -- the number of conv layers in the discriminator norm_layer -- normalization layer """ super(NLayerDiscriminator, self).__init__() if not use_actnorm: norm_layer = nn.BatchNorm2d else: norm_layer = ActNorm if ( type(norm_layer) == functools.partial ): # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func != nn.BatchNorm2d else: use_bias = norm_layer != nn.BatchNorm2d kw = 4 padw = 1 sequence = [ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True), ] nf_mult = 1 nf_mult_prev = 1 for n in range(1, n_layers): # gradually increase the number of filters nf_mult_prev = nf_mult nf_mult = min(2**n, 8) sequence += [ nn.Conv2d( ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias, ), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True), ] nf_mult_prev = nf_mult nf_mult = min(2**n_layers, 8) sequence += [ nn.Conv2d( ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias, ), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True), ] sequence += [ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) ] # output 1 channel prediction map self.main = nn.Sequential(*sequence) def forward(self, input): """Standard forward.""" return self.main(input) ================================================ FILE: sgm/modules/autoencoding/lpips/util.py ================================================ import hashlib import os import requests import torch import torch.nn as nn from tqdm import tqdm URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} CKPT_MAP = {"vgg_lpips": "vgg.pth"} MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} def download(url, local_path, chunk_size=1024): os.makedirs(os.path.split(local_path)[0], exist_ok=True) with requests.get(url, stream=True) as r: total_size = int(r.headers.get("content-length", 0)) with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: with open(local_path, "wb") as f: for data in r.iter_content(chunk_size=chunk_size): if data: f.write(data) pbar.update(chunk_size) def md5_hash(path): with open(path, "rb") as f: content = f.read() return hashlib.md5(content).hexdigest() def get_ckpt_path(name, root, check=False): assert name in URL_MAP path = os.path.join(root, CKPT_MAP[name]) if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) download(URL_MAP[name], path) md5 = md5_hash(path) assert md5 == MD5_MAP[name], md5 return path class ActNorm(nn.Module): def __init__( self, num_features, logdet=False, affine=True, allow_reverse_init=False ): assert affine super().__init__() self.logdet = logdet self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) self.allow_reverse_init = allow_reverse_init self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) def initialize(self, input): with torch.no_grad(): flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) mean = ( flatten.mean(1) .unsqueeze(1) .unsqueeze(2) .unsqueeze(3) .permute(1, 0, 2, 3) ) std = ( flatten.std(1) .unsqueeze(1) .unsqueeze(2) .unsqueeze(3) .permute(1, 0, 2, 3) ) self.loc.data.copy_(-mean) self.scale.data.copy_(1 / (std + 1e-6)) def forward(self, input, reverse=False): if reverse: return self.reverse(input) if len(input.shape) == 2: input = input[:, :, None, None] squeeze = True else: squeeze = False _, _, height, width = input.shape if self.training and self.initialized.item() == 0: self.initialize(input) self.initialized.fill_(1) h = self.scale * (input + self.loc) if squeeze: h = h.squeeze(-1).squeeze(-1) if self.logdet: log_abs = torch.log(torch.abs(self.scale)) logdet = height * width * torch.sum(log_abs) logdet = logdet * torch.ones(input.shape[0]).to(input) return h, logdet return h def reverse(self, output): if self.training and self.initialized.item() == 0: if not self.allow_reverse_init: raise RuntimeError( "Initializing ActNorm in reverse direction is " "disabled by default. Use allow_reverse_init=True to enable." ) else: self.initialize(output) self.initialized.fill_(1) if len(output.shape) == 2: output = output[:, :, None, None] squeeze = True else: squeeze = False h = output / self.scale - self.loc if squeeze: h = h.squeeze(-1).squeeze(-1) return h ================================================ FILE: sgm/modules/autoencoding/lpips/vqperceptual.py ================================================ import torch import torch.nn.functional as F def hinge_d_loss(logits_real, logits_fake): loss_real = torch.mean(F.relu(1.0 - logits_real)) loss_fake = torch.mean(F.relu(1.0 + logits_fake)) d_loss = 0.5 * (loss_real + loss_fake) return d_loss def vanilla_d_loss(logits_real, logits_fake): d_loss = 0.5 * ( torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake)) ) return d_loss ================================================ FILE: sgm/modules/autoencoding/regularizers/__init__.py ================================================ from abc import abstractmethod from typing import Any, Tuple import torch import torch.nn as nn import torch.nn.functional as F from ....modules.distributions.distributions import \ DiagonalGaussianDistribution from .base import AbstractRegularizer class DiagonalGaussianRegularizer(AbstractRegularizer): def __init__(self, sample: bool = True): super().__init__() self.sample = sample def get_trainable_parameters(self) -> Any: yield from () def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: log = dict() posterior = DiagonalGaussianDistribution(z) if self.sample: z = posterior.sample() else: z = posterior.mode() kl_loss = posterior.kl() kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] log["kl_loss"] = kl_loss return z, log ================================================ FILE: sgm/modules/autoencoding/regularizers/base.py ================================================ from abc import abstractmethod from typing import Any, Tuple import torch import torch.nn.functional as F from torch import nn class AbstractRegularizer(nn.Module): def __init__(self): super().__init__() def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: raise NotImplementedError() @abstractmethod def get_trainable_parameters(self) -> Any: raise NotImplementedError() class IdentityRegularizer(AbstractRegularizer): def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: return z, dict() def get_trainable_parameters(self) -> Any: yield from () def measure_perplexity( predicted_indices: torch.Tensor, num_centroids: int ) -> Tuple[torch.Tensor, torch.Tensor]: # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally encodings = ( F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) ) avg_probs = encodings.mean(0) perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() cluster_use = torch.sum(avg_probs > 0) return perplexity, cluster_use ================================================ FILE: sgm/modules/autoencoding/regularizers/quantize.py ================================================ import logging from abc import abstractmethod from typing import Dict, Iterator, Literal, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torch import einsum from .base import AbstractRegularizer, measure_perplexity logpy = logging.getLogger(__name__) class AbstractQuantizer(AbstractRegularizer): def __init__(self): super().__init__() # Define these in your init # shape (N,) self.used: Optional[torch.Tensor] self.re_embed: int self.unknown_index: Union[Literal["random"], int] def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor: assert self.used is not None, "You need to define used indices for remap" ishape = inds.shape assert len(ishape) > 1 inds = inds.reshape(ishape[0], -1) used = self.used.to(inds) match = (inds[:, :, None] == used[None, None, ...]).long() new = match.argmax(-1) unknown = match.sum(2) < 1 if self.unknown_index == "random": new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to( device=new.device ) else: new[unknown] = self.unknown_index return new.reshape(ishape) def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor: assert self.used is not None, "You need to define used indices for remap" ishape = inds.shape assert len(ishape) > 1 inds = inds.reshape(ishape[0], -1) used = self.used.to(inds) if self.re_embed > self.used.shape[0]: # extra token inds[inds >= self.used.shape[0]] = 0 # simply set to zero back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) return back.reshape(ishape) @abstractmethod def get_codebook_entry( self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None ) -> torch.Tensor: raise NotImplementedError() def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]: yield from self.parameters() class GumbelQuantizer(AbstractQuantizer): """ credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) Gumbel Softmax trick quantizer Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 https://arxiv.org/abs/1611.01144 """ def __init__( self, num_hiddens: int, embedding_dim: int, n_embed: int, straight_through: bool = True, kl_weight: float = 5e-4, temp_init: float = 1.0, remap: Optional[str] = None, unknown_index: str = "random", loss_key: str = "loss/vq", ) -> None: super().__init__() self.loss_key = loss_key self.embedding_dim = embedding_dim self.n_embed = n_embed self.straight_through = straight_through self.temperature = temp_init self.kl_weight = kl_weight self.proj = nn.Conv2d(num_hiddens, n_embed, 1) self.embed = nn.Embedding(n_embed, embedding_dim) self.remap = remap if self.remap is not None: self.register_buffer("used", torch.tensor(np.load(self.remap))) self.re_embed = self.used.shape[0] else: self.used = None self.re_embed = n_embed if unknown_index == "extra": self.unknown_index = self.re_embed self.re_embed = self.re_embed + 1 else: assert unknown_index == "random" or isinstance( unknown_index, int ), "unknown index needs to be 'random', 'extra' or any integer" self.unknown_index = unknown_index # "random" or "extra" or integer if self.remap is not None: logpy.info( f"Remapping {self.n_embed} indices to {self.re_embed} indices. " f"Using {self.unknown_index} for unknown indices." ) def forward( self, z: torch.Tensor, temp: Optional[float] = None, return_logits: bool = False ) -> Tuple[torch.Tensor, Dict]: # force hard = True when we are in eval mode, as we must quantize. # actually, always true seems to work hard = self.straight_through if self.training else True temp = self.temperature if temp is None else temp out_dict = {} logits = self.proj(z) if self.remap is not None: # continue only with used logits full_zeros = torch.zeros_like(logits) logits = logits[:, self.used, ...] soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) if self.remap is not None: # go back to all entries but unused set to zero full_zeros[:, self.used, ...] = soft_one_hot soft_one_hot = full_zeros z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) # + kl divergence to the prior loss qy = F.softmax(logits, dim=1) diff = ( self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() ) out_dict[self.loss_key] = diff ind = soft_one_hot.argmax(dim=1) out_dict["indices"] = ind if self.remap is not None: ind = self.remap_to_used(ind) if return_logits: out_dict["logits"] = logits return z_q, out_dict def get_codebook_entry(self, indices, shape): # TODO: shape not yet optional b, h, w, c = shape assert b * h * w == indices.shape[0] indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w) if self.remap is not None: indices = self.unmap_to_all(indices) one_hot = ( F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() ) z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight) return z_q class VectorQuantizer(AbstractQuantizer): """ ____________________________________________ Discretization bottleneck part of the VQ-VAE. Inputs: - n_e : number of embeddings - e_dim : dimension of embedding - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 _____________________________________________ """ def __init__( self, n_e: int, e_dim: int, beta: float = 0.25, remap: Optional[str] = None, unknown_index: str = "random", sane_index_shape: bool = False, log_perplexity: bool = False, embedding_weight_norm: bool = False, loss_key: str = "loss/vq", ): super().__init__() self.n_e = n_e self.e_dim = e_dim self.beta = beta self.loss_key = loss_key if not embedding_weight_norm: self.embedding = nn.Embedding(self.n_e, self.e_dim) self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) else: self.embedding = torch.nn.utils.weight_norm( nn.Embedding(self.n_e, self.e_dim), dim=1 ) self.remap = remap if self.remap is not None: self.register_buffer("used", torch.tensor(np.load(self.remap))) self.re_embed = self.used.shape[0] else: self.used = None self.re_embed = n_e if unknown_index == "extra": self.unknown_index = self.re_embed self.re_embed = self.re_embed + 1 else: assert unknown_index == "random" or isinstance( unknown_index, int ), "unknown index needs to be 'random', 'extra' or any integer" self.unknown_index = unknown_index # "random" or "extra" or integer if self.remap is not None: logpy.info( f"Remapping {self.n_e} indices to {self.re_embed} indices. " f"Using {self.unknown_index} for unknown indices." ) self.sane_index_shape = sane_index_shape self.log_perplexity = log_perplexity def forward( self, z: torch.Tensor, ) -> Tuple[torch.Tensor, Dict]: do_reshape = z.ndim == 4 if do_reshape: # # reshape z -> (batch, height, width, channel) and flatten z = rearrange(z, "b c h w -> b h w c").contiguous() else: assert z.ndim < 4, "No reshaping strategy for inputs > 4 dimensions defined" z = z.contiguous() z_flattened = z.view(-1, self.e_dim) # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z d = ( torch.sum(z_flattened**2, dim=1, keepdim=True) + torch.sum(self.embedding.weight**2, dim=1) - 2 * torch.einsum( "bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n") ) ) min_encoding_indices = torch.argmin(d, dim=1) z_q = self.embedding(min_encoding_indices).view(z.shape) loss_dict = {} if self.log_perplexity: perplexity, cluster_usage = measure_perplexity( min_encoding_indices.detach(), self.n_e ) loss_dict.update({"perplexity": perplexity, "cluster_usage": cluster_usage}) # compute loss for embedding loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean( (z_q - z.detach()) ** 2 ) loss_dict[self.loss_key] = loss # preserve gradients z_q = z + (z_q - z).detach() # reshape back to match original input shape if do_reshape: z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() if self.remap is not None: min_encoding_indices = min_encoding_indices.reshape( z.shape[0], -1 ) # add batch axis min_encoding_indices = self.remap_to_used(min_encoding_indices) min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten if self.sane_index_shape: if do_reshape: min_encoding_indices = min_encoding_indices.reshape( z_q.shape[0], z_q.shape[2], z_q.shape[3] ) else: min_encoding_indices = rearrange( min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0] ) loss_dict["min_encoding_indices"] = min_encoding_indices return z_q, loss_dict def get_codebook_entry( self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None ) -> torch.Tensor: # shape specifying (batch, height, width, channel) if self.remap is not None: assert shape is not None, "Need to give shape for remap" indices = indices.reshape(shape[0], -1) # add batch axis indices = self.unmap_to_all(indices) indices = indices.reshape(-1) # flatten again # get quantized latent vectors z_q = self.embedding(indices) if shape is not None: z_q = z_q.view(shape) # reshape back to match original input shape z_q = z_q.permute(0, 3, 1, 2).contiguous() return z_q class EmbeddingEMA(nn.Module): def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): super().__init__() self.decay = decay self.eps = eps weight = torch.randn(num_tokens, codebook_dim) self.weight = nn.Parameter(weight, requires_grad=False) self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) self.update = True def forward(self, embed_id): return F.embedding(embed_id, self.weight) def cluster_size_ema_update(self, new_cluster_size): self.cluster_size.data.mul_(self.decay).add_( new_cluster_size, alpha=1 - self.decay ) def embed_avg_ema_update(self, new_embed_avg): self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) def weight_update(self, num_tokens): n = self.cluster_size.sum() smoothed_cluster_size = ( (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n ) # normalize embedding average with smoothed cluster size embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) self.weight.data.copy_(embed_normalized) class EMAVectorQuantizer(AbstractQuantizer): def __init__( self, n_embed: int, embedding_dim: int, beta: float, decay: float = 0.99, eps: float = 1e-5, remap: Optional[str] = None, unknown_index: str = "random", loss_key: str = "loss/vq", ): super().__init__() self.codebook_dim = embedding_dim self.num_tokens = n_embed self.beta = beta self.loss_key = loss_key self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) self.remap = remap if self.remap is not None: self.register_buffer("used", torch.tensor(np.load(self.remap))) self.re_embed = self.used.shape[0] else: self.used = None self.re_embed = n_embed if unknown_index == "extra": self.unknown_index = self.re_embed self.re_embed = self.re_embed + 1 else: assert unknown_index == "random" or isinstance( unknown_index, int ), "unknown index needs to be 'random', 'extra' or any integer" self.unknown_index = unknown_index # "random" or "extra" or integer if self.remap is not None: logpy.info( f"Remapping {self.n_embed} indices to {self.re_embed} indices. " f"Using {self.unknown_index} for unknown indices." ) def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]: # reshape z -> (batch, height, width, channel) and flatten # z, 'b c h w -> b h w c' z = rearrange(z, "b c h w -> b h w c") z_flattened = z.reshape(-1, self.codebook_dim) # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z d = ( z_flattened.pow(2).sum(dim=1, keepdim=True) + self.embedding.weight.pow(2).sum(dim=1) - 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight) ) # 'n d -> d n' encoding_indices = torch.argmin(d, dim=1) z_q = self.embedding(encoding_indices).view(z.shape) encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) avg_probs = torch.mean(encodings, dim=0) perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) if self.training and self.embedding.update: # EMA cluster size encodings_sum = encodings.sum(0) self.embedding.cluster_size_ema_update(encodings_sum) # EMA embedding average embed_sum = encodings.transpose(0, 1) @ z_flattened self.embedding.embed_avg_ema_update(embed_sum) # normalize embed_avg and update weight self.embedding.weight_update(self.num_tokens) # compute loss for embedding loss = self.beta * F.mse_loss(z_q.detach(), z) # preserve gradients z_q = z + (z_q - z).detach() # reshape back to match original input shape # z_q, 'b h w c -> b c h w' z_q = rearrange(z_q, "b h w c -> b c h w") out_dict = { self.loss_key: loss, "encodings": encodings, "encoding_indices": encoding_indices, "perplexity": perplexity, } return z_q, out_dict class VectorQuantizerWithInputProjection(VectorQuantizer): def __init__( self, input_dim: int, n_codes: int, codebook_dim: int, beta: float = 1.0, output_dim: Optional[int] = None, **kwargs, ): super().__init__(n_codes, codebook_dim, beta, **kwargs) self.proj_in = nn.Linear(input_dim, codebook_dim) self.output_dim = output_dim if output_dim is not None: self.proj_out = nn.Linear(codebook_dim, output_dim) else: self.proj_out = nn.Identity() def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]: rearr = False in_shape = z.shape if z.ndim > 3: rearr = self.output_dim is not None z = rearrange(z, "b c ... -> b (...) c") z = self.proj_in(z) z_q, loss_dict = super().forward(z) z_q = self.proj_out(z_q) if rearr: if len(in_shape) == 4: z_q = rearrange(z_q, "b (h w) c -> b c h w ", w=in_shape[-1]) elif len(in_shape) == 5: z_q = rearrange( z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2] ) else: raise NotImplementedError( f"rearranging not available for {len(in_shape)}-dimensional input." ) return z_q, loss_dict ================================================ FILE: sgm/modules/autoencoding/temporal_ae.py ================================================ from typing import Callable, Iterable, Union import torch from einops import rearrange, repeat from sgm.modules.diffusionmodules.model import (XFORMERS_IS_AVAILABLE, AttnBlock, Decoder, MemoryEfficientAttnBlock, ResnetBlock) from sgm.modules.diffusionmodules.openaimodel import (ResBlock, timestep_embedding) from sgm.modules.video_attention import VideoTransformerBlock from sgm.util import partialclass class VideoResBlock(ResnetBlock): def __init__( self, out_channels, *args, dropout=0.0, video_kernel_size=3, alpha=0.0, merge_strategy="learned", **kwargs, ): super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs) if video_kernel_size is None: video_kernel_size = [3, 1, 1] self.time_stack = ResBlock( channels=out_channels, emb_channels=0, dropout=dropout, dims=3, use_scale_shift_norm=False, use_conv=False, up=False, down=False, kernel_size=video_kernel_size, use_checkpoint=False, skip_t_emb=True, ) self.merge_strategy = merge_strategy if self.merge_strategy == "fixed": self.register_buffer("mix_factor", torch.Tensor([alpha])) elif self.merge_strategy == "learned": self.register_parameter( "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) ) else: raise ValueError(f"unknown merge strategy {self.merge_strategy}") def get_alpha(self, bs): if self.merge_strategy == "fixed": return self.mix_factor elif self.merge_strategy == "learned": return torch.sigmoid(self.mix_factor) else: raise NotImplementedError() def forward(self, x, temb, skip_video=False, timesteps=None): if timesteps is None: timesteps = self.timesteps b, c, h, w = x.shape x = super().forward(x, temb) if not skip_video: x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) x = self.time_stack(x, temb) alpha = self.get_alpha(bs=b // timesteps) x = alpha * x + (1.0 - alpha) * x_mix x = rearrange(x, "b c t h w -> (b t) c h w") return x class AE3DConv(torch.nn.Conv2d): def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): super().__init__(in_channels, out_channels, *args, **kwargs) if isinstance(video_kernel_size, Iterable): padding = [int(k // 2) for k in video_kernel_size] else: padding = int(video_kernel_size // 2) self.time_mix_conv = torch.nn.Conv3d( in_channels=out_channels, out_channels=out_channels, kernel_size=video_kernel_size, padding=padding, ) def forward(self, input, timesteps, skip_video=False): x = super().forward(input) if skip_video: return x x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) x = self.time_mix_conv(x) return rearrange(x, "b c t h w -> (b t) c h w") class VideoBlock(AttnBlock): def __init__( self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned" ): super().__init__(in_channels) # no context, single headed, as in base class self.time_mix_block = VideoTransformerBlock( dim=in_channels, n_heads=1, d_head=in_channels, checkpoint=False, ff_in=True, attn_mode="softmax", ) time_embed_dim = self.in_channels * 4 self.video_time_embed = torch.nn.Sequential( torch.nn.Linear(self.in_channels, time_embed_dim), torch.nn.SiLU(), torch.nn.Linear(time_embed_dim, self.in_channels), ) self.merge_strategy = merge_strategy if self.merge_strategy == "fixed": self.register_buffer("mix_factor", torch.Tensor([alpha])) elif self.merge_strategy == "learned": self.register_parameter( "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) ) else: raise ValueError(f"unknown merge strategy {self.merge_strategy}") def forward(self, x, timesteps, skip_video=False): if skip_video: return super().forward(x) x_in = x x = self.attention(x) h, w = x.shape[2:] x = rearrange(x, "b c h w -> b (h w) c") x_mix = x num_frames = torch.arange(timesteps, device=x.device) num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) num_frames = rearrange(num_frames, "b t -> (b t)") t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) emb = self.video_time_embed(t_emb) # b, n_channels emb = emb[:, None, :] x_mix = x_mix + emb alpha = self.get_alpha() x_mix = self.time_mix_block(x_mix, timesteps=timesteps) x = alpha * x + (1.0 - alpha) * x_mix # alpha merge x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) x = self.proj_out(x) return x_in + x def get_alpha( self, ): if self.merge_strategy == "fixed": return self.mix_factor elif self.merge_strategy == "learned": return torch.sigmoid(self.mix_factor) else: raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock): def __init__( self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned" ): super().__init__(in_channels) # no context, single headed, as in base class self.time_mix_block = VideoTransformerBlock( dim=in_channels, n_heads=1, d_head=in_channels, checkpoint=False, ff_in=True, attn_mode="softmax-xformers", ) time_embed_dim = self.in_channels * 4 self.video_time_embed = torch.nn.Sequential( torch.nn.Linear(self.in_channels, time_embed_dim), torch.nn.SiLU(), torch.nn.Linear(time_embed_dim, self.in_channels), ) self.merge_strategy = merge_strategy if self.merge_strategy == "fixed": self.register_buffer("mix_factor", torch.Tensor([alpha])) elif self.merge_strategy == "learned": self.register_parameter( "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) ) else: raise ValueError(f"unknown merge strategy {self.merge_strategy}") def forward(self, x, timesteps, skip_time_block=False): if skip_time_block: return super().forward(x) x_in = x x = self.attention(x) h, w = x.shape[2:] x = rearrange(x, "b c h w -> b (h w) c") x_mix = x num_frames = torch.arange(timesteps, device=x.device) num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) num_frames = rearrange(num_frames, "b t -> (b t)") t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) emb = self.video_time_embed(t_emb) # b, n_channels emb = emb[:, None, :] x_mix = x_mix + emb alpha = self.get_alpha() x_mix = self.time_mix_block(x_mix, timesteps=timesteps) x = alpha * x + (1.0 - alpha) * x_mix # alpha merge x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) x = self.proj_out(x) return x_in + x def get_alpha( self, ): if self.merge_strategy == "fixed": return self.mix_factor elif self.merge_strategy == "learned": return torch.sigmoid(self.mix_factor) else: raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") def make_time_attn( in_channels, attn_type="vanilla", attn_kwargs=None, alpha: float = 0, merge_strategy: str = "learned", ): assert attn_type in [ "vanilla", "vanilla-xformers", ], f"attn_type {attn_type} not supported for spatio-temporal attention" print( f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels" ) if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers": print( f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. " f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}" ) attn_type = "vanilla" if attn_type == "vanilla": assert attn_kwargs is None return partialclass( VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy ) elif attn_type == "vanilla-xformers": print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") return partialclass( MemoryEfficientVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy, ) else: return NotImplementedError() class Conv2DWrapper(torch.nn.Conv2d): def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor: return super().forward(input) class VideoDecoder(Decoder): available_time_modes = ["all", "conv-only", "attn-only"] def __init__( self, *args, video_kernel_size: Union[int, list] = 3, alpha: float = 0.0, merge_strategy: str = "learned", time_mode: str = "conv-only", **kwargs, ): self.video_kernel_size = video_kernel_size self.alpha = alpha self.merge_strategy = merge_strategy self.time_mode = time_mode assert ( self.time_mode in self.available_time_modes ), f"time_mode parameter has to be in {self.available_time_modes}" super().__init__(*args, **kwargs) def get_last_layer(self, skip_time_mix=False, **kwargs): if self.time_mode == "attn-only": raise NotImplementedError("TODO") else: return ( self.conv_out.time_mix_conv.weight if not skip_time_mix else self.conv_out.weight ) def _make_attn(self) -> Callable: if self.time_mode not in ["conv-only", "only-last-conv"]: return partialclass( make_time_attn, alpha=self.alpha, merge_strategy=self.merge_strategy, ) else: return super()._make_attn() def _make_conv(self) -> Callable: if self.time_mode != "attn-only": return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size) else: return Conv2DWrapper def _make_resblock(self) -> Callable: if self.time_mode not in ["attn-only", "only-last-conv"]: return partialclass( VideoResBlock, video_kernel_size=self.video_kernel_size, alpha=self.alpha, merge_strategy=self.merge_strategy, ) else: return super()._make_resblock() ================================================ FILE: sgm/modules/diffusionmodules/__init__.py ================================================ ================================================ FILE: sgm/modules/diffusionmodules/denoiser.py ================================================ from typing import Dict, Union import torch import torch.nn as nn from ...util import append_dims, instantiate_from_config from .denoiser_scaling import DenoiserScaling from .discretizer import Discretization class Denoiser(nn.Module): def __init__(self, scaling_config: Dict): super().__init__() self.scaling: DenoiserScaling = instantiate_from_config(scaling_config) def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: return sigma def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: return c_noise def forward( self, network: nn.Module, input: torch.Tensor, sigma: torch.Tensor, cond: Dict, **additional_model_inputs, ) -> torch.Tensor: sigma = self.possibly_quantize_sigma(sigma) sigma_shape = sigma.shape sigma = append_dims(sigma, input.ndim) c_skip, c_out, c_in, c_noise = self.scaling(sigma) c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) return ( network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out + input * c_skip ) class DiscreteDenoiser(Denoiser): def __init__( self, scaling_config: Dict, num_idx: int, discretization_config: Dict, do_append_zero: bool = False, quantize_c_noise: bool = True, flip: bool = True, ): super().__init__(scaling_config) self.discretization: Discretization = instantiate_from_config( discretization_config ) sigmas = self.discretization(num_idx, do_append_zero=do_append_zero, flip=flip) self.register_buffer("sigmas", sigmas) self.quantize_c_noise = quantize_c_noise self.num_idx = num_idx def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor: dists = sigma - self.sigmas[:, None] return dists.abs().argmin(dim=0).view(sigma.shape) def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor: return self.sigmas[idx] def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: return self.idx_to_sigma(self.sigma_to_idx(sigma)) def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: if self.quantize_c_noise: return self.sigma_to_idx(c_noise) else: return c_noise ================================================ FILE: sgm/modules/diffusionmodules/denoiser_scaling.py ================================================ from abc import ABC, abstractmethod from typing import Tuple import torch class DenoiserScaling(ABC): @abstractmethod def __call__( self, sigma: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: pass class EDMScaling: def __init__(self, sigma_data: float = 0.5): self.sigma_data = sigma_data def __call__( self, sigma: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 c_noise = 0.25 * sigma.log() return c_skip, c_out, c_in, c_noise class EpsScaling: def __call__( self, sigma: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: c_skip = torch.ones_like(sigma, device=sigma.device) c_out = -sigma c_in = 1 / (sigma**2 + 1.0) ** 0.5 c_noise = sigma.clone() return c_skip, c_out, c_in, c_noise class VScaling: def __call__( self, sigma: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: c_skip = 1.0 / (sigma**2 + 1.0) c_out = -sigma / (sigma**2 + 1.0) ** 0.5 c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 c_noise = sigma.clone() return c_skip, c_out, c_in, c_noise class VScalingWithEDMcNoise(DenoiserScaling): def __call__( self, sigma: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: c_skip = 1.0 / (sigma**2 + 1.0) c_out = -sigma / (sigma**2 + 1.0) ** 0.5 c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 c_noise = 0.25 * sigma.log() return c_skip, c_out, c_in, c_noise ================================================ FILE: sgm/modules/diffusionmodules/denoiser_weighting.py ================================================ import torch class UnitWeighting: def __call__(self, sigma): return torch.ones_like(sigma, device=sigma.device) class EDMWeighting: def __init__(self, sigma_data=0.5): self.sigma_data = sigma_data def __call__(self, sigma): return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 class VWeighting(EDMWeighting): def __init__(self): super().__init__(sigma_data=1.0) class EpsWeighting: def __call__(self, sigma): return sigma**-2.0 ================================================ FILE: sgm/modules/diffusionmodules/discretizer.py ================================================ from abc import abstractmethod from functools import partial import numpy as np import torch from ...modules.diffusionmodules.util import make_beta_schedule from ...util import append_zero def generate_roughly_equally_spaced_steps( num_substeps: int, max_step: int ) -> np.ndarray: return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] class Discretization: def __call__(self, n, do_append_zero=True, device="cpu", flip=False): sigmas = self.get_sigmas(n, device=device) sigmas = append_zero(sigmas) if do_append_zero else sigmas return sigmas if not flip else torch.flip(sigmas, (0,)) @abstractmethod def get_sigmas(self, n, device): pass class EDMDiscretization(Discretization): def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0): self.sigma_min = sigma_min self.sigma_max = sigma_max self.rho = rho def get_sigmas(self, n, device="cpu"): ramp = torch.linspace(0, 1, n, device=device) min_inv_rho = self.sigma_min ** (1 / self.rho) max_inv_rho = self.sigma_max ** (1 / self.rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho return sigmas class LegacyDDPMDiscretization(Discretization): def __init__( self, linear_start=0.00085, linear_end=0.0120, num_timesteps=1000, ): super().__init__() self.num_timesteps = num_timesteps betas = make_beta_schedule( "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end ) alphas = 1.0 - betas self.alphas_cumprod = np.cumprod(alphas, axis=0) self.to_torch = partial(torch.tensor, dtype=torch.float32) def get_sigmas(self, n, device="cpu"): if n < self.num_timesteps: timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) alphas_cumprod = self.alphas_cumprod[timesteps] elif n == self.num_timesteps: alphas_cumprod = self.alphas_cumprod else: raise ValueError to_torch = partial(torch.tensor, dtype=torch.float32, device=device) sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 return torch.flip(sigmas, (0,)) ================================================ FILE: sgm/modules/diffusionmodules/guiders.py ================================================ import logging from abc import ABC, abstractmethod from typing import Dict, List, Literal, Optional, Tuple, Union import torch from einops import rearrange, repeat from ...util import append_dims, default logpy = logging.getLogger(__name__) class Guider(ABC): @abstractmethod def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: pass def prepare_inputs( self, x: torch.Tensor, s: float, c: Dict, uc: Dict ) -> Tuple[torch.Tensor, float, Dict]: pass class VanillaCFG(Guider): def __init__(self, scale: float): self.scale = scale def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: x_u, x_c = x.chunk(2) x_pred = x_u + self.scale * (x_c - x_u) return x_pred def prepare_inputs(self, x, s, c, uc): c_out = dict() for k in c: if k in ["vector", "crossattn", "concat"]: c_out[k] = torch.cat((uc[k], c[k]), 0) else: assert c[k] == uc[k] c_out[k] = c[k] return torch.cat([x] * 2), torch.cat([s] * 2), c_out class IdentityGuider(Guider): def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: return x def prepare_inputs( self, x: torch.Tensor, s: float, c: Dict, uc: Dict ) -> Tuple[torch.Tensor, float, Dict]: c_out = dict() for k in c: c_out[k] = c[k] return x, s, c_out class LinearPredictionGuider(Guider): def __init__( self, max_scale: float, num_frames: int, min_scale: float = 1.0, additional_cond_keys: Optional[Union[List[str], str]] = None, ): self.min_scale = min_scale self.max_scale = max_scale self.num_frames = num_frames self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0) additional_cond_keys = default(additional_cond_keys, []) if isinstance(additional_cond_keys, str): additional_cond_keys = [additional_cond_keys] self.additional_cond_keys = additional_cond_keys def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: x_u, x_c = x.chunk(2) x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames) x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames) scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0]) scale = append_dims(scale, x_u.ndim).to(x_u.device) return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...") def prepare_inputs( self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict ) -> Tuple[torch.Tensor, torch.Tensor, dict]: c_out = dict() for k in c: if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys: c_out[k] = torch.cat((uc[k], c[k]), 0) else: # assert c[k] == uc[k] c_out[k] = c[k] return torch.cat([x] * 2), torch.cat([s] * 2), c_out class TrianglePredictionGuider(LinearPredictionGuider): def __init__( self, max_scale: float, num_frames: int, min_scale: float = 1.0, period: Union[float, List[float]] = 1.0, period_fusing: Literal["mean", "multiply", "max"] = "max", additional_cond_keys: Optional[Union[List[str], str]] = None, ): super().__init__(max_scale, num_frames, min_scale, additional_cond_keys) values = torch.linspace(0, 1, num_frames) # Constructs a triangle wave if isinstance(period, float): period = [period] scales = [] for p in period: scales.append(self.triangle_wave(values, p)) if period_fusing == "mean": scale = sum(scales) / len(period) elif period_fusing == "multiply": scale = torch.prod(torch.stack(scales), dim=0) elif period_fusing == "max": scale = torch.max(torch.stack(scales), dim=0).values self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0) def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor: return 2 * (values / period - torch.floor(values / period + 0.5)).abs() class TrapezoidPredictionGuider(LinearPredictionGuider): def __init__( self, max_scale: float, num_frames: int, min_scale: float = 1.0, edge_perc: float = 0.1, additional_cond_keys: Optional[Union[List[str], str]] = None, ): super().__init__(max_scale, num_frames, min_scale, additional_cond_keys) rise_steps = torch.linspace(min_scale, max_scale, int(num_frames * edge_perc)) fall_steps = torch.flip(rise_steps, [0]) self.scale = torch.cat( [ rise_steps, torch.ones(num_frames - 2 * int(num_frames * edge_perc)), fall_steps, ] ).unsqueeze(0) class SpatiotemporalPredictionGuider(LinearPredictionGuider): def __init__( self, max_scale: float, num_frames: int, num_views: int = 1, min_scale: float = 1.0, additional_cond_keys: Optional[Union[List[str], str]] = None, ): super().__init__(max_scale, num_frames, min_scale, additional_cond_keys) V = num_views T = num_frames // V scale = torch.zeros(num_frames).view(T, V) scale += torch.linspace(0, 1, T)[:,None] * 0.5 scale += self.triangle_wave(torch.linspace(0, 1, V))[None,:] * 0.5 scale = scale.flatten() self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0) def triangle_wave(self, values: torch.Tensor, period=1) -> torch.Tensor: return 2 * (values / period - torch.floor(values / period + 0.5)).abs() ================================================ FILE: sgm/modules/diffusionmodules/loss.py ================================================ from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn from ...modules.autoencoding.lpips.loss.lpips import LPIPS from ...modules.encoders.modules import GeneralConditioner from ...util import append_dims, instantiate_from_config from .denoiser import Denoiser class StandardDiffusionLoss(nn.Module): def __init__( self, sigma_sampler_config: dict, loss_weighting_config: dict, loss_type: str = "l2", offset_noise_level: float = 0.0, batch2model_keys: Optional[Union[str, List[str]]] = None, ): super().__init__() assert loss_type in ["l2", "l1", "lpips"] self.sigma_sampler = instantiate_from_config(sigma_sampler_config) self.loss_weighting = instantiate_from_config(loss_weighting_config) self.loss_type = loss_type self.offset_noise_level = offset_noise_level if loss_type == "lpips": self.lpips = LPIPS().eval() if not batch2model_keys: batch2model_keys = [] if isinstance(batch2model_keys, str): batch2model_keys = [batch2model_keys] self.batch2model_keys = set(batch2model_keys) def get_noised_input( self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor ) -> torch.Tensor: noised_input = input + noise * sigmas_bc return noised_input def forward( self, network: nn.Module, denoiser: Denoiser, conditioner: GeneralConditioner, input: torch.Tensor, batch: Dict, ) -> torch.Tensor: cond = conditioner(batch) return self._forward(network, denoiser, cond, input, batch) def _forward( self, network: nn.Module, denoiser: Denoiser, cond: Dict, input: torch.Tensor, batch: Dict, ) -> Tuple[torch.Tensor, Dict]: additional_model_inputs = { key: batch[key] for key in self.batch2model_keys.intersection(batch) } sigmas = self.sigma_sampler(input.shape[0]).to(input) noise = torch.randn_like(input) if self.offset_noise_level > 0.0: offset_shape = ( (input.shape[0], 1, input.shape[2]) if self.n_frames is not None else (input.shape[0], input.shape[1]) ) noise = noise + self.offset_noise_level * append_dims( torch.randn(offset_shape, device=input.device), input.ndim, ) sigmas_bc = append_dims(sigmas, input.ndim) noised_input = self.get_noised_input(sigmas_bc, noise, input) model_output = denoiser( network, noised_input, sigmas, cond, **additional_model_inputs ) w = append_dims(self.loss_weighting(sigmas), input.ndim) return self.get_loss(model_output, input, w) def get_loss(self, model_output, target, w): if self.loss_type == "l2": return torch.mean( (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1 ) elif self.loss_type == "l1": return torch.mean( (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1 ) elif self.loss_type == "lpips": loss = self.lpips(model_output, target).reshape(-1) return loss else: raise NotImplementedError(f"Unknown loss type {self.loss_type}") ================================================ FILE: sgm/modules/diffusionmodules/loss_weighting.py ================================================ from abc import ABC, abstractmethod import torch class DiffusionLossWeighting(ABC): @abstractmethod def __call__(self, sigma: torch.Tensor) -> torch.Tensor: pass class UnitWeighting(DiffusionLossWeighting): def __call__(self, sigma: torch.Tensor) -> torch.Tensor: return torch.ones_like(sigma, device=sigma.device) class EDMWeighting(DiffusionLossWeighting): def __init__(self, sigma_data: float = 0.5): self.sigma_data = sigma_data def __call__(self, sigma: torch.Tensor) -> torch.Tensor: return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 class VWeighting(EDMWeighting): def __init__(self): super().__init__(sigma_data=1.0) class EpsWeighting(DiffusionLossWeighting): def __call__(self, sigma: torch.Tensor) -> torch.Tensor: return sigma**-2.0 ================================================ FILE: sgm/modules/diffusionmodules/model.py ================================================ # pytorch_diffusion + derived encoder decoder import logging import math from typing import Any, Callable, Optional import numpy as np import torch import torch.nn as nn from einops import rearrange from packaging import version logpy = logging.getLogger(__name__) try: import xformers import xformers.ops XFORMERS_IS_AVAILABLE = True except: XFORMERS_IS_AVAILABLE = False logpy.warning("no module 'xformers'. Processing without...") from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention def get_timestep_embedding(timesteps, embedding_dim): """ This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ assert len(timesteps.shape) == 1 half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) emb = emb.to(device=timesteps.device) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb def nonlinearity(x): # swish return x * torch.sigmoid(x) def Normalize(in_channels, num_groups=32): return torch.nn.GroupNorm( num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True ) class Upsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: self.conv = torch.nn.Conv2d( in_channels, in_channels, kernel_size=3, stride=1, padding=1 ) def forward(self, x): x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") if self.with_conv: x = self.conv(x) return x class Downsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves self.conv = torch.nn.Conv2d( in_channels, in_channels, kernel_size=3, stride=2, padding=0 ) def forward(self, x): if self.with_conv: pad = (0, 1, 0, 1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) return x class ResnetBlock(nn.Module): def __init__( self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512, ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels) self.conv1 = torch.nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, padding=1 ) if temb_channels > 0: self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout) self.conv2 = torch.nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1 ) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = torch.nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, padding=1 ) else: self.nin_shortcut = torch.nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0 ) def forward(self, x, temb): h = x h = self.norm1(h) h = nonlinearity(h) h = self.conv1(h) if temb is not None: h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] h = self.norm2(h) h = nonlinearity(h) h = self.dropout(h) h = self.conv2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) return x + h class LinAttnBlock(LinearAttention): """to match AttnBlock usage""" def __init__(self, in_channels): super().__init__(dim=in_channels, heads=1, dim_head=in_channels) class AttnBlock(nn.Module): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.k = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.v = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.proj_out = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) def attention(self, h_: torch.Tensor) -> torch.Tensor: h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) b, c, h, w = q.shape q, k, v = map( lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v) ) h_ = torch.nn.functional.scaled_dot_product_attention( q, k, v ) # scale is dim ** -0.5 per default # compute attention return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) def forward(self, x, **kwargs): h_ = x h_ = self.attention(h_) h_ = self.proj_out(h_) return x + h_ class MemoryEfficientAttnBlock(nn.Module): """ Uses xformers efficient implementation, see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 Note: this is a single-head self-attention operation """ # def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.k = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.v = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.proj_out = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.attention_op: Optional[Any] = None def attention(self, h_: torch.Tensor) -> torch.Tensor: h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention B, C, H, W = q.shape q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) q, k, v = map( lambda t: t.unsqueeze(3) .reshape(B, t.shape[1], 1, C) .permute(0, 2, 1, 3) .reshape(B * 1, t.shape[1], C) .contiguous(), (q, k, v), ) out = xformers.ops.memory_efficient_attention( q, k, v, attn_bias=None, op=self.attention_op ) out = ( out.unsqueeze(0) .reshape(B, 1, out.shape[1], C) .permute(0, 2, 1, 3) .reshape(B, out.shape[1], C) ) return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) def forward(self, x, **kwargs): h_ = x h_ = self.attention(h_) h_ = self.proj_out(h_) return x + h_ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): def forward(self, x, context=None, mask=None, **unused_kwargs): b, c, h, w = x.shape x = rearrange(x, "b c h w -> b (h w) c") out = super().forward(x, context=context, mask=mask) out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c) return x + out def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): assert attn_type in [ "vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none", ], f"attn_type {attn_type} unknown" if ( version.parse(torch.__version__) < version.parse("2.0.0") and attn_type != "none" ): assert XFORMERS_IS_AVAILABLE, ( f"We do not support vanilla attention in {torch.__version__} anymore, " f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'" ) attn_type = "vanilla-xformers" logpy.info(f"making attention of type '{attn_type}' with {in_channels} in_channels") if attn_type == "vanilla": assert attn_kwargs is None return AttnBlock(in_channels) elif attn_type == "vanilla-xformers": logpy.info( f"building MemoryEfficientAttnBlock with {in_channels} in_channels..." ) return MemoryEfficientAttnBlock(in_channels) elif type == "memory-efficient-cross-attn": attn_kwargs["query_dim"] = in_channels return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) elif attn_type == "none": return nn.Identity(in_channels) else: return LinAttnBlock(in_channels) class Model(nn.Module): def __init__( self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla", ): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = self.ch * 4 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.use_timestep = use_timestep if self.use_timestep: # timestep embedding self.temb = nn.Module() self.temb.dense = nn.ModuleList( [ torch.nn.Linear(self.ch, self.temb_ch), torch.nn.Linear(self.temb_ch, self.temb_ch), ] ) # downsampling self.conv_in = torch.nn.Conv2d( in_channels, self.ch, kernel_size=3, stride=1, padding=1 ) curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append( ResnetBlock( in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.block_2 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch * ch_mult[i_level] skip_in = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): if i_block == self.num_res_blocks: skip_in = ch * in_ch_mult[i_level] block.append( ResnetBlock( in_channels=block_in + skip_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d( block_in, out_ch, kernel_size=3, stride=1, padding=1 ) def forward(self, x, t=None, context=None): # assert x.shape[2] == x.shape[3] == self.resolution if context is not None: # assume aligned context, cat along channel axis x = torch.cat((x, context), dim=1) if self.use_timestep: # timestep embedding assert t is not None temb = get_timestep_embedding(t, self.ch) temb = self.temb.dense[0](temb) temb = nonlinearity(temb) temb = self.temb.dense[1](temb) else: temb = None # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] h = self.mid.block_1(h, temb) h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block]( torch.cat([h, hs.pop()], dim=1), temb ) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: h = self.up[i_level].upsample(h) # end h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) return h def get_last_layer(self): return self.conv_out.weight class Encoder(nn.Module): def __init__( self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", **ignore_kwargs, ): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels # downsampling self.conv_in = torch.nn.Conv2d( in_channels, self.ch, kernel_size=3, stride=1, padding=1 ) curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append( ResnetBlock( in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.block_2 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d( block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1, ) def forward(self, x): # timestep embedding temb = None # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] h = self.mid.block_1(h, temb) h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) # end h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) return h class Decoder(nn.Module): def __init__( self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, attn_type="vanilla", **ignorekwargs, ): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.give_pre_end = give_pre_end self.tanh_out = tanh_out # compute in_ch_mult, block_in and curr_res at lowest res in_ch_mult = (1,) + tuple(ch_mult) block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) logpy.info( "Working with z of shape {} = {} dimensions.".format( self.z_shape, np.prod(self.z_shape) ) ) make_attn_cls = self._make_attn() make_resblock_cls = self._make_resblock() make_conv_cls = self._make_conv() # z to block_in self.conv_in = torch.nn.Conv2d( z_channels, block_in, kernel_size=3, stride=1, padding=1 ) # middle self.mid = nn.Module() self.mid.block_1 = make_resblock_cls( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type) self.mid.block_2 = make_resblock_cls( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): block.append( make_resblock_cls( in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn_cls(block_in, attn_type=attn_type)) up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = make_conv_cls( block_in, out_ch, kernel_size=3, stride=1, padding=1 ) def _make_attn(self) -> Callable: return make_attn def _make_resblock(self) -> Callable: return ResnetBlock def _make_conv(self) -> Callable: return torch.nn.Conv2d def get_last_layer(self, **kwargs): return self.conv_out.weight def forward(self, z, **kwargs): # assert z.shape[1:] == self.z_shape[1:] self.last_z_shape = z.shape # timestep embedding temb = None # z to block_in h = self.conv_in(z) # middle h = self.mid.block_1(h, temb, **kwargs) h = self.mid.attn_1(h, **kwargs) h = self.mid.block_2(h, temb, **kwargs) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h, temb, **kwargs) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h, **kwargs) if i_level != 0: h = self.up[i_level].upsample(h) # end if self.give_pre_end: return h h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h, **kwargs) if self.tanh_out: h = torch.tanh(h) return h ================================================ FILE: sgm/modules/diffusionmodules/openaimodel.py ================================================ import logging import math from abc import abstractmethod from typing import Iterable, List, Optional, Tuple, Union import torch as th import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torch.utils.checkpoint import checkpoint from ...modules.attention import SpatialTransformer from ...modules.diffusionmodules.util import (avg_pool_nd, conv_nd, linear, normalization, timestep_embedding, zero_module) from ...modules.video_attention import SpatialVideoTransformer from ...util import exists logpy = logging.getLogger(__name__) class AttentionPool2d(nn.Module): """ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py """ def __init__( self, spacial_dim: int, embed_dim: int, num_heads_channels: int, output_dim: Optional[int] = None, ): super().__init__() self.positional_embedding = nn.Parameter( th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 ) self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) self.num_heads = embed_dim // num_heads_channels self.attention = QKVAttention(self.num_heads) def forward(self, x: th.Tensor) -> th.Tensor: b, c, _ = x.shape x = x.reshape(b, c, -1) x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) x = x + self.positional_embedding[None, :, :].to(x.dtype) x = self.qkv_proj(x) x = self.attention(x) x = self.c_proj(x) return x[:, :, 0] class TimestepBlock(nn.Module): """ Any module where forward() takes timestep embeddings as a second argument. """ @abstractmethod def forward(self, x: th.Tensor, emb: th.Tensor): """ Apply the module to `x` given `emb` timestep embeddings. """ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): """ A sequential module that passes timestep embeddings to the children that support it as an extra input. """ def forward( self, x: th.Tensor, emb: th.Tensor, context: Optional[th.Tensor] = None, cam: Optional[th.Tensor] = None, image_only_indicator: Optional[th.Tensor] = None, cond_view: Optional[th.Tensor] = None, cond_motion: Optional[th.Tensor] = None, time_context: Optional[int] = None, num_video_frames: Optional[int] = None, time_step: Optional[int] = None, name: Optional[str] = None, ): from ...modules.diffusionmodules.video_model import VideoResBlock, PostHocResBlockWithTime from ...modules.spacetime_attention import ( BasicTransformerTimeMixBlock, PostHocSpatialTransformerWithTimeMixing, PostHocSpatialTransformerWithTimeMixingAndMotion, ) for layer in self: module = layer if isinstance( module, ( BasicTransformerTimeMixBlock, PostHocSpatialTransformerWithTimeMixing, ), ): x = layer( x, context, emb, time_context, num_video_frames, image_only_indicator, cond_view, cond_motion, time_step, name, ) elif isinstance( module, ( PostHocSpatialTransformerWithTimeMixingAndMotion, ), ): x = layer( x, context, emb, time_context, num_video_frames, image_only_indicator, cond_view, cond_motion, time_step, name, ) elif isinstance(module, SpatialVideoTransformer): x = layer( x, context, time_context, num_video_frames, image_only_indicator, # time_step, ) elif isinstance(module, PostHocResBlockWithTime): x = layer(x, emb, num_video_frames, image_only_indicator) elif isinstance(module, VideoResBlock): x = layer(x, emb, num_video_frames, image_only_indicator) elif isinstance(module, TimestepBlock) and not isinstance( module, VideoResBlock ): x = layer(x, emb) elif isinstance(module, SpatialTransformer): x = layer(x, context) else: x = layer(x) return x class Upsample(nn.Module): """ An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions. """ def __init__( self, channels: int, use_conv: bool, dims: int = 2, out_channels: Optional[int] = None, padding: int = 1, third_up: bool = False, kernel_size: int = 3, scale_factor: int = 2, ): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims self.third_up = third_up self.scale_factor = scale_factor if use_conv: self.conv = conv_nd( dims, self.channels, self.out_channels, kernel_size, padding=padding ) def forward(self, x: th.Tensor) -> th.Tensor: assert x.shape[1] == self.channels if self.dims == 3: t_factor = 1 if not self.third_up else self.scale_factor x = F.interpolate( x, ( t_factor * x.shape[2], x.shape[3] * self.scale_factor, x.shape[4] * self.scale_factor, ), mode="nearest", ) else: x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest") if self.use_conv: x = self.conv(x) return x class Downsample(nn.Module): """ A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions. """ def __init__( self, channels: int, use_conv: bool, dims: int = 2, out_channels: Optional[int] = None, padding: int = 1, third_down: bool = False, ): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2)) if use_conv: logpy.info(f"Building a Downsample layer with {dims} dims.") logpy.info( f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, " f"kernel-size: 3, stride: {stride}, padding: {padding}" ) if dims == 3: logpy.info(f" --> Downsampling third axis (time): {third_down}") self.op = conv_nd( dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, ) else: assert self.channels == self.out_channels self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) def forward(self, x: th.Tensor) -> th.Tensor: assert x.shape[1] == self.channels return self.op(x) class ResBlock(TimestepBlock): """ A residual block that can optionally change the number of channels. :param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use a spatial convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. :param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for downsampling. """ def __init__( self, channels: int, emb_channels: int, dropout: float, out_channels: Optional[int] = None, use_conv: bool = False, use_scale_shift_norm: bool = False, dims: int = 2, use_checkpoint: bool = False, up: bool = False, down: bool = False, kernel_size: int = 3, exchange_temb_dims: bool = False, skip_t_emb: bool = False, ): super().__init__() self.channels = channels self.emb_channels = emb_channels self.dropout = dropout self.out_channels = out_channels or channels self.use_conv = use_conv self.use_checkpoint = use_checkpoint self.use_scale_shift_norm = use_scale_shift_norm self.exchange_temb_dims = exchange_temb_dims if isinstance(kernel_size, Iterable): padding = [k // 2 for k in kernel_size] else: padding = kernel_size // 2 self.in_layers = nn.Sequential( normalization(channels), nn.SiLU(), conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), ) self.updown = up or down if up: self.h_upd = Upsample(channels, False, dims) self.x_upd = Upsample(channels, False, dims) elif down: self.h_upd = Downsample(channels, False, dims) self.x_upd = Downsample(channels, False, dims) else: self.h_upd = self.x_upd = nn.Identity() self.skip_t_emb = skip_t_emb self.emb_out_channels = ( 2 * self.out_channels if use_scale_shift_norm else self.out_channels ) if self.skip_t_emb: logpy.info(f"Skipping timestep embedding in {self.__class__.__name__}") assert not self.use_scale_shift_norm self.emb_layers = None self.exchange_temb_dims = False else: self.emb_layers = nn.Sequential( nn.SiLU(), linear( emb_channels, self.emb_out_channels, ), ) self.out_layers = nn.Sequential( normalization(self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), zero_module( conv_nd( dims, self.out_channels, self.out_channels, kernel_size, padding=padding, ) ), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: self.skip_connection = conv_nd( dims, channels, self.out_channels, kernel_size, padding=padding ) else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) def forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor: """ Apply the block to a Tensor, conditioned on a timestep embedding. :param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings. :return: an [N x C x ...] Tensor of outputs. """ if self.use_checkpoint: return checkpoint(self._forward, x, emb) else: return self._forward(x, emb) def _forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor: if self.updown: in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] h = in_rest(x) h = self.h_upd(h) x = self.x_upd(x) h = in_conv(h) else: h = self.in_layers(x) if self.skip_t_emb: emb_out = th.zeros_like(h) else: emb_out = self.emb_layers(emb).type(h.dtype) while len(emb_out.shape) < len(h.shape): emb_out = emb_out[..., None] if self.use_scale_shift_norm: out_norm, out_rest = self.out_layers[0], self.out_layers[1:] scale, shift = th.chunk(emb_out, 2, dim=1) h = out_norm(h) * (1 + scale) + shift h = out_rest(h) else: if self.exchange_temb_dims: emb_out = rearrange(emb_out, "b t c ... -> b c t ...") h = h + emb_out h = self.out_layers(h) return self.skip_connection(x) + h class AttentionBlock(nn.Module): """ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted to the N-d case. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. """ def __init__( self, channels: int, num_heads: int = 1, num_head_channels: int = -1, use_checkpoint: bool = False, use_new_attention_order: bool = False, ): super().__init__() self.channels = channels if num_head_channels == -1: self.num_heads = num_heads else: assert ( channels % num_head_channels == 0 ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" self.num_heads = channels // num_head_channels self.use_checkpoint = use_checkpoint self.norm = normalization(channels) self.qkv = conv_nd(1, channels, channels * 3, 1) if use_new_attention_order: # split qkv before split heads self.attention = QKVAttention(self.num_heads) else: # split heads before split qkv self.attention = QKVAttentionLegacy(self.num_heads) self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) def forward(self, x: th.Tensor, **kwargs) -> th.Tensor: return checkpoint(self._forward, x) def _forward(self, x: th.Tensor) -> th.Tensor: b, c, *spatial = x.shape x = x.reshape(b, c, -1) qkv = self.qkv(self.norm(x)) h = self.attention(qkv) h = self.proj_out(h) return (x + h).reshape(b, c, *spatial) class QKVAttentionLegacy(nn.Module): """ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping """ def __init__(self, n_heads: int): super().__init__() self.n_heads = n_heads def forward(self, qkv: th.Tensor) -> th.Tensor: """ Apply QKV attention. :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after attention. """ bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = th.einsum( "bct,bcs->bts", q * scale, k * scale ) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) a = th.einsum("bts,bcs->bct", weight, v) return a.reshape(bs, -1, length) class QKVAttention(nn.Module): """ A module which performs QKV attention and splits in a different order. """ def __init__(self, n_heads: int): super().__init__() self.n_heads = n_heads def forward(self, qkv: th.Tensor) -> th.Tensor: """ Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after attention. """ bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) q, k, v = qkv.chunk(3, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = th.einsum( "bct,bcs->bts", (q * scale).view(bs * self.n_heads, ch, length), (k * scale).view(bs * self.n_heads, ch, length), ) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) return a.reshape(bs, -1, length) class Timestep(nn.Module): def __init__(self, dim: int): super().__init__() self.dim = dim def forward(self, t: th.Tensor) -> th.Tensor: return timestep_embedding(t, self.dim) class UNetModel(nn.Module): """ The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample rates at which attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x downsampling, attention will be used. :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param conv_resample: if True, use learned convolutions for upsampling and downsampling. :param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this model will be class-conditional with `num_classes` classes. :param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use a fixed channel width per attention head. :param num_heads_upsample: works with num_heads to set a different number of heads for upsampling. Deprecated. :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially increased efficiency. """ def __init__( self, in_channels: int, model_channels: int, out_channels: int, num_res_blocks: int, attention_resolutions: int, dropout: float = 0.0, channel_mult: Union[List, Tuple] = (1, 2, 4, 8), conv_resample: bool = True, dims: int = 2, num_classes: Optional[Union[int, str]] = None, use_checkpoint: bool = False, num_heads: int = -1, num_head_channels: int = -1, num_heads_upsample: int = -1, use_scale_shift_norm: bool = False, resblock_updown: bool = False, transformer_depth: int = 1, context_dim: Optional[int] = None, disable_self_attentions: Optional[List[bool]] = None, num_attention_blocks: Optional[List[int]] = None, disable_middle_self_attn: bool = False, disable_middle_transformer: bool = False, use_linear_in_transformer: bool = False, spatial_transformer_attn_type: str = "softmax", adm_in_channels: Optional[int] = None, ): super().__init__() if num_heads_upsample == -1: num_heads_upsample = num_heads if num_heads == -1: assert ( num_head_channels != -1 ), "Either num_heads or num_head_channels has to be set" if num_head_channels == -1: assert ( num_heads != -1 ), "Either num_heads or num_head_channels has to be set" self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels if isinstance(transformer_depth, int): transformer_depth = len(channel_mult) * [transformer_depth] transformer_depth_middle = transformer_depth[-1] if isinstance(num_res_blocks, int): self.num_res_blocks = len(channel_mult) * [num_res_blocks] else: if len(num_res_blocks) != len(channel_mult): raise ValueError( "provide num_res_blocks either as an int (globally constant) or " "as a list/tuple (per-level) with the same length as channel_mult" ) self.num_res_blocks = num_res_blocks if disable_self_attentions is not None: assert len(disable_self_attentions) == len(channel_mult) if num_attention_blocks is not None: assert len(num_attention_blocks) == len(self.num_res_blocks) assert all( map( lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)), ) ) logpy.info( f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " f"This option has LESS priority than attention_resolutions {attention_resolutions}, " f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " f"attention will still not be set." ) self.attention_resolutions = attention_resolutions self.dropout = dropout self.channel_mult = channel_mult self.conv_resample = conv_resample self.num_classes = num_classes self.use_checkpoint = use_checkpoint self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ) if self.num_classes is not None: if isinstance(self.num_classes, int): self.label_emb = nn.Embedding(num_classes, time_embed_dim) elif self.num_classes == "continuous": logpy.info("setting up linear c_adm embedding layer") self.label_emb = nn.Linear(1, time_embed_dim) elif self.num_classes == "timestep": self.label_emb = nn.Sequential( Timestep(model_channels), nn.Sequential( linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ), ) elif self.num_classes == "sequential": assert adm_in_channels is not None self.label_emb = nn.Sequential( nn.Sequential( linear(adm_in_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ) ) else: raise ValueError self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( conv_nd(dims, in_channels, model_channels, 3, padding=1) ) ] ) self._feature_size = model_channels input_block_chans = [model_channels] ch = model_channels ds = 1 for level, mult in enumerate(channel_mult): for nr in range(self.num_res_blocks[level]): layers = [ ResBlock( ch, time_embed_dim, dropout, out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = mult * model_channels if ds in attention_resolutions: if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels if context_dim is not None and exists(disable_self_attentions): disabled_sa = disable_self_attentions[level] else: disabled_sa = False if ( not exists(num_attention_blocks) or nr < num_attention_blocks[level] ): layers.append( SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, attn_type=spatial_transformer_attn_type, use_checkpoint=use_checkpoint, ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) if level != len(channel_mult) - 1: out_ch = ch self.input_blocks.append( TimestepEmbedSequential( ResBlock( ch, time_embed_dim, dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True, ) if resblock_updown else Downsample( ch, conv_resample, dims=dims, out_channels=out_ch ) ) ) ch = out_ch input_block_chans.append(ch) ds *= 2 self._feature_size += ch if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels self.middle_block = TimestepEmbedSequential( ResBlock( ch, time_embed_dim, dropout, out_channels=ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, attn_type=spatial_transformer_attn_type, use_checkpoint=use_checkpoint, ) if not disable_middle_transformer else th.nn.Identity(), ResBlock( ch, time_embed_dim, dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), ) self._feature_size += ch self.output_blocks = nn.ModuleList([]) for level, mult in list(enumerate(channel_mult))[::-1]: for i in range(self.num_res_blocks[level] + 1): ich = input_block_chans.pop() layers = [ ResBlock( ch + ich, time_embed_dim, dropout, out_channels=model_channels * mult, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = model_channels * mult if ds in attention_resolutions: if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels if exists(disable_self_attentions): disabled_sa = disable_self_attentions[level] else: disabled_sa = False if ( not exists(num_attention_blocks) or i < num_attention_blocks[level] ): layers.append( SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, attn_type=spatial_transformer_attn_type, use_checkpoint=use_checkpoint, ) ) if level and i == self.num_res_blocks[level]: out_ch = ch layers.append( ResBlock( ch, time_embed_dim, dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, up=True, ) if resblock_updown else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) ) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch self.out = nn.Sequential( normalization(ch), nn.SiLU(), zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), ) def forward( self, x: th.Tensor, timesteps: Optional[th.Tensor] = None, context: Optional[th.Tensor] = None, y: Optional[th.Tensor] = None, **kwargs, ) -> th.Tensor: """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :param context: conditioning plugged in via crossattn :param y: an [N] Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs. """ assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" hs = [] t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) emb = self.time_embed(t_emb) if self.num_classes is not None: assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) h = x for module in self.input_blocks: h = module(h, emb, context) hs.append(h) h = self.middle_block(h, emb, context) for module in self.output_blocks: h = th.cat([h, hs.pop()], dim=1) h = module(h, emb, context) h = h.type(x.dtype) return self.out(h) ================================================ FILE: sgm/modules/diffusionmodules/sampling.py ================================================ """ Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py """ from typing import Dict, Union import torch from omegaconf import ListConfig, OmegaConf from tqdm import tqdm from ...modules.diffusionmodules.sampling_utils import (get_ancestral_step, linear_multistep_coeff, to_d, to_neg_log_sigma, to_sigma) from ...util import append_dims, default, instantiate_from_config DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} class BaseDiffusionSampler: def __init__( self, discretization_config: Union[Dict, ListConfig, OmegaConf], num_steps: Union[int, None] = None, guider_config: Union[Dict, ListConfig, OmegaConf, None] = None, verbose: bool = False, device: str = "cuda", ): self.num_steps = num_steps self.discretization = instantiate_from_config(discretization_config) self.guider = instantiate_from_config( default( guider_config, DEFAULT_GUIDER, ) ) self.verbose = verbose self.device = device def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): sigmas = self.discretization( self.num_steps if num_steps is None else num_steps, device=self.device ) uc = default(uc, cond) x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) num_sigmas = len(sigmas) s_in = x.new_ones([x.shape[0]]) return x, s_in, sigmas, num_sigmas, cond, uc def denoise(self, x, denoiser, sigma, cond, uc): denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc)) denoised = self.guider(denoised, sigma) return denoised def get_sigma_gen(self, num_sigmas): sigma_generator = range(num_sigmas - 1) if self.verbose: print("#" * 30, " Sampling setting ", "#" * 30) print(f"Sampler: {self.__class__.__name__}") print(f"Discretization: {self.discretization.__class__.__name__}") print(f"Guider: {self.guider.__class__.__name__}") sigma_generator = tqdm( sigma_generator, total=num_sigmas, desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps", ) return sigma_generator class SingleStepDiffusionSampler(BaseDiffusionSampler): def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs): raise NotImplementedError def euler_step(self, x, d, dt): return x + dt * d class EDMSampler(SingleStepDiffusionSampler): def __init__( self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs ): super().__init__(*args, **kwargs) self.s_churn = s_churn self.s_tmin = s_tmin self.s_tmax = s_tmax self.s_noise = s_noise def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0): sigma_hat = sigma * (gamma + 1.0) if gamma > 0: eps = torch.randn_like(x) * self.s_noise x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 denoised = self.denoise(x, denoiser, sigma_hat, cond, uc) d = to_d(x, sigma_hat, denoised) dt = append_dims(next_sigma - sigma_hat, x.ndim) euler_step = self.euler_step(x, d, dt) x = self.possible_correction_step( euler_step, x, d, dt, next_sigma, denoiser, cond, uc ) return x def __call__(self, denoiser, x, cond, uc=None, num_steps=None): x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( x, cond, uc, num_steps ) for i in self.get_sigma_gen(num_sigmas): gamma = ( min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 ) x = self.sampler_step( s_in * sigmas[i], s_in * sigmas[i + 1], denoiser, x, cond, uc, gamma, ) return x class AncestralSampler(SingleStepDiffusionSampler): def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs): super().__init__(*args, **kwargs) self.eta = eta self.s_noise = s_noise self.noise_sampler = lambda x: torch.randn_like(x) def ancestral_euler_step(self, x, denoised, sigma, sigma_down): d = to_d(x, sigma, denoised) dt = append_dims(sigma_down - sigma, x.ndim) return self.euler_step(x, d, dt) def ancestral_step(self, x, sigma, next_sigma, sigma_up): x = torch.where( append_dims(next_sigma, x.ndim) > 0.0, x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim), x, ) return x def __call__(self, denoiser, x, cond, uc=None, num_steps=None): x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( x, cond, uc, num_steps ) for i in self.get_sigma_gen(num_sigmas): x = self.sampler_step( s_in * sigmas[i], s_in * sigmas[i + 1], denoiser, x, cond, uc, ) return x class LinearMultistepSampler(BaseDiffusionSampler): def __init__( self, order=4, *args, **kwargs, ): super().__init__(*args, **kwargs) self.order = order def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( x, cond, uc, num_steps ) ds = [] sigmas_cpu = sigmas.detach().cpu().numpy() for i in self.get_sigma_gen(num_sigmas): sigma = s_in * sigmas[i] denoised = denoiser( *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs ) denoised = self.guider(denoised, sigma) d = to_d(x, sigma, denoised) ds.append(d) if len(ds) > self.order: ds.pop(0) cur_order = min(i + 1, self.order) coeffs = [ linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order) ] x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) return x class EulerEDMSampler(EDMSampler): def possible_correction_step( self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc ): return euler_step class HeunEDMSampler(EDMSampler): def possible_correction_step( self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc ): if torch.sum(next_sigma) < 1e-14: # Save a network evaluation if all noise levels are 0 return euler_step else: denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc) d_new = to_d(euler_step, next_sigma, denoised) d_prime = (d + d_new) / 2.0 # apply correction if noise level is not 0 x = torch.where( append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step ) return x class EulerAncestralSampler(AncestralSampler): def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc): sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) denoised = self.denoise(x, denoiser, sigma, cond, uc) x = self.ancestral_euler_step(x, denoised, sigma, sigma_down) x = self.ancestral_step(x, sigma, next_sigma, sigma_up) return x class DPMPP2SAncestralSampler(AncestralSampler): def get_variables(self, sigma, sigma_down): t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)] h = t_next - t s = t + 0.5 * h return h, s, t, t_next def get_mult(self, h, s, t, t_next): mult1 = to_sigma(s) / to_sigma(t) mult2 = (-0.5 * h).expm1() mult3 = to_sigma(t_next) / to_sigma(t) mult4 = (-h).expm1() return mult1, mult2, mult3, mult4 def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs): sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) denoised = self.denoise(x, denoiser, sigma, cond, uc) x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down) if torch.sum(sigma_down) < 1e-14: # Save a network evaluation if all noise levels are 0 x = x_euler else: h, s, t, t_next = self.get_variables(sigma, sigma_down) mult = [ append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next) ] x2 = mult[0] * x - mult[1] * denoised denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc) x_dpmpp2s = mult[2] * x - mult[3] * denoised2 # apply correction if noise level is not 0 x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler) x = self.ancestral_step(x, sigma, next_sigma, sigma_up) return x class DPMPP2MSampler(BaseDiffusionSampler): def get_variables(self, sigma, next_sigma, previous_sigma=None): t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] h = t_next - t if previous_sigma is not None: h_last = t - to_neg_log_sigma(previous_sigma) r = h_last / h return h, r, t, t_next else: return h, None, t, t_next def get_mult(self, h, r, t, t_next, previous_sigma): mult1 = to_sigma(t_next) / to_sigma(t) mult2 = (-h).expm1() if previous_sigma is not None: mult3 = 1 + 1 / (2 * r) mult4 = 1 / (2 * r) return mult1, mult2, mult3, mult4 else: return mult1, mult2 def sampler_step( self, old_denoised, previous_sigma, sigma, next_sigma, denoiser, x, cond, uc=None, ): denoised = self.denoise(x, denoiser, sigma, cond, uc) h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) mult = [ append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma) ] x_standard = mult[0] * x - mult[1] * denoised if old_denoised is None or torch.sum(next_sigma) < 1e-14: # Save a network evaluation if all noise levels are 0 or on the first step return x_standard, denoised else: denoised_d = mult[2] * denoised - mult[3] * old_denoised x_advanced = mult[0] * x - mult[1] * denoised_d # apply correction if noise level is not 0 and not first step x = torch.where( append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard ) return x, denoised def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( x, cond, uc, num_steps ) old_denoised = None for i in self.get_sigma_gen(num_sigmas): x, old_denoised = self.sampler_step( old_denoised, None if i == 0 else s_in * sigmas[i - 1], s_in * sigmas[i], s_in * sigmas[i + 1], denoiser, x, cond, uc=uc, ) return x ================================================ FILE: sgm/modules/diffusionmodules/sampling_utils.py ================================================ import torch from scipy import integrate from ...util import append_dims def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): if order - 1 > i: raise ValueError(f"Order {order} too high for step {i}") def fn(tau): prod = 1.0 for k in range(order): if j == k: continue prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) return prod return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0] def get_ancestral_step(sigma_from, sigma_to, eta=1.0): if not eta: return sigma_to, 0.0 sigma_up = torch.minimum( sigma_to, eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, ) sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 return sigma_down, sigma_up def to_d(x, sigma, denoised): return (x - denoised) / append_dims(sigma, x.ndim) def to_neg_log_sigma(sigma): return sigma.log().neg() def to_sigma(neg_log_sigma): return neg_log_sigma.neg().exp() ================================================ FILE: sgm/modules/diffusionmodules/sigma_sampling.py ================================================ import torch from typing import Optional, Union from ...util import default, instantiate_from_config class EDMSampling: def __init__(self, p_mean=-1.2, p_std=1.2): self.p_mean = p_mean self.p_std = p_std def __call__(self, n_samples, rand=None): log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,))) return log_sigma.exp() class DiscreteSampling: def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True): self.num_idx = num_idx self.sigmas = instantiate_from_config(discretization_config)( num_idx, do_append_zero=do_append_zero, flip=flip ) def idx_to_sigma(self, idx): return self.sigmas[idx] def __call__(self, n_samples, rand=None): idx = default( rand, torch.randint(0, self.num_idx, (n_samples,)), ) return self.idx_to_sigma(idx) class ZeroSampler: def __call__( self, n_samples: int, rand: Optional[torch.Tensor] = None ) -> torch.Tensor: return torch.zeros_like(default(rand, torch.randn((n_samples,)))) + 1.0e-5 ================================================ FILE: sgm/modules/diffusionmodules/util.py ================================================ """ partially adopted from https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py and https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py and https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py thanks! """ import math from typing import Optional import torch import torch.nn as nn from einops import rearrange, repeat def get_alpha( merge_strategy: str, mix_factor: Optional[torch.Tensor], image_only_indicator: torch.Tensor, apply_sigmoid: bool = True, is_attn: bool = False, ) -> torch.Tensor: if merge_strategy == "fixed" or merge_strategy == "learned": alpha = mix_factor elif merge_strategy == "learned_with_images": alpha = torch.where( image_only_indicator.bool(), torch.ones(1, 1, device=image_only_indicator.device), rearrange(mix_factor, "... -> ... 1"), ) if is_attn: alpha = rearrange(alpha, "b t -> (b t) 1 1") else: alpha = rearrange(alpha, "b t -> b 1 t 1 1") elif merge_strategy == "fixed_with_images": alpha = image_only_indicator if is_attn: alpha = rearrange(alpha, "b t -> (b t) 1 1") else: alpha = rearrange(alpha, "b t -> b 1 t 1 1") else: raise NotImplementedError return torch.sigmoid(alpha) if apply_sigmoid else alpha def make_beta_schedule( schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, ): if schedule == "linear": betas = ( torch.linspace( linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 ) ** 2 ) return betas.numpy() def extract_into_tensor(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) def mixed_checkpoint(func, inputs: dict, params, flag): """ Evaluate a function without caching intermediate activations, allowing for reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that it also works with non-tensor inputs :param func: the function to evaluate. :param inputs: the argument dictionary to pass to `func`. :param params: a sequence of parameters `func` depends on but does not explicitly take as arguments. :param flag: if False, disable gradient checkpointing. """ if flag: tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)] tensor_inputs = [ inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor) ] non_tensor_keys = [ key for key in inputs if not isinstance(inputs[key], torch.Tensor) ] non_tensor_inputs = [ inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor) ] args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params) return MixedCheckpointFunction.apply( func, len(tensor_inputs), len(non_tensor_inputs), tensor_keys, non_tensor_keys, *args, ) else: return func(**inputs) class MixedCheckpointFunction(torch.autograd.Function): @staticmethod def forward( ctx, run_function, length_tensors, length_non_tensors, tensor_keys, non_tensor_keys, *args, ): ctx.end_tensors = length_tensors ctx.end_non_tensors = length_tensors + length_non_tensors ctx.gpu_autocast_kwargs = { "enabled": torch.is_autocast_enabled(), "dtype": torch.get_autocast_gpu_dtype(), "cache_enabled": torch.is_autocast_cache_enabled(), } assert ( len(tensor_keys) == length_tensors and len(non_tensor_keys) == length_non_tensors ) ctx.input_tensors = { key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors])) } ctx.input_non_tensors = { key: val for (key, val) in zip( non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors]) ) } ctx.run_function = run_function ctx.input_params = list(args[ctx.end_non_tensors :]) with torch.no_grad(): output_tensors = ctx.run_function( **ctx.input_tensors, **ctx.input_non_tensors ) return output_tensors @staticmethod def backward(ctx, *output_grads): # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)} ctx.input_tensors = { key: ctx.input_tensors[key].detach().requires_grad_(True) for key in ctx.input_tensors } with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): # Fixes a bug where the first op in run_function modifies the # Tensor storage in place, which is not allowed for detach()'d # Tensors. shallow_copies = { key: ctx.input_tensors[key].view_as(ctx.input_tensors[key]) for key in ctx.input_tensors } # shallow_copies.update(additional_args) output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors) input_grads = torch.autograd.grad( output_tensors, list(ctx.input_tensors.values()) + ctx.input_params, output_grads, allow_unused=True, ) del ctx.input_tensors del ctx.input_params del output_tensors return ( (None, None, None, None, None) + input_grads[: ctx.end_tensors] + (None,) * (ctx.end_non_tensors - ctx.end_tensors) + input_grads[ctx.end_tensors :] ) def checkpoint(func, inputs, params, flag): """ Evaluate a function without caching intermediate activations, allowing for reduced memory at the expense of extra compute in the backward pass. :param func: the function to evaluate. :param inputs: the argument sequence to pass to `func`. :param params: a sequence of parameters `func` depends on but does not explicitly take as arguments. :param flag: if False, disable gradient checkpointing. """ if flag: args = tuple(inputs) + tuple(params) return CheckpointFunction.apply(func, len(inputs), *args) else: return func(*inputs) class CheckpointFunction(torch.autograd.Function): @staticmethod def forward(ctx, run_function, length, *args): ctx.run_function = run_function ctx.input_tensors = list(args[:length]) ctx.input_params = list(args[length:]) ctx.gpu_autocast_kwargs = { "enabled": torch.is_autocast_enabled(), "dtype": torch.get_autocast_gpu_dtype(), "cache_enabled": torch.is_autocast_cache_enabled(), } with torch.no_grad(): output_tensors = ctx.run_function(*ctx.input_tensors) return output_tensors @staticmethod def backward(ctx, *output_grads): ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): # Fixes a bug where the first op in run_function modifies the # Tensor storage in place, which is not allowed for detach()'d # Tensors. shallow_copies = [x.view_as(x) for x in ctx.input_tensors] output_tensors = ctx.run_function(*shallow_copies) input_grads = torch.autograd.grad( output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True, ) del ctx.input_tensors del ctx.input_params del output_tensors return (None, None) + input_grads def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): """ Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ if not repeat_only: half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=timesteps.device) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat( [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 ) else: embedding = repeat(timesteps, "b -> b d", d=dim) return embedding def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module def scale_module(module, scale): """ Scale the parameters of a module and return it. """ for p in module.parameters(): p.detach().mul_(scale) return module def mean_flat(tensor): """ Take the mean over all non-batch dimensions. """ return tensor.mean(dim=list(range(1, len(tensor.shape)))) def normalization(channels): """ Make a standard normalization layer. :param channels: number of input channels. :return: an nn.Module for normalization. """ return GroupNorm32(32, channels) # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. class SiLU(nn.Module): def forward(self, x): return x * torch.sigmoid(x) class GroupNorm32(nn.GroupNorm): def forward(self, x): return super().forward(x.float()).type(x.dtype) def conv_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D convolution module. """ if dims == 1: return nn.Conv1d(*args, **kwargs) elif dims == 2: return nn.Conv2d(*args, **kwargs) elif dims == 3: return nn.Conv3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") def linear(*args, **kwargs): """ Create a linear module. """ return nn.Linear(*args, **kwargs) def avg_pool_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D average pooling module. """ if dims == 1: return nn.AvgPool1d(*args, **kwargs) elif dims == 2: return nn.AvgPool2d(*args, **kwargs) elif dims == 3: return nn.AvgPool3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") class AlphaBlender(nn.Module): strategies = ["learned", "fixed", "learned_with_images"] def __init__( self, alpha: float, merge_strategy: str = "learned_with_images", rearrange_pattern: str = "b t -> (b t) 1 1", ): super().__init__() self.merge_strategy = merge_strategy self.rearrange_pattern = rearrange_pattern assert ( merge_strategy in self.strategies ), f"merge_strategy needs to be in {self.strategies}" if self.merge_strategy == "fixed": self.register_buffer("mix_factor", torch.Tensor([alpha])) elif ( self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images" ): self.register_parameter( "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) ) else: raise ValueError(f"unknown merge strategy {self.merge_strategy}") def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor: if self.merge_strategy == "fixed": alpha = self.mix_factor elif self.merge_strategy == "learned": alpha = torch.sigmoid(self.mix_factor) elif self.merge_strategy == "learned_with_images": assert image_only_indicator is not None, "need image_only_indicator ..." alpha = torch.where( image_only_indicator.bool(), torch.ones(1, 1, device=image_only_indicator.device), rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), ) alpha = rearrange(alpha, self.rearrange_pattern) else: raise NotImplementedError return alpha def forward( self, x_spatial: torch.Tensor, x_temporal: torch.Tensor, image_only_indicator: Optional[torch.Tensor] = None, ) -> torch.Tensor: alpha = self.get_alpha(image_only_indicator) x = ( alpha.to(x_spatial.dtype) * x_spatial + (1.0 - alpha).to(x_spatial.dtype) * x_temporal ) return x ================================================ FILE: sgm/modules/diffusionmodules/video_model.py ================================================ from functools import partial from typing import List, Optional, Union from einops import rearrange from ...modules.diffusionmodules.openaimodel import * from ...modules.video_attention import SpatialVideoTransformer from ...modules.spacetime_attention import ( BasicTransformerTimeMixBlock, PostHocSpatialTransformerWithTimeMixing, PostHocSpatialTransformerWithTimeMixingAndMotion, ) from ...util import default from .util import AlphaBlender, get_alpha class VideoResBlock(ResBlock): def __init__( self, channels: int, emb_channels: int, dropout: float, video_kernel_size: Union[int, List[int]] = 3, merge_strategy: str = "fixed", merge_factor: float = 0.5, out_channels: Optional[int] = None, use_conv: bool = False, use_scale_shift_norm: bool = False, dims: int = 2, use_checkpoint: bool = False, up: bool = False, down: bool = False, ): super().__init__( channels, emb_channels, dropout, out_channels=out_channels, use_conv=use_conv, use_scale_shift_norm=use_scale_shift_norm, dims=dims, use_checkpoint=use_checkpoint, up=up, down=down, ) self.time_stack = ResBlock( default(out_channels, channels), emb_channels, dropout=dropout, dims=3, out_channels=default(out_channels, channels), use_scale_shift_norm=False, use_conv=False, up=False, down=False, kernel_size=video_kernel_size, use_checkpoint=use_checkpoint, exchange_temb_dims=True, ) self.time_mixer = AlphaBlender( alpha=merge_factor, merge_strategy=merge_strategy, rearrange_pattern="b t -> b 1 t 1 1", ) def forward( self, x: th.Tensor, emb: th.Tensor, num_video_frames: int, image_only_indicator: Optional[th.Tensor] = None, ) -> th.Tensor: x = super().forward(x, emb) x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames) x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames) x = self.time_stack( x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames) ) x = self.time_mixer( x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator ) x = rearrange(x, "b c t h w -> (b t) c h w") return x class VideoUNet(nn.Module): def __init__( self, in_channels: int, model_channels: int, out_channels: int, num_res_blocks: int, attention_resolutions: int, dropout: float = 0.0, channel_mult: List[int] = (1, 2, 4, 8), conv_resample: bool = True, dims: int = 2, num_classes: Optional[int] = None, use_checkpoint: bool = False, num_heads: int = -1, num_head_channels: int = -1, num_heads_upsample: int = -1, use_scale_shift_norm: bool = False, resblock_updown: bool = False, transformer_depth: Union[List[int], int] = 1, transformer_depth_middle: Optional[int] = None, context_dim: Optional[int] = None, time_downup: bool = False, time_context_dim: Optional[int] = None, extra_ff_mix_layer: bool = False, use_spatial_context: bool = False, merge_strategy: str = "fixed", merge_factor: float = 0.5, spatial_transformer_attn_type: str = "softmax", video_kernel_size: Union[int, List[int]] = 3, use_linear_in_transformer: bool = False, adm_in_channels: Optional[int] = None, disable_temporal_crossattention: bool = False, max_ddpm_temb_period: int = 10000, ): super().__init__() assert context_dim is not None if num_heads_upsample == -1: num_heads_upsample = num_heads if num_heads == -1: assert num_head_channels != -1 if num_head_channels == -1: assert num_heads != -1 self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels if isinstance(transformer_depth, int): transformer_depth = len(channel_mult) * [transformer_depth] transformer_depth_middle = default( transformer_depth_middle, transformer_depth[-1] ) self.num_res_blocks = num_res_blocks self.attention_resolutions = attention_resolutions self.dropout = dropout self.channel_mult = channel_mult self.conv_resample = conv_resample self.num_classes = num_classes self.use_checkpoint = use_checkpoint self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ) if self.num_classes is not None: if isinstance(self.num_classes, int): self.label_emb = nn.Embedding(num_classes, time_embed_dim) elif self.num_classes == "continuous": print("setting up linear c_adm embedding layer") self.label_emb = nn.Linear(1, time_embed_dim) elif self.num_classes == "timestep": self.label_emb = nn.Sequential( Timestep(model_channels), nn.Sequential( linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ), ) elif self.num_classes == "sequential": assert adm_in_channels is not None self.label_emb = nn.Sequential( nn.Sequential( linear(adm_in_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ) ) else: raise ValueError() self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( conv_nd(dims, in_channels, model_channels, 3, padding=1) ) ] ) self._feature_size = model_channels input_block_chans = [model_channels] ch = model_channels ds = 1 def get_attention_layer( ch, num_heads, dim_head, depth=1, context_dim=None, use_checkpoint=False, disabled_sa=False, ): return SpatialVideoTransformer( ch, num_heads, dim_head, depth=depth, context_dim=context_dim, time_context_dim=time_context_dim, dropout=dropout, ff_in=extra_ff_mix_layer, use_spatial_context=use_spatial_context, merge_strategy=merge_strategy, merge_factor=merge_factor, checkpoint=use_checkpoint, use_linear=use_linear_in_transformer, attn_mode=spatial_transformer_attn_type, disable_self_attn=disabled_sa, disable_temporal_crossattention=disable_temporal_crossattention, max_time_embed_period=max_ddpm_temb_period, ) def get_resblock( merge_factor, merge_strategy, video_kernel_size, ch, time_embed_dim, dropout, out_ch, dims, use_checkpoint, use_scale_shift_norm, down=False, up=False, ): return VideoResBlock( merge_factor=merge_factor, merge_strategy=merge_strategy, video_kernel_size=video_kernel_size, channels=ch, emb_channels=time_embed_dim, dropout=dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=down, up=up, ) for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): layers = [ get_resblock( merge_factor=merge_factor, merge_strategy=merge_strategy, video_kernel_size=video_kernel_size, ch=ch, time_embed_dim=time_embed_dim, dropout=dropout, out_ch=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = mult * model_channels if ds in attention_resolutions: if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels layers.append( get_attention_layer( ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, use_checkpoint=use_checkpoint, disabled_sa=False, ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) if level != len(channel_mult) - 1: ds *= 2 out_ch = ch self.input_blocks.append( TimestepEmbedSequential( get_resblock( merge_factor=merge_factor, merge_strategy=merge_strategy, video_kernel_size=video_kernel_size, ch=ch, time_embed_dim=time_embed_dim, dropout=dropout, out_ch=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True, ) if resblock_updown else Downsample( ch, conv_resample, dims=dims, out_channels=out_ch, third_down=time_downup, ) ) ) ch = out_ch input_block_chans.append(ch) self._feature_size += ch if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels self.middle_block = TimestepEmbedSequential( get_resblock( merge_factor=merge_factor, merge_strategy=merge_strategy, video_kernel_size=video_kernel_size, ch=ch, time_embed_dim=time_embed_dim, out_ch=None, dropout=dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), get_attention_layer( ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, use_checkpoint=use_checkpoint, ), get_resblock( merge_factor=merge_factor, merge_strategy=merge_strategy, video_kernel_size=video_kernel_size, ch=ch, out_ch=None, time_embed_dim=time_embed_dim, dropout=dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), ) self._feature_size += ch self.output_blocks = nn.ModuleList([]) for level, mult in list(enumerate(channel_mult))[::-1]: for i in range(num_res_blocks + 1): ich = input_block_chans.pop() layers = [ get_resblock( merge_factor=merge_factor, merge_strategy=merge_strategy, video_kernel_size=video_kernel_size, ch=ch + ich, time_embed_dim=time_embed_dim, dropout=dropout, out_ch=model_channels * mult, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = model_channels * mult if ds in attention_resolutions: if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels layers.append( get_attention_layer( ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, use_checkpoint=use_checkpoint, disabled_sa=False, ) ) if level and i == num_res_blocks: out_ch = ch ds //= 2 layers.append( get_resblock( merge_factor=merge_factor, merge_strategy=merge_strategy, video_kernel_size=video_kernel_size, ch=ch, time_embed_dim=time_embed_dim, dropout=dropout, out_ch=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, up=True, ) if resblock_updown else Upsample( ch, conv_resample, dims=dims, out_channels=out_ch, third_up=time_downup, ) ) self.output_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch self.out = nn.Sequential( normalization(ch), nn.SiLU(), zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), ) def forward( self, x: th.Tensor, timesteps: th.Tensor, context: Optional[th.Tensor] = None, y: Optional[th.Tensor] = None, time_context: Optional[th.Tensor] = None, num_video_frames: Optional[int] = None, image_only_indicator: Optional[th.Tensor] = None, ): assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional -> no, relax this TODO" hs = [] t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) emb = self.time_embed(t_emb) if self.num_classes is not None: assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) h = x for module in self.input_blocks: h = module( h, emb, context=context, image_only_indicator=image_only_indicator, time_context=time_context, num_video_frames=num_video_frames, ) hs.append(h) h = self.middle_block( h, emb, context=context, image_only_indicator=image_only_indicator, time_context=time_context, num_video_frames=num_video_frames, ) for module in self.output_blocks: h = th.cat([h, hs.pop()], dim=1) h = module( h, emb, context=context, image_only_indicator=image_only_indicator, time_context=time_context, num_video_frames=num_video_frames, ) h = h.type(x.dtype) return self.out(h) class PostHocAttentionBlockWithTimeMixing(AttentionBlock): def __init__( self, in_channels: int, n_heads: int, d_head: int, use_checkpoint: bool = False, use_new_attention_order: bool = False, dropout: float = 0.0, use_spatial_context: bool = False, merge_strategy: bool = "fixed", merge_factor: float = 0.5, apply_sigmoid_to_merge: bool = True, ff_in: bool = False, attn_mode: str = "softmax", disable_temporal_crossattention: bool = False, ): super().__init__( in_channels, n_heads, d_head, use_checkpoint=use_checkpoint, use_new_attention_order=use_new_attention_order, ) inner_dim = n_heads * d_head self.time_mix_blocks = nn.ModuleList( [ BasicTransformerTimeMixBlock( inner_dim, n_heads, d_head, dropout=dropout, checkpoint=use_checkpoint, ff_in=ff_in, attn_mode=attn_mode, disable_temporal_crossattention=disable_temporal_crossattention, ) ] ) self.in_channels = in_channels time_embed_dim = self.in_channels * 4 self.time_mix_time_embed = nn.Sequential( linear(self.in_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, self.in_channels), ) self.use_spatial_context = use_spatial_context if merge_strategy == "fixed": self.register_buffer("mix_factor", th.Tensor([merge_factor])) elif merge_strategy == "learned" or merge_strategy == "learned_with_images": self.register_parameter( "mix_factor", th.nn.Parameter(th.Tensor([merge_factor])) ) elif merge_strategy == "fixed_with_images": self.mix_factor = None else: raise ValueError(f"unknown merge strategy {merge_strategy}") self.get_alpha_fn = functools.partial( get_alpha, merge_strategy, self.mix_factor, apply_sigmoid=apply_sigmoid_to_merge, ) def forward( self, x: th.Tensor, context: Optional[th.Tensor] = None, # cam: Optional[th.Tensor] = None, time_context: Optional[th.Tensor] = None, timesteps: Optional[int] = None, image_only_indicator: Optional[th.Tensor] = None, conv_view: Optional[th.Tensor] = None, conv_motion: Optional[th.Tensor] = None, ): if time_context is not None: raise NotImplementedError _, _, h, w = x.shape if exists(context): context = rearrange(context, "b t ... -> (b t) ...") if self.use_spatial_context: time_context = repeat(context[:, 0], "b ... -> (b n) ...", n=h * w) x = super().forward( x, ) x = rearrange(x, "b c h w -> b (h w) c") x_mix = x num_frames = th.arange(timesteps, device=x.device) num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) num_frames = rearrange(num_frames, "b t -> (b t)") t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) emb = self.time_mix_time_embed(t_emb) emb = emb[:, None, :] x_mix = x_mix + emb x_mix = self.time_mix_blocks[0]( x_mix, context=time_context, timesteps=timesteps ) alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator) x = alpha * x + (1.0 - alpha) * x_mix x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) return x class PostHocResBlockWithTime(ResBlock): def __init__( self, channels: int, emb_channels: int, dropout: float, time_kernel_size: Union[int, List[int]] = 3, merge_strategy: bool = "fixed", merge_factor: float = 0.5, apply_sigmoid_to_merge: bool = True, out_channels: Optional[int] = None, use_conv: bool = False, use_scale_shift_norm: bool = False, dims: int = 2, use_checkpoint: bool = False, up: bool = False, down: bool = False, time_mix_legacy: bool = True, replicate_bug: bool = False, ): super().__init__( channels, emb_channels, dropout, out_channels=out_channels, use_conv=use_conv, use_scale_shift_norm=use_scale_shift_norm, dims=dims, use_checkpoint=use_checkpoint, up=up, down=down, ) self.time_mix_blocks = ResBlock( default(out_channels, channels), emb_channels, dropout=dropout, dims=3, out_channels=default(out_channels, channels), use_scale_shift_norm=False, use_conv=False, up=False, down=False, kernel_size=time_kernel_size, use_checkpoint=use_checkpoint, exchange_temb_dims=True, ) self.time_mix_legacy = time_mix_legacy if self.time_mix_legacy: if merge_strategy == "fixed": self.register_buffer("mix_factor", th.Tensor([merge_factor])) elif merge_strategy == "learned" or merge_strategy == "learned_with_images": self.register_parameter( "mix_factor", th.nn.Parameter(th.Tensor([merge_factor])) ) elif merge_strategy == "fixed_with_images": self.mix_factor = None else: raise ValueError(f"unknown merge strategy {merge_strategy}") self.get_alpha_fn = functools.partial( get_alpha, merge_strategy, self.mix_factor, apply_sigmoid=apply_sigmoid_to_merge, ) else: if False: # replicate_bug: logpy.warning( "*****************************************************************************************\n" "GRAVE WARNING: YOU'RE USING THE BUGGY LEGACY ALPHABLENDER!!! ARE YOU SURE YOU WANT THIS?!\n" "*****************************************************************************************" ) self.time_mixer = LegacyAlphaBlenderWithBug( alpha=merge_factor, merge_strategy=merge_strategy, rearrange_pattern="b t -> b 1 t 1 1", ) else: self.time_mixer = AlphaBlender( alpha=merge_factor, merge_strategy=merge_strategy, rearrange_pattern="b t -> b 1 t 1 1", ) def forward( self, x: th.Tensor, emb: th.Tensor, num_video_frames: int, image_only_indicator: Optional[th.Tensor] = None, cond_view: Optional[th.Tensor] = None, cond_motion: Optional[th.Tensor] = None, ) -> th.Tensor: x = super().forward(x, emb) x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames) x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames) x = self.time_mix_blocks( x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames) ) if self.time_mix_legacy: alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator*0.0) x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix else: x = self.time_mixer( x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator*0.0 ) x = rearrange(x, "b c t h w -> (b t) c h w") return x class SpatialUNetModelWithTime(nn.Module): def __init__( self, in_channels: int, model_channels: int, out_channels: int, num_res_blocks: int, attention_resolutions: int, dropout: float = 0.0, channel_mult: List[int] = (1, 2, 4, 8), conv_resample: bool = True, dims: int = 2, num_classes: Optional[int] = None, use_checkpoint: bool = False, num_heads: int = -1, num_head_channels: int = -1, num_heads_upsample: int = -1, use_scale_shift_norm: bool = False, resblock_updown: bool = False, use_new_attention_order: bool = False, use_spatial_transformer: bool = False, transformer_depth: Union[List[int], int] = 1, transformer_depth_middle: Optional[int] = None, context_dim: Optional[int] = None, time_downup: bool = False, time_context_dim: Optional[int] = None, view_context_dim: Optional[int] = None, motion_context_dim: Optional[int] = None, extra_ff_mix_layer: bool = False, use_spatial_context: bool = False, time_block_merge_strategy: str = "fixed", time_block_merge_factor: float = 0.5, view_block_merge_factor: float = 0.5, motion_block_merge_factor: float = 0.5, spatial_transformer_attn_type: str = "softmax", time_kernel_size: Union[int, List[int]] = 3, use_linear_in_transformer: bool = False, legacy: bool = True, adm_in_channels: Optional[int] = None, use_temporal_resblock: bool = True, disable_temporal_crossattention: bool = False, time_mix_legacy: bool = True, max_ddpm_temb_period: int = 10000, replicate_time_mix_bug: bool = False, use_motion_attention: bool = False, use_camera_emb: bool = False, use_3d_attention: bool = False, separate_motion_merge_factor: bool = False, ): super().__init__() if use_spatial_transformer: assert context_dim is not None if context_dim is not None: assert use_spatial_transformer if num_heads_upsample == -1: num_heads_upsample = num_heads if num_heads == -1: assert num_head_channels != -1 if num_head_channels == -1: assert num_heads != -1 self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels if isinstance(transformer_depth, int): transformer_depth = len(channel_mult) * [transformer_depth] transformer_depth_middle = default( transformer_depth_middle, transformer_depth[-1] ) self.num_res_blocks = num_res_blocks self.attention_resolutions = attention_resolutions self.dropout = dropout self.channel_mult = channel_mult self.conv_resample = conv_resample self.num_classes = num_classes self.use_checkpoint = use_checkpoint self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample self.use_temporal_resblocks = use_temporal_resblock time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ) if self.num_classes is not None: if isinstance(self.num_classes, int): self.label_emb = nn.Embedding(num_classes, time_embed_dim) elif self.num_classes == "continuous": print("setting up linear c_adm embedding layer") self.label_emb = nn.Linear(1, time_embed_dim) elif self.num_classes == "timestep": self.label_emb = nn.Sequential( Timestep(model_channels), nn.Sequential( linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ), ) elif self.num_classes == "sequential": assert adm_in_channels is not None self.label_emb = nn.Sequential( nn.Sequential( linear(adm_in_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ) ) else: raise ValueError() self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( conv_nd(dims, in_channels, model_channels, 3, padding=1) ) ] ) self._feature_size = model_channels input_block_chans = [model_channels] ch = model_channels ds = 1 def get_attention_layer( ch, num_heads, dim_head, depth=1, context_dim=None, use_checkpoint=False, disabled_sa=False, ): if not use_spatial_transformer: return PostHocAttentionBlockWithTimeMixing( ch, num_heads, dim_head, use_checkpoint=use_checkpoint, use_new_attention_order=use_new_attention_order, dropout=dropout, ff_in=extra_ff_mix_layer, use_spatial_context=use_spatial_context, merge_strategy=time_block_merge_strategy, merge_factor=time_block_merge_factor, attn_mode=spatial_transformer_attn_type, disable_temporal_crossattention=disable_temporal_crossattention, ) elif use_motion_attention: return PostHocSpatialTransformerWithTimeMixingAndMotion( ch, num_heads, dim_head, depth=depth, context_dim=context_dim, time_context_dim=time_context_dim, motion_context_dim=motion_context_dim, dropout=dropout, ff_in=extra_ff_mix_layer, use_spatial_context=use_spatial_context, use_camera_emb=use_camera_emb, use_3d_attention=use_3d_attention, separate_motion_merge_factor=separate_motion_merge_factor, adm_in_channels=adm_in_channels, merge_strategy=time_block_merge_strategy, merge_factor=time_block_merge_factor, merge_factor_motion=motion_block_merge_factor, checkpoint=use_checkpoint, use_linear=use_linear_in_transformer, attn_mode=spatial_transformer_attn_type, disable_self_attn=disabled_sa, disable_temporal_crossattention=disable_temporal_crossattention, time_mix_legacy=time_mix_legacy, max_time_embed_period=max_ddpm_temb_period, ) else: return PostHocSpatialTransformerWithTimeMixing( ch, num_heads, dim_head, depth=depth, context_dim=context_dim, time_context_dim=time_context_dim, dropout=dropout, ff_in=extra_ff_mix_layer, use_spatial_context=use_spatial_context, merge_strategy=time_block_merge_strategy, merge_factor=time_block_merge_factor, checkpoint=use_checkpoint, use_linear=use_linear_in_transformer, attn_mode=spatial_transformer_attn_type, disable_self_attn=disabled_sa, disable_temporal_crossattention=disable_temporal_crossattention, time_mix_legacy=time_mix_legacy, max_time_embed_period=max_ddpm_temb_period, ) def get_resblock( time_block_merge_factor, time_block_merge_strategy, time_kernel_size, ch, time_embed_dim, dropout, out_ch, dims, use_checkpoint, use_scale_shift_norm, down=False, up=False, ): if self.use_temporal_resblocks: return PostHocResBlockWithTime( merge_factor=time_block_merge_factor, merge_strategy=time_block_merge_strategy, time_kernel_size=time_kernel_size, channels=ch, emb_channels=time_embed_dim, dropout=dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=down, up=up, time_mix_legacy=time_mix_legacy, replicate_bug=replicate_time_mix_bug, ) else: return ResBlock( channels=ch, emb_channels=time_embed_dim, dropout=dropout, out_channels=out_ch, use_checkpoint=use_checkpoint, dims=dims, use_scale_shift_norm=use_scale_shift_norm, down=down, up=up, ) for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): layers = [ get_resblock( time_block_merge_factor=time_block_merge_factor, time_block_merge_strategy=time_block_merge_strategy, time_kernel_size=time_kernel_size, ch=ch, time_embed_dim=time_embed_dim, dropout=dropout, out_ch=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = mult * model_channels if ds in attention_resolutions: if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: dim_head = ( ch // num_heads if use_spatial_transformer else num_head_channels ) layers.append( get_attention_layer( ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, use_checkpoint=use_checkpoint, disabled_sa=False, ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) if level != len(channel_mult) - 1: ds *= 2 out_ch = ch self.input_blocks.append( TimestepEmbedSequential( get_resblock( time_block_merge_factor=time_block_merge_factor, time_block_merge_strategy=time_block_merge_strategy, time_kernel_size=time_kernel_size, ch=ch, time_embed_dim=time_embed_dim, dropout=dropout, out_ch=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True, ) if resblock_updown else Downsample( ch, conv_resample, dims=dims, out_channels=out_ch, third_down=time_downup, ) ) ) ch = out_ch input_block_chans.append(ch) self._feature_size += ch if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels self.middle_block = TimestepEmbedSequential( get_resblock( time_block_merge_factor=time_block_merge_factor, time_block_merge_strategy=time_block_merge_strategy, time_kernel_size=time_kernel_size, ch=ch, time_embed_dim=time_embed_dim, out_ch=None, dropout=dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), get_attention_layer( ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, use_checkpoint=use_checkpoint, ), get_resblock( time_block_merge_factor=time_block_merge_factor, time_block_merge_strategy=time_block_merge_strategy, time_kernel_size=time_kernel_size, ch=ch, out_ch=None, time_embed_dim=time_embed_dim, dropout=dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), ) self._feature_size += ch self.output_blocks = nn.ModuleList([]) for level, mult in list(enumerate(channel_mult))[::-1]: for i in range(num_res_blocks + 1): ich = input_block_chans.pop() layers = [ get_resblock( time_block_merge_factor=time_block_merge_factor, time_block_merge_strategy=time_block_merge_strategy, time_kernel_size=time_kernel_size, ch=ch + ich, time_embed_dim=time_embed_dim, dropout=dropout, out_ch=model_channels * mult, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = model_channels * mult if ds in attention_resolutions: if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: dim_head = ( ch // num_heads if use_spatial_transformer else num_head_channels ) layers.append( get_attention_layer( ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, use_checkpoint=use_checkpoint, disabled_sa=False, ) ) if level and i == num_res_blocks: out_ch = ch ds //= 2 layers.append( get_resblock( time_block_merge_factor=time_block_merge_factor, time_block_merge_strategy=time_block_merge_strategy, time_kernel_size=time_kernel_size, ch=ch, time_embed_dim=time_embed_dim, dropout=dropout, out_ch=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, up=True, ) if resblock_updown else Upsample( ch, conv_resample, dims=dims, out_channels=out_ch, third_up=time_downup, ) ) self.output_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch self.out = nn.Sequential( normalization(ch), nn.SiLU(), zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), ) def forward( self, x: th.Tensor, timesteps: th.Tensor, context: Optional[th.Tensor] = None, y: Optional[th.Tensor] = None, cam: Optional[th.Tensor] = None, time_context: Optional[th.Tensor] = None, num_video_frames: Optional[int] = None, image_only_indicator: Optional[th.Tensor] = None, cond_view: Optional[th.Tensor] = None, cond_motion: Optional[th.Tensor] = None, time_step: Optional[int] = None, ): assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional -> no, relax this TODO" hs = [] t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # 21 x 320 emb = self.time_embed(t_emb) # 21 x 1280 time = str(timesteps[0].data.cpu().numpy()) if self.num_classes is not None: assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) # 21 x 1280 h = x # 21 x 8 x 64 x 64 for i, module in enumerate(self.input_blocks): h = module( h, emb, context=context, cam=cam, image_only_indicator=image_only_indicator, cond_view=cond_view, cond_motion=cond_motion, time_context=time_context, num_video_frames=num_video_frames, time_step=time_step, name='encoder_{}_{}'.format(time, i) ) hs.append(h) h = self.middle_block( h, emb, context=context, cam=cam, image_only_indicator=image_only_indicator, cond_view=cond_view, cond_motion=cond_motion, time_context=time_context, num_video_frames=num_video_frames, time_step=time_step, name='middle_{}_0'.format(time, i) ) for i, module in enumerate(self.output_blocks): h = th.cat([h, hs.pop()], dim=1) h = module( h, emb, context=context, cam=cam, image_only_indicator=image_only_indicator, cond_view=cond_view, cond_motion=cond_motion, time_context=time_context, num_video_frames=num_video_frames, time_step=time_step, name='decoder_{}_{}'.format(time, i) ) h = h.type(x.dtype) return self.out(h) ================================================ FILE: sgm/modules/diffusionmodules/wrappers.py ================================================ import torch import torch.nn as nn from packaging import version OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" class IdentityWrapper(nn.Module): def __init__(self, diffusion_model, compile_model: bool = False): super().__init__() compile = ( torch.compile if (version.parse(torch.__version__) >= version.parse("2.0.0")) and compile_model else lambda x: x ) self.diffusion_model = compile(diffusion_model) def forward(self, *args, **kwargs): return self.diffusion_model(*args, **kwargs) class OpenAIWrapper(IdentityWrapper): def forward( self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs ) -> torch.Tensor: x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) if "cond_view" in c: return self.diffusion_model( x, timesteps=t, context=c.get("crossattn", None), y=c.get("vector", None), cond_view=c.get("cond_view", None), cond_motion=c.get("cond_motion", None), **kwargs, ) else: return self.diffusion_model( x, timesteps=t, context=c.get("crossattn", None), y=c.get("vector", None), **kwargs, ) ================================================ FILE: sgm/modules/distributions/__init__.py ================================================ ================================================ FILE: sgm/modules/distributions/distributions.py ================================================ import numpy as np import torch class AbstractDistribution: def sample(self): raise NotImplementedError() def mode(self): raise NotImplementedError() class DiracDistribution(AbstractDistribution): def __init__(self, value): self.value = value def sample(self): return self.value def mode(self): return self.value class DiagonalGaussianDistribution(object): def __init__(self, parameters, deterministic=False): self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: self.var = self.std = torch.zeros_like(self.mean).to( device=self.parameters.device ) def sample(self): x = self.mean + self.std * torch.randn(self.mean.shape).to( device=self.parameters.device ) return x def kl(self, other=None): if self.deterministic: return torch.Tensor([0.0]) else: if other is None: return 0.5 * torch.sum( torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3], ) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, dim=[1, 2, 3], ) def nll(self, sample, dims=[1, 2, 3]): if self.deterministic: return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) return 0.5 * torch.sum( logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims, ) def mode(self): return self.mean def normal_kl(mean1, logvar1, mean2, logvar2): """ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 Compute the KL divergence between two gaussians. Shapes are automatically broadcasted, so batches can be compared to scalars, among other use cases. """ tensor = None for obj in (mean1, logvar1, mean2, logvar2): if isinstance(obj, torch.Tensor): tensor = obj break assert tensor is not None, "at least one argument must be a Tensor" # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for torch.exp(). logvar1, logvar2 = [ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2) ] return 0.5 * ( -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) ) ================================================ FILE: sgm/modules/ema.py ================================================ import torch from torch import nn class LitEma(nn.Module): def __init__(self, model, decay=0.9999, use_num_upates=True): super().__init__() if decay < 0.0 or decay > 1.0: raise ValueError("Decay must be between 0 and 1") self.m_name2s_name = {} self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) self.register_buffer( "num_updates", torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int), ) for name, p in model.named_parameters(): if p.requires_grad: # remove as '.'-character is not allowed in buffers s_name = name.replace(".", "") self.m_name2s_name.update({name: s_name}) self.register_buffer(s_name, p.clone().detach().data) self.collected_params = [] def reset_num_updates(self): del self.num_updates self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) def forward(self, model): decay = self.decay if self.num_updates >= 0: self.num_updates += 1 decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) one_minus_decay = 1.0 - decay with torch.no_grad(): m_param = dict(model.named_parameters()) shadow_params = dict(self.named_buffers()) for key in m_param: if m_param[key].requires_grad: sname = self.m_name2s_name[key] shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) shadow_params[sname].sub_( one_minus_decay * (shadow_params[sname] - m_param[key]) ) else: assert not key in self.m_name2s_name def copy_to(self, model): m_param = dict(model.named_parameters()) shadow_params = dict(self.named_buffers()) for key in m_param: if m_param[key].requires_grad: m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) else: assert not key in self.m_name2s_name def store(self, parameters): """ Save the current parameters for restoring later. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. """ self.collected_params = [param.clone() for param in parameters] def restore(self, parameters): """ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without affecting the original optimization process. Store the parameters before the `copy_to` method. After validation (or model saving), use this to restore the former parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. """ for c_param, param in zip(self.collected_params, parameters): param.data.copy_(c_param.data) ================================================ FILE: sgm/modules/encoders/__init__.py ================================================ ================================================ FILE: sgm/modules/encoders/modules.py ================================================ import math from contextlib import nullcontext from functools import partial from typing import Dict, List, Optional, Tuple, Union import kornia import numpy as np import open_clip import torch import torch.nn as nn from einops import rearrange, repeat from omegaconf import ListConfig from torch.utils.checkpoint import checkpoint from transformers import (ByT5Tokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer) from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer from ...modules.diffusionmodules.model import Encoder from ...modules.diffusionmodules.openaimodel import Timestep from ...modules.diffusionmodules.util import (extract_into_tensor, make_beta_schedule) from ...modules.distributions.distributions import DiagonalGaussianDistribution from ...util import (append_dims, autocast, count_params, default, disabled_train, expand_dims_like, instantiate_from_config) class AbstractEmbModel(nn.Module): def __init__(self): super().__init__() self._is_trainable = None self._ucg_rate = None self._input_key = None @property def is_trainable(self) -> bool: return self._is_trainable @property def ucg_rate(self) -> Union[float, torch.Tensor]: return self._ucg_rate @property def input_key(self) -> str: return self._input_key @is_trainable.setter def is_trainable(self, value: bool): self._is_trainable = value @ucg_rate.setter def ucg_rate(self, value: Union[float, torch.Tensor]): self._ucg_rate = value @input_key.setter def input_key(self, value: str): self._input_key = value @is_trainable.deleter def is_trainable(self): del self._is_trainable @ucg_rate.deleter def ucg_rate(self): del self._ucg_rate @input_key.deleter def input_key(self): del self._input_key class GeneralConditioner(nn.Module): OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat"} # , 5: "concat"} KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1, "cond_view": 1, "cond_motion": 1} def __init__(self, emb_models: Union[List, ListConfig]): super().__init__() embedders = [] for n, embconfig in enumerate(emb_models): embedder = instantiate_from_config(embconfig) assert isinstance( embedder, AbstractEmbModel ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel" embedder.is_trainable = embconfig.get("is_trainable", False) embedder.ucg_rate = embconfig.get("ucg_rate", 0.0) if not embedder.is_trainable: embedder.train = disabled_train for param in embedder.parameters(): param.requires_grad = False embedder.eval() print( f"Initialized embedder #{n}: {embedder.__class__.__name__} " f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}" ) if "input_key" in embconfig: embedder.input_key = embconfig["input_key"] elif "input_keys" in embconfig: embedder.input_keys = embconfig["input_keys"] else: raise KeyError( f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}" ) embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None) if embedder.legacy_ucg_val is not None: embedder.ucg_prng = np.random.RandomState() embedders.append(embedder) self.embedders = nn.ModuleList(embedders) def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict: assert embedder.legacy_ucg_val is not None p = embedder.ucg_rate val = embedder.legacy_ucg_val for i in range(len(batch[embedder.input_key])): if embedder.ucg_prng.choice(2, p=[1 - p, p]): batch[embedder.input_key][i] = val return batch def forward( self, batch: Dict, force_zero_embeddings: Optional[List] = None ) -> Dict: output = dict() if force_zero_embeddings is None: force_zero_embeddings = [] for embedder in self.embedders: embedding_context = nullcontext if embedder.is_trainable else torch.no_grad with embedding_context(): if hasattr(embedder, "input_key") and (embedder.input_key is not None): if embedder.legacy_ucg_val is not None: batch = self.possibly_get_ucg_val(embedder, batch) emb_out = embedder(batch[embedder.input_key]) elif hasattr(embedder, "input_keys"): emb_out = embedder(*[batch[k] for k in embedder.input_keys]) assert isinstance( emb_out, (torch.Tensor, list, tuple) ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}" if not isinstance(emb_out, (list, tuple)): emb_out = [emb_out] for emb in emb_out: if embedder.input_key in ["cond_view", "cond_motion"]: out_key = embedder.input_key else: out_key = self.OUTPUT_DIM2KEYS[emb.dim()] if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: emb = ( expand_dims_like( torch.bernoulli( (1.0 - embedder.ucg_rate) * torch.ones(emb.shape[0], device=emb.device) ), emb, ) * emb ) if ( hasattr(embedder, "input_key") and embedder.input_key in force_zero_embeddings ): emb = torch.zeros_like(emb) if out_key in output: output[out_key] = torch.cat( (output[out_key], emb), self.KEY2CATDIM[out_key] ) else: output[out_key] = emb return output def get_unconditional_conditioning( self, batch_c: Dict, batch_uc: Optional[Dict] = None, force_uc_zero_embeddings: Optional[List[str]] = None, force_cond_zero_embeddings: Optional[List[str]] = None, ): if force_uc_zero_embeddings is None: force_uc_zero_embeddings = [] ucg_rates = list() for embedder in self.embedders: ucg_rates.append(embedder.ucg_rate) embedder.ucg_rate = 0.0 c = self(batch_c, force_cond_zero_embeddings) uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings) for embedder, rate in zip(self.embedders, ucg_rates): embedder.ucg_rate = rate return c, uc class InceptionV3(nn.Module): """Wrapper around the https://github.com/mseitzer/pytorch-fid inception port with an additional squeeze at the end""" def __init__(self, normalize_input=False, **kwargs): super().__init__() from pytorch_fid import inception kwargs["resize_input"] = True self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs) def forward(self, inp): outp = self.model(inp) if len(outp) == 1: return outp[0].squeeze() return outp class IdentityEncoder(AbstractEmbModel): def encode(self, x): return x def forward(self, x): return x class ClassEmbedder(AbstractEmbModel): def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False): super().__init__() self.embedding = nn.Embedding(n_classes, embed_dim) self.n_classes = n_classes self.add_sequence_dim = add_sequence_dim def forward(self, c): c = self.embedding(c) if self.add_sequence_dim: c = c[:, None, :] return c def get_unconditional_conditioning(self, bs, device="cuda"): uc_class = ( self.n_classes - 1 ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) uc = torch.ones((bs,), device=device) * uc_class uc = {self.key: uc.long()} return uc class ClassEmbedderForMultiCond(ClassEmbedder): def forward(self, batch, key=None, disable_dropout=False): out = batch key = default(key, self.key) islist = isinstance(batch[key], list) if islist: batch[key] = batch[key][0] c_out = super().forward(batch, key, disable_dropout) out[key] = [c_out] if islist else c_out return out class FrozenT5Embedder(AbstractEmbModel): """Uses the T5 transformer encoder for text""" def __init__( self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() self.tokenizer = T5Tokenizer.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version) self.device = device self.max_length = max_length if freeze: self.freeze() def freeze(self): self.transformer = self.transformer.eval() for param in self.parameters(): param.requires_grad = False def forward(self, text): batch_encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt", ) tokens = batch_encoding["input_ids"].to(self.device) with torch.autocast("cuda", enabled=False): outputs = self.transformer(input_ids=tokens) z = outputs.last_hidden_state return z def encode(self, text): return self(text) class FrozenByT5Embedder(AbstractEmbModel): """ Uses the ByT5 transformer encoder for text. Is character-aware. """ def __init__( self, version="google/byt5-base", device="cuda", max_length=77, freeze=True ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() self.tokenizer = ByT5Tokenizer.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version) self.device = device self.max_length = max_length if freeze: self.freeze() def freeze(self): self.transformer = self.transformer.eval() for param in self.parameters(): param.requires_grad = False def forward(self, text): batch_encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt", ) tokens = batch_encoding["input_ids"].to(self.device) with torch.autocast("cuda", enabled=False): outputs = self.transformer(input_ids=tokens) z = outputs.last_hidden_state return z def encode(self, text): return self(text) class FrozenCLIPEmbedder(AbstractEmbModel): """Uses the CLIP transformer encoder for text (from huggingface)""" LAYERS = ["last", "pooled", "hidden"] def __init__( self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, freeze=True, layer="last", layer_idx=None, always_return_pooled=False, ): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS self.tokenizer = CLIPTokenizer.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version) self.device = device self.max_length = max_length if freeze: self.freeze() self.layer = layer self.layer_idx = layer_idx self.return_pooled = always_return_pooled if layer == "hidden": assert layer_idx is not None assert 0 <= abs(layer_idx) <= 12 def freeze(self): self.transformer = self.transformer.eval() for param in self.parameters(): param.requires_grad = False @autocast def forward(self, text): batch_encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt", ) tokens = batch_encoding["input_ids"].to(self.device) outputs = self.transformer( input_ids=tokens, output_hidden_states=self.layer == "hidden" ) if self.layer == "last": z = outputs.last_hidden_state elif self.layer == "pooled": z = outputs.pooler_output[:, None, :] else: z = outputs.hidden_states[self.layer_idx] if self.return_pooled: return z, outputs.pooler_output return z def encode(self, text): return self(text) class FrozenOpenCLIPEmbedder2(AbstractEmbModel): """ Uses the OpenCLIP transformer encoder for text """ LAYERS = ["pooled", "last", "penultimate"] def __init__( self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, freeze=True, layer="last", always_return_pooled=False, legacy=True, ): super().__init__() assert layer in self.LAYERS model, _, _ = open_clip.create_model_and_transforms( arch, device=torch.device("cpu"), pretrained=version, ) del model.visual self.model = model self.device = device self.max_length = max_length self.return_pooled = always_return_pooled if freeze: self.freeze() self.layer = layer if self.layer == "last": self.layer_idx = 0 elif self.layer == "penultimate": self.layer_idx = 1 else: raise NotImplementedError() self.legacy = legacy def freeze(self): self.model = self.model.eval() for param in self.parameters(): param.requires_grad = False @autocast def forward(self, text): tokens = open_clip.tokenize(text) z = self.encode_with_transformer(tokens.to(self.device)) if not self.return_pooled and self.legacy: return z if self.return_pooled: assert not self.legacy return z[self.layer], z["pooled"] return z[self.layer] def encode_with_transformer(self, text): x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] x = x + self.model.positional_embedding x = x.permute(1, 0, 2) # NLD -> LND x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) if self.legacy: x = x[self.layer] x = self.model.ln_final(x) return x else: # x is a dict and will stay a dict o = x["last"] o = self.model.ln_final(o) pooled = self.pool(o, text) x["pooled"] = pooled return x def pool(self, x, text): # take features from the eot embedding (eot_token is the highest number in each sequence) x = ( x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection ) return x def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): outputs = {} for i, r in enumerate(self.model.transformer.resblocks): if i == len(self.model.transformer.resblocks) - 1: outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD if ( self.model.transformer.grad_checkpointing and not torch.jit.is_scripting() ): x = checkpoint(r, x, attn_mask) else: x = r(x, attn_mask=attn_mask) outputs["last"] = x.permute(1, 0, 2) # LND -> NLD return outputs def encode(self, text): return self(text) class FrozenOpenCLIPEmbedder(AbstractEmbModel): LAYERS = [ # "pooled", "last", "penultimate", ] def __init__( self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, freeze=True, layer="last", ): super().__init__() assert layer in self.LAYERS model, _, _ = open_clip.create_model_and_transforms( arch, device=torch.device("cpu"), pretrained=version ) del model.visual self.model = model self.device = device self.max_length = max_length if freeze: self.freeze() self.layer = layer if self.layer == "last": self.layer_idx = 0 elif self.layer == "penultimate": self.layer_idx = 1 else: raise NotImplementedError() def freeze(self): self.model = self.model.eval() for param in self.parameters(): param.requires_grad = False def forward(self, text): tokens = open_clip.tokenize(text) z = self.encode_with_transformer(tokens.to(self.device)) return z def encode_with_transformer(self, text): x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] x = x + self.model.positional_embedding x = x.permute(1, 0, 2) # NLD -> LND x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) x = x.permute(1, 0, 2) # LND -> NLD x = self.model.ln_final(x) return x def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): for i, r in enumerate(self.model.transformer.resblocks): if i == len(self.model.transformer.resblocks) - self.layer_idx: break if ( self.model.transformer.grad_checkpointing and not torch.jit.is_scripting() ): x = checkpoint(r, x, attn_mask) else: x = r(x, attn_mask=attn_mask) return x def encode(self, text): return self(text) class FrozenOpenCLIPImageEmbedder(AbstractEmbModel): """ Uses the OpenCLIP vision transformer encoder for images """ def __init__( self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, freeze=True, antialias=True, ucg_rate=0.0, unsqueeze_dim=False, repeat_to_max_len=False, num_image_crops=0, output_tokens=False, init_device=None, ): super().__init__() model, _, _ = open_clip.create_model_and_transforms( arch, device=torch.device(default(init_device, "cpu")), pretrained=version, ) del model.transformer self.model = model self.max_crops = num_image_crops self.pad_to_max_len = self.max_crops > 0 self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len) self.device = device self.max_length = max_length if freeze: self.freeze() self.antialias = antialias self.register_buffer( "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False ) self.register_buffer( "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False ) self.ucg_rate = ucg_rate self.unsqueeze_dim = unsqueeze_dim self.stored_batch = None self.model.visual.output_tokens = output_tokens self.output_tokens = output_tokens def preprocess(self, x): # normalize to [0,1] x = kornia.geometry.resize( x, (224, 224), interpolation="bicubic", align_corners=True, antialias=self.antialias, ) x = (x + 1.0) / 2.0 # renormalize according to clip x = kornia.enhance.normalize(x, self.mean, self.std) return x def freeze(self): self.model = self.model.eval() for param in self.parameters(): param.requires_grad = False @autocast def forward(self, image, no_dropout=False): z = self.encode_with_vision_transformer(image) tokens = None if self.output_tokens: z, tokens = z[0], z[1] z = z.to(image.dtype) if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0): z = ( torch.bernoulli( (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device) )[:, None] * z ) if tokens is not None: tokens = ( expand_dims_like( torch.bernoulli( (1.0 - self.ucg_rate) * torch.ones(tokens.shape[0], device=tokens.device) ), tokens, ) * tokens ) if self.unsqueeze_dim: z = z[:, None, :] if self.output_tokens: assert not self.repeat_to_max_len assert not self.pad_to_max_len return tokens, z if self.repeat_to_max_len: if z.dim() == 2: z_ = z[:, None, :] else: z_ = z return repeat(z_, "b 1 d -> b n d", n=self.max_length), z elif self.pad_to_max_len: assert z.dim() == 3 z_pad = torch.cat( ( z, torch.zeros( z.shape[0], self.max_length - z.shape[1], z.shape[2], device=z.device, ), ), 1, ) return z_pad, z_pad[:, 0, ...] return z def encode_with_vision_transformer(self, img): # if self.max_crops > 0: # img = self.preprocess_by_cropping(img) if img.dim() == 5: assert self.max_crops == img.shape[1] img = rearrange(img, "b n c h w -> (b n) c h w") img = self.preprocess(img) if not self.output_tokens: assert not self.model.visual.output_tokens x = self.model.visual(img) tokens = None else: assert self.model.visual.output_tokens x, tokens = self.model.visual(img) if self.max_crops > 0: x = rearrange(x, "(b n) d -> b n d", n=self.max_crops) # drop out between 0 and all along the sequence axis x = ( torch.bernoulli( (1.0 - self.ucg_rate) * torch.ones(x.shape[0], x.shape[1], 1, device=x.device) ) * x ) if tokens is not None: tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops) print( f"You are running very experimental token-concat in {self.__class__.__name__}. " f"Check what you are doing, and then remove this message." ) if self.output_tokens: return x, tokens return x def encode(self, text): return self(text) class FrozenCLIPT5Encoder(AbstractEmbModel): def __init__( self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", clip_max_length=77, t5_max_length=77, ): super().__init__() self.clip_encoder = FrozenCLIPEmbedder( clip_version, device, max_length=clip_max_length ) self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) print( f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params." ) def encode(self, text): return self(text) def forward(self, text): clip_z = self.clip_encoder.encode(text) t5_z = self.t5_encoder.encode(text) return [clip_z, t5_z] class SpatialRescaler(nn.Module): def __init__( self, n_stages=1, method="bilinear", multiplier=0.5, in_channels=3, out_channels=None, bias=False, wrap_video=False, kernel_size=1, remap_output=False, ): super().__init__() self.n_stages = n_stages assert self.n_stages >= 0 assert method in [ "nearest", "linear", "bilinear", "trilinear", "bicubic", "area", ] self.multiplier = multiplier self.interpolator = partial(torch.nn.functional.interpolate, mode=method) self.remap_output = out_channels is not None or remap_output if self.remap_output: print( f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing." ) self.channel_mapper = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, bias=bias, padding=kernel_size // 2, ) self.wrap_video = wrap_video def forward(self, x): if self.wrap_video and x.ndim == 5: B, C, T, H, W = x.shape x = rearrange(x, "b c t h w -> b t c h w") x = rearrange(x, "b t c h w -> (b t) c h w") for stage in range(self.n_stages): x = self.interpolator(x, scale_factor=self.multiplier) if self.wrap_video: x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C) x = rearrange(x, "b t c h w -> b c t h w") if self.remap_output: x = self.channel_mapper(x) return x def encode(self, x): return self(x) class LowScaleEncoder(nn.Module): def __init__( self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64, scale_factor=1.0, ): super().__init__() self.max_noise_level = max_noise_level self.model = instantiate_from_config(model_config) self.augmentation_schedule = self.register_schedule( timesteps=timesteps, linear_start=linear_start, linear_end=linear_end ) self.out_size = output_size self.scale_factor = scale_factor def register_schedule( self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, ): betas = make_beta_schedule( beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s, ) alphas = 1.0 - betas alphas_cumprod = np.cumprod(alphas, axis=0) alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) (timesteps,) = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end assert ( alphas_cumprod.shape[0] == self.num_timesteps ), "alphas have to be defined for each timestep" to_torch = partial(torch.tensor, dtype=torch.float32) self.register_buffer("betas", to_torch(betas)) self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) self.register_buffer( "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) ) self.register_buffer( "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) ) self.register_buffer( "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) ) self.register_buffer( "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) ) def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) return ( extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) def forward(self, x): z = self.model.encode(x) if isinstance(z, DiagonalGaussianDistribution): z = z.sample() z = z * self.scale_factor noise_level = torch.randint( 0, self.max_noise_level, (x.shape[0],), device=x.device ).long() z = self.q_sample(z, noise_level) if self.out_size is not None: z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") return z, noise_level def decode(self, z): z = z / self.scale_factor return self.model.decode(z) class ConcatTimestepEmbedderND(AbstractEmbModel): """embeds each dimension independently and concatenates them""" def __init__(self, outdim): super().__init__() self.timestep = Timestep(outdim) self.outdim = outdim def forward(self, x): if x.ndim == 1: x = x[:, None] assert len(x.shape) == 2 b, dims = x.shape[0], x.shape[1] x = rearrange(x, "b d -> (b d)") emb = self.timestep(x) emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) return emb class GaussianEncoder(Encoder, AbstractEmbModel): def __init__( self, weight: float = 1.0, flatten_output: bool = True, *args, **kwargs ): super().__init__(*args, **kwargs) self.posterior = DiagonalGaussianRegularizer() self.weight = weight self.flatten_output = flatten_output def forward(self, x) -> Tuple[Dict, torch.Tensor]: z = super().forward(x) z, log = self.posterior(z) log["loss"] = log["kl_loss"] log["weight"] = self.weight if self.flatten_output: z = rearrange(z, "b c h w -> b (h w ) c") return log, z class VideoPredictionEmbedderWithEncoder(AbstractEmbModel): def __init__( self, n_cond_frames: int, n_copies: int, encoder_config: dict, sigma_sampler_config: Optional[dict] = None, sigma_cond_config: Optional[dict] = None, is_ae: bool = False, scale_factor: float = 1.0, disable_encoder_autocast: bool = False, en_and_decode_n_samples_a_time: Optional[int] = None, ): super().__init__() self.n_cond_frames = n_cond_frames self.n_copies = n_copies self.encoder = instantiate_from_config(encoder_config) self.sigma_sampler = ( instantiate_from_config(sigma_sampler_config) if sigma_sampler_config is not None else None ) self.sigma_cond = ( instantiate_from_config(sigma_cond_config) if sigma_cond_config is not None else None ) self.is_ae = is_ae self.scale_factor = scale_factor self.disable_encoder_autocast = disable_encoder_autocast self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time def forward( self, vid: torch.Tensor ) -> Union[ torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, dict], Tuple[Tuple[torch.Tensor, torch.Tensor], dict], ]: if self.sigma_sampler is not None: b = vid.shape[0] // self.n_cond_frames sigmas = self.sigma_sampler(b).to(vid.device) if self.sigma_cond is not None: sigma_cond = self.sigma_cond(sigmas) if self.n_cond_frames == 1: sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies) else: sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_cond_frames) # For SV4D sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames) noise = torch.randn_like(vid) vid = vid + noise * append_dims(sigmas, vid.ndim) with torch.autocast("cuda", enabled=not self.disable_encoder_autocast): n_samples = ( self.en_and_decode_n_samples_a_time if self.en_and_decode_n_samples_a_time is not None else vid.shape[0] ) n_rounds = math.ceil(vid.shape[0] / n_samples) all_out = [] for n in range(n_rounds): if self.is_ae: out = self.encoder.encode(vid[n * n_samples : (n + 1) * n_samples]) else: out = self.encoder(vid[n * n_samples : (n + 1) * n_samples]) all_out.append(out) vid = torch.cat(all_out, dim=0) vid *= self.scale_factor if self.n_cond_frames == 1: vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames) vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies) return_val = (vid, sigma_cond) if self.sigma_cond is not None else vid return return_val class FrozenOpenCLIPImagePredictionEmbedder(AbstractEmbModel): def __init__( self, open_clip_embedding_config: Dict, n_cond_frames: int, n_copies: int, ): super().__init__() self.n_cond_frames = n_cond_frames self.n_copies = n_copies self.open_clip = instantiate_from_config(open_clip_embedding_config) def forward(self, vid): vid = self.open_clip(vid) vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames) vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies) return vid ================================================ FILE: sgm/modules/spacetime_attention.py ================================================ from functools import partial import torch import torch.nn.functional as F from ..modules.attention import * from ..modules.diffusionmodules.util import ( AlphaBlender, get_alpha, linear, mixed_checkpoint, timestep_embedding, ) class TimeMixSequential(nn.Sequential): def forward(self, x, context=None, timesteps=None): for layer in self: x = layer(x, context, timesteps) return x class BasicTransformerTimeMixBlock(nn.Module): ATTENTION_MODES = { "softmax": CrossAttention, "softmax-xformers": MemoryEfficientCrossAttention, } def __init__( self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True, timesteps=None, ff_in=False, inner_dim=None, attn_mode="softmax", disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, ): super().__init__() attn_cls = self.ATTENTION_MODES[attn_mode] self.ff_in = ff_in or inner_dim is not None if inner_dim is None: inner_dim = dim assert int(n_heads * d_head) == inner_dim self.is_res = inner_dim == dim if self.ff_in: self.norm_in = nn.LayerNorm(dim) self.ff_in = FeedForward( dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff ) self.timesteps = timesteps self.disable_self_attn = disable_self_attn if self.disable_self_attn: self.attn1 = attn_cls( query_dim=inner_dim, heads=n_heads, dim_head=d_head, context_dim=context_dim, dropout=dropout, ) # is a cross-attention else: self.attn1 = attn_cls( query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout ) # is a self-attention self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff) if disable_temporal_crossattention: if switch_temporal_ca_to_sa: raise ValueError else: self.attn2 = None else: self.norm2 = nn.LayerNorm(inner_dim) if switch_temporal_ca_to_sa: self.attn2 = attn_cls( query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout ) # is a self-attention else: self.attn2 = attn_cls( query_dim=inner_dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout, ) # is self-attn if context is none self.norm1 = nn.LayerNorm(inner_dim) self.norm3 = nn.LayerNorm(inner_dim) self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa self.checkpoint = checkpoint if self.checkpoint: logpy.info(f"{self.__class__.__name__} is using checkpointing") def forward( self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None ) -> torch.Tensor: if self.checkpoint: return checkpoint(self._forward, x, context, timesteps) else: return self._forward(x, context, timesteps=timesteps) def _forward(self, x, context=None, timesteps=None): assert self.timesteps or timesteps assert not (self.timesteps and timesteps) or self.timesteps == timesteps timesteps = self.timesteps or timesteps B, S, C = x.shape x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps) if self.ff_in: x_skip = x x = self.ff_in(self.norm_in(x)) if self.is_res: x += x_skip if self.disable_self_attn: x = self.attn1(self.norm1(x), context=context) + x else: x = self.attn1(self.norm1(x)) + x if self.attn2 is not None: if self.switch_temporal_ca_to_sa: x = self.attn2(self.norm2(x)) + x else: x = self.attn2(self.norm2(x), context=context) + x x_skip = x x = self.ff(self.norm3(x)) if self.is_res: x += x_skip x = rearrange( x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps ) return x def get_last_layer(self): return self.ff.net[-1].weight class PostHocSpatialTransformerWithTimeMixing(SpatialTransformer): def __init__( self, in_channels, n_heads, d_head, depth=1, dropout=0.0, use_linear=False, context_dim=None, use_spatial_context=False, timesteps=None, merge_strategy: str = "fixed", merge_factor: float = 0.5, apply_sigmoid_to_merge: bool = True, time_context_dim=None, ff_in=False, checkpoint=False, time_depth=1, attn_mode="softmax", disable_self_attn=False, disable_temporal_crossattention=False, time_mix_legacy: bool = True, max_time_embed_period: int = 10000, ): super().__init__( in_channels, n_heads, d_head, depth=depth, dropout=dropout, attn_type=attn_mode, use_checkpoint=checkpoint, context_dim=context_dim, use_linear=use_linear, disable_self_attn=disable_self_attn, ) self.time_depth = time_depth self.depth = depth self.max_time_embed_period = max_time_embed_period time_mix_d_head = d_head n_time_mix_heads = n_heads time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads) inner_dim = n_heads * d_head if use_spatial_context: time_context_dim = context_dim self.time_mix_blocks = nn.ModuleList( [ BasicTransformerTimeMixBlock( inner_dim, n_time_mix_heads, time_mix_d_head, dropout=dropout, context_dim=time_context_dim, timesteps=timesteps, checkpoint=checkpoint, ff_in=ff_in, inner_dim=time_mix_inner_dim, attn_mode=attn_mode, disable_self_attn=disable_self_attn, disable_temporal_crossattention=disable_temporal_crossattention, ) for _ in range(self.depth) ] ) assert len(self.time_mix_blocks) == len(self.transformer_blocks) self.use_spatial_context = use_spatial_context self.in_channels = in_channels time_embed_dim = self.in_channels * 4 self.time_mix_time_embed = nn.Sequential( linear(self.in_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, self.in_channels), ) self.time_mix_legacy = time_mix_legacy if self.time_mix_legacy: if merge_strategy == "fixed": self.register_buffer("mix_factor", torch.Tensor([merge_factor])) elif merge_strategy == "learned" or merge_strategy == "learned_with_images": self.register_parameter( "mix_factor", torch.nn.Parameter(torch.Tensor([merge_factor])) ) elif merge_strategy == "fixed_with_images": self.mix_factor = None else: raise ValueError(f"unknown merge strategy {merge_strategy}") self.get_alpha_fn = partial( get_alpha, merge_strategy, self.mix_factor, apply_sigmoid=apply_sigmoid_to_merge, is_attn=True, ) else: self.time_mixer = AlphaBlender( alpha=merge_factor, merge_strategy=merge_strategy ) def forward( self, x: torch.Tensor, context: Optional[torch.Tensor] = None, # cam: Optional[torch.Tensor] = None, time_context: Optional[torch.Tensor] = None, timesteps: Optional[int] = None, image_only_indicator: Optional[torch.Tensor] = None, cond_view: Optional[torch.Tensor] = None, cond_motion: Optional[torch.Tensor] = None, time_step: Optional[int] = None, name: Optional[str] = None, ) -> torch.Tensor: _, _, h, w = x.shape x_in = x spatial_context = None if exists(context): spatial_context = context if self.use_spatial_context: assert ( context.ndim == 3 ), f"n dims of spatial context should be 3 but are {context.ndim}" time_context = context time_context_first_timestep = time_context[::timesteps] time_context = repeat( time_context_first_timestep, "b ... -> (b n) ...", n=h * w ) elif time_context is not None and not self.use_spatial_context: time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w) if time_context.ndim == 2: time_context = rearrange(time_context, "b c -> b 1 c") x = self.norm(x) if not self.use_linear: x = self.proj_in(x) x = rearrange(x, "b c h w -> b (h w) c") if self.use_linear: x = self.proj_in(x) if self.time_mix_legacy: alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator) num_frames = torch.arange(timesteps, device=x.device) num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) num_frames = rearrange(num_frames, "b t -> (b t)") t_emb = timestep_embedding( num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period, ) emb = self.time_mix_time_embed(t_emb) emb = emb[:, None, :] for it_, (block, mix_block) in enumerate( zip(self.transformer_blocks, self.time_mix_blocks) ): # spatial attention x = block( x, context=spatial_context, time_step=time_step, name=name + '_' + str(it_) ) x_mix = x x_mix = x_mix + emb # temporal attention x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps) if self.time_mix_legacy: x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix else: x = self.time_mixer( x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator, ) if self.use_linear: x = self.proj_out(x) x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) if not self.use_linear: x = self.proj_out(x) out = x + x_in return out class PostHocSpatialTransformerWithTimeMixingAndMotion(SpatialTransformer): def __init__( self, in_channels, n_heads, d_head, depth=1, dropout=0.0, use_linear=False, context_dim=None, use_spatial_context=False, use_camera_emb=False, use_3d_attention=False, separate_motion_merge_factor=False, adm_in_channels=None, timesteps=None, merge_strategy: str = "fixed", merge_factor: float = 0.5, merge_factor_motion: float = 0.5, apply_sigmoid_to_merge: bool = True, time_context_dim=None, motion_context_dim=None, ff_in=False, checkpoint=False, time_depth=1, attn_mode="softmax", disable_self_attn=False, disable_temporal_crossattention=False, time_mix_legacy: bool = True, max_time_embed_period: int = 10000, ): super().__init__( in_channels, n_heads, d_head, depth=depth, dropout=dropout, attn_type=attn_mode, use_checkpoint=checkpoint, context_dim=context_dim, use_linear=use_linear, disable_self_attn=disable_self_attn, ) self.time_depth = time_depth self.depth = depth self.max_time_embed_period = max_time_embed_period self.use_camera_emb = use_camera_emb self.motion_context_dim = motion_context_dim self.use_3d_attention = use_3d_attention self.separate_motion_merge_factor = separate_motion_merge_factor time_mix_d_head = d_head n_time_mix_heads = n_heads time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads) inner_dim = n_heads * d_head if use_spatial_context: time_context_dim = context_dim # Camera attention layer self.time_mix_blocks = nn.ModuleList( [ BasicTransformerTimeMixBlock( inner_dim, n_time_mix_heads, time_mix_d_head, dropout=dropout, context_dim=time_context_dim, timesteps=timesteps, checkpoint=checkpoint, ff_in=ff_in, inner_dim=time_mix_inner_dim, attn_mode=attn_mode, disable_self_attn=disable_self_attn, disable_temporal_crossattention=disable_temporal_crossattention, ) for _ in range(self.depth) ] ) # Motion attention layer self.motion_blocks = nn.ModuleList( [ BasicTransformerTimeMixBlock( inner_dim, n_time_mix_heads, time_mix_d_head, dropout=dropout, context_dim=motion_context_dim, timesteps=timesteps, checkpoint=checkpoint, ff_in=ff_in, inner_dim=time_mix_inner_dim, attn_mode=attn_mode, disable_self_attn=disable_self_attn, disable_temporal_crossattention=disable_temporal_crossattention, ) for _ in range(self.depth) ] ) assert len(self.time_mix_blocks) == len(self.transformer_blocks) self.use_spatial_context = use_spatial_context self.in_channels = in_channels time_embed_dim = self.in_channels * 4 time_embed_channels = adm_in_channels if self.use_camera_emb else self.in_channels # Camera view embedding self.time_mix_time_embed = nn.Sequential( linear(time_embed_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, self.in_channels), ) # Motion time embedding self.time_mix_motion_embed = nn.Sequential( linear(self.in_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, self.in_channels), ) self.time_mix_legacy = time_mix_legacy if self.time_mix_legacy: if merge_strategy == "fixed": self.register_buffer("mix_factor", torch.Tensor([merge_factor])) elif merge_strategy == "learned" or merge_strategy == "learned_with_images": self.register_parameter( "mix_factor", torch.nn.Parameter(torch.Tensor([merge_factor])) ) elif merge_strategy == "fixed_with_images": self.mix_factor = None else: raise ValueError(f"unknown merge strategy {merge_strategy}") self.get_alpha_fn = partial( get_alpha, merge_strategy, self.mix_factor, apply_sigmoid=apply_sigmoid_to_merge, is_attn=True, ) else: self.time_mixer = AlphaBlender( alpha=merge_factor, merge_strategy=merge_strategy ) if self.separate_motion_merge_factor: self.time_mixer_motion = AlphaBlender( alpha=merge_factor_motion, merge_strategy=merge_strategy ) def forward( self, x: torch.Tensor, context: Optional[torch.Tensor] = None, cam: Optional[torch.Tensor] = None, time_context: Optional[torch.Tensor] = None, timesteps: Optional[int] = None, image_only_indicator: Optional[torch.Tensor] = None, cond_view: Optional[torch.Tensor] = None, cond_motion: Optional[torch.Tensor] = None, time_step: Optional[int] = None, name: Optional[str] = None, ) -> torch.Tensor: # context: b t 1024 # cond_view: b*v 4 h w # cond_motion: b*t 4 h w # image_only_indicator: b t*v b, t, d1 = context.shape # CLIP v, d2 = cond_view.shape[0]//b, cond_view.shape[1] # VAE _, c, h, w = x.shape x_in = x spatial_context = None if exists(context): spatial_context = context cond_view = torch.nn.functional.interpolate(cond_view, size=(h,w), mode="bilinear") # b*v d h w spatial_context = context[:,:,None].repeat(1,1,v,1).reshape(-1,1,d1) # (b*t*v) 1 d1 camera_context = context[:,:,None].repeat(1,1,h*w,1).reshape(-1,1,d1) # (b*t*h*w) 1 d1 motion_context = cond_view.permute(0,2,3,1).reshape(-1,1,d2) # (b*v*h*w) 1 d2 x = self.norm(x) if not self.use_linear: x = self.proj_in(x) x = rearrange(x, "b c h w -> b (h w) c") if self.use_linear: x = self.proj_in(x) if self.time_mix_legacy: alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator) num_frames = torch.arange(t, device=x.device) num_frames = repeat(num_frames, "t -> b t", b=b) num_frames = rearrange(num_frames, "b t -> (b t)") t_emb = timestep_embedding( num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period, ) emb_time = self.time_mix_motion_embed(t_emb) emb_time = emb_time[:, None, :] # b*t 1 c if self.use_camera_emb: emb_view = self.time_mix_time_embed(cam.view(b,t,v,-1)[:,0].reshape(b*v,-1)) emb_view = emb_view[:, None, :] else: num_views = torch.arange(v, device=x.device) num_views = repeat(num_views, "t -> b t", b=b) num_views = rearrange(num_views, "b t -> (b t)") v_emb = timestep_embedding( num_views, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period, ) emb_view = self.time_mix_time_embed(v_emb) emb_view = emb_view[:, None, :] # b*v 1 c if self.use_3d_attention: emb_view = emb_view.repeat(1, h*w, 1).view(-1,1,c) # b*v*h*w 1 c for it_, (block, time_block, mot_block) in enumerate( zip(self.transformer_blocks, self.time_mix_blocks, self.motion_blocks) ): # Spatial attention x = block( x, context=spatial_context, ) # Camera attention if self.use_3d_attention: x = x.view(b, t, v, h*w, c).permute(0,2,3,1,4).reshape(-1,t,c) # b*v*h*w t c else: x = x.view(b, t, v, h*w, c).permute(0,2,1,3,4).reshape(b*v,-1,c) # b*v t*h*w c x_mix = x + emb_view x_mix = time_block(x_mix, context=camera_context, timesteps=v) if self.time_mix_legacy: x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix else: x = self.time_mixer( x_spatial=x, x_temporal=x_mix, image_only_indicator=torch.zeros_like(image_only_indicator[:,:1].repeat(1,x.shape[0]//b)), ) # Motion attention if self.use_3d_attention: x = x.view(b, v, h*w, t, c).permute(0,3,1,2,4).reshape(b*t,-1,c) # b*t v*h*w c else: x = x.view(b, v, t, h*w, c).permute(0,2,1,3,4).reshape(b*t,-1,c) # b*t v*h*w c x_mix = x + emb_time x_mix = mot_block(x_mix, context=motion_context, timesteps=t) if self.time_mix_legacy: x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix else: motion_mixer = self.time_mixer_motion if self.separate_motion_merge_factor else self.time_mixer x = motion_mixer( x_spatial=x, x_temporal=x_mix, image_only_indicator=torch.zeros_like(image_only_indicator[:,:1].repeat(1,x.shape[0]//b)), ) x = x.view(b, t, v, h*w, c).reshape(-1,h*w,c) # b*t*v h*w c if self.use_linear: x = self.proj_out(x) x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) if not self.use_linear: x = self.proj_out(x) out = x + x_in return out ================================================ FILE: sgm/modules/video_attention.py ================================================ import torch from ..modules.attention import * from ..modules.diffusionmodules.util import (AlphaBlender, linear, timestep_embedding) class TimeMixSequential(nn.Sequential): def forward(self, x, context=None, timesteps=None): for layer in self: x = layer(x, context, timesteps) return x class VideoTransformerBlock(nn.Module): ATTENTION_MODES = { "softmax": CrossAttention, "softmax-xformers": MemoryEfficientCrossAttention, } def __init__( self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True, timesteps=None, ff_in=False, inner_dim=None, attn_mode="softmax", disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, ): super().__init__() attn_cls = self.ATTENTION_MODES[attn_mode] self.ff_in = ff_in or inner_dim is not None if inner_dim is None: inner_dim = dim assert int(n_heads * d_head) == inner_dim self.is_res = inner_dim == dim if self.ff_in: self.norm_in = nn.LayerNorm(dim) self.ff_in = FeedForward( dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff ) self.timesteps = timesteps self.disable_self_attn = disable_self_attn if self.disable_self_attn: self.attn1 = attn_cls( query_dim=inner_dim, heads=n_heads, dim_head=d_head, context_dim=context_dim, dropout=dropout, ) # is a cross-attention else: self.attn1 = attn_cls( query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout ) # is a self-attention self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff) if disable_temporal_crossattention: if switch_temporal_ca_to_sa: raise ValueError else: self.attn2 = None else: self.norm2 = nn.LayerNorm(inner_dim) if switch_temporal_ca_to_sa: self.attn2 = attn_cls( query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout ) # is a self-attention else: self.attn2 = attn_cls( query_dim=inner_dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout, ) # is self-attn if context is none self.norm1 = nn.LayerNorm(inner_dim) self.norm3 = nn.LayerNorm(inner_dim) self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa self.checkpoint = checkpoint if self.checkpoint: print(f"{self.__class__.__name__} is using checkpointing") def forward( self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None ) -> torch.Tensor: if self.checkpoint: return checkpoint(self._forward, x, context, timesteps) else: return self._forward(x, context, timesteps=timesteps) def _forward(self, x, context=None, timesteps=None): assert self.timesteps or timesteps assert not (self.timesteps and timesteps) or self.timesteps == timesteps timesteps = self.timesteps or timesteps B, S, C = x.shape x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps) if self.ff_in: x_skip = x x = self.ff_in(self.norm_in(x)) if self.is_res: x += x_skip if self.disable_self_attn: x = self.attn1(self.norm1(x), context=context) + x else: x = self.attn1(self.norm1(x)) + x if self.attn2 is not None: if self.switch_temporal_ca_to_sa: x = self.attn2(self.norm2(x)) + x else: x = self.attn2(self.norm2(x), context=context) + x x_skip = x x = self.ff(self.norm3(x)) if self.is_res: x += x_skip x = rearrange( x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps ) return x def get_last_layer(self): return self.ff.net[-1].weight class SpatialVideoTransformer(SpatialTransformer): def __init__( self, in_channels, n_heads, d_head, depth=1, dropout=0.0, use_linear=False, context_dim=None, use_spatial_context=False, timesteps=None, merge_strategy: str = "fixed", merge_factor: float = 0.5, time_context_dim=None, ff_in=False, checkpoint=False, time_depth=1, attn_mode="softmax", disable_self_attn=False, disable_temporal_crossattention=False, max_time_embed_period: int = 10000, ): super().__init__( in_channels, n_heads, d_head, depth=depth, dropout=dropout, attn_type=attn_mode, use_checkpoint=checkpoint, context_dim=context_dim, use_linear=use_linear, disable_self_attn=disable_self_attn, ) self.time_depth = time_depth self.depth = depth self.max_time_embed_period = max_time_embed_period time_mix_d_head = d_head n_time_mix_heads = n_heads time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads) inner_dim = n_heads * d_head if use_spatial_context: time_context_dim = context_dim self.time_stack = nn.ModuleList( [ VideoTransformerBlock( inner_dim, n_time_mix_heads, time_mix_d_head, dropout=dropout, context_dim=time_context_dim, timesteps=timesteps, checkpoint=checkpoint, ff_in=ff_in, inner_dim=time_mix_inner_dim, attn_mode=attn_mode, disable_self_attn=disable_self_attn, disable_temporal_crossattention=disable_temporal_crossattention, ) for _ in range(self.depth) ] ) assert len(self.time_stack) == len(self.transformer_blocks) self.use_spatial_context = use_spatial_context self.in_channels = in_channels time_embed_dim = self.in_channels * 4 self.time_pos_embed = nn.Sequential( linear(self.in_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, self.in_channels), ) self.time_mixer = AlphaBlender( alpha=merge_factor, merge_strategy=merge_strategy ) def forward( self, x: torch.Tensor, context: Optional[torch.Tensor] = None, time_context: Optional[torch.Tensor] = None, timesteps: Optional[int] = None, image_only_indicator: Optional[torch.Tensor] = None, ) -> torch.Tensor: _, _, h, w = x.shape x_in = x spatial_context = None if exists(context): spatial_context = context if self.use_spatial_context: assert ( context.ndim == 3 ), f"n dims of spatial context should be 3 but are {context.ndim}" time_context = context time_context_first_timestep = time_context[::timesteps] time_context = repeat( time_context_first_timestep, "b ... -> (b n) ...", n=h * w ) elif time_context is not None and not self.use_spatial_context: time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w) if time_context.ndim == 2: time_context = rearrange(time_context, "b c -> b 1 c") x = self.norm(x) if not self.use_linear: x = self.proj_in(x) x = rearrange(x, "b c h w -> b (h w) c") if self.use_linear: x = self.proj_in(x) num_frames = torch.arange(timesteps, device=x.device) num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) num_frames = rearrange(num_frames, "b t -> (b t)") t_emb = timestep_embedding( num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period, ) emb = self.time_pos_embed(t_emb) emb = emb[:, None, :] for it_, (block, mix_block) in enumerate( zip(self.transformer_blocks, self.time_stack) ): x = block( x, context=spatial_context, ) x_mix = x x_mix = x_mix + emb x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps) x = self.time_mixer( x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator, ) if self.use_linear: x = self.proj_out(x) x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) if not self.use_linear: x = self.proj_out(x) out = x + x_in return out ================================================ FILE: sgm/util.py ================================================ import functools import importlib import os from functools import partial from inspect import isfunction import fsspec import numpy as np import torch from PIL import Image, ImageDraw, ImageFont from safetensors.torch import load_file as load_safetensors def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self def get_string_from_tuple(s): try: # Check if the string starts and ends with parentheses if s[0] == "(" and s[-1] == ")": # Convert the string to a tuple t = eval(s) # Check if the type of t is tuple if type(t) == tuple: return t[0] else: pass except: pass return s def is_power_of_two(n): """ chat.openai.com/chat Return True if n is a power of 2, otherwise return False. The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False. The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False. If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise. Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False. """ if n <= 0: return False return (n & (n - 1)) == 0 def autocast(f, enabled=True): def do_autocast(*args, **kwargs): with torch.cuda.amp.autocast( enabled=enabled, dtype=torch.get_autocast_gpu_dtype(), cache_enabled=torch.is_autocast_cache_enabled(), ): return f(*args, **kwargs) return do_autocast def load_partial_from_config(config): return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) # xc a list of captions to plot b = len(xc) txts = list() for bi in range(b): txt = Image.new("RGB", wh, color="white") draw = ImageDraw.Draw(txt) font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) nc = int(40 * (wh[0] / 256)) if isinstance(xc[bi], list): text_seq = xc[bi][0] else: text_seq = xc[bi] lines = "\n".join( text_seq[start : start + nc] for start in range(0, len(text_seq), nc) ) try: draw.text((0, 0), lines, fill="black", font=font) except UnicodeEncodeError: print("Cant encode string for logging. Skipping.") txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 txts.append(txt) txts = np.stack(txts) txts = torch.tensor(txts) return txts def partialclass(cls, *args, **kwargs): class NewCls(cls): __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) return NewCls def make_path_absolute(path): fs, p = fsspec.core.url_to_fs(path) if fs.protocol == "file": return os.path.abspath(p) return path def ismap(x): if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] > 3) def isimage(x): if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) def isheatmap(x): if not isinstance(x, torch.Tensor): return False return x.ndim == 2 def isneighbors(x): if not isinstance(x, torch.Tensor): return False return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1) def exists(x): return x is not None def expand_dims_like(x, y): while x.dim() != y.dim(): x = x.unsqueeze(-1) return x def default(val, d): if exists(val): return val return d() if isfunction(d) else d def mean_flat(tensor): """ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 Take the mean over all non-batch dimensions. """ return tensor.mean(dim=list(range(1, len(tensor.shape)))) def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") return total_params def instantiate_from_config(config): if not "target" in config: if config == "__is_first_stage__": return None elif config == "__is_unconditional__": return None raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) def get_obj_from_str(string, reload=False, invalidate_cache=True): module, cls = string.rsplit(".", 1) if invalidate_cache: importlib.invalidate_caches() if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def append_zero(x): return torch.cat([x, x.new_zeros([1])]) def append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: raise ValueError( f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" ) return x[(...,) + (None,) * dims_to_append] def load_model_from_config(config, ckpt, verbose=True, freeze=True): print(f"Loading model from {ckpt}") if ckpt.endswith("ckpt"): pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] elif ckpt.endswith("safetensors"): sd = load_safetensors(ckpt) else: raise NotImplementedError model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") print(m) if len(u) > 0 and verbose: print("unexpected keys:") print(u) if freeze: for param in model.parameters(): param.requires_grad = False model.eval() return model def get_configs_path() -> str: """ Get the `configs` directory. For a working copy, this is the one in the root of the repository, but for an installed copy, it's in the `sgm` package (see pyproject.toml). """ this_dir = os.path.dirname(__file__) candidates = ( os.path.join(this_dir, "configs"), os.path.join(this_dir, "..", "configs"), ) for candidate in candidates: candidate = os.path.abspath(candidate) if os.path.isdir(candidate): return candidate raise FileNotFoundError(f"Could not find SGM configs in {candidates}") def get_nested_attribute(obj, attribute_path, depth=None, return_key=False): """ Will return the result of a recursive get attribute call. E.g.: a.b.c = getattr(getattr(a, "b"), "c") = get_nested_attribute(a, "b.c") If any part of the attribute call is an integer x with current obj a, will try to call a[x] instead of a.x first. """ attributes = attribute_path.split(".") if depth is not None and depth > 0: attributes = attributes[:depth] assert len(attributes) > 0, "At least one attribute should be selected" current_attribute = obj current_key = None for level, attribute in enumerate(attributes): current_key = ".".join(attributes[: level + 1]) try: id_ = int(attribute) current_attribute = current_attribute[id_] except ValueError: current_attribute = getattr(current_attribute, attribute) return (current_attribute, current_key) if return_key else current_attribute ================================================ FILE: tests/inference/test_inference.py ================================================ import numpy from PIL import Image import pytest from pytest import fixture import torch from typing import Tuple from sgm.inference.api import ( model_specs, SamplingParams, SamplingPipeline, Sampler, ModelArchitecture, ) import sgm.inference.helpers as helpers @pytest.mark.inference class TestInference: @fixture(scope="class", params=model_specs.keys()) def pipeline(self, request) -> SamplingPipeline: pipeline = SamplingPipeline(request.param) yield pipeline del pipeline torch.cuda.empty_cache() @fixture( scope="class", params=[ [ModelArchitecture.SDXL_V1_BASE, ModelArchitecture.SDXL_V1_REFINER], [ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER], ], ids=["SDXL_V1", "SDXL_V0_9"], ) def sdxl_pipelines(self, request) -> Tuple[SamplingPipeline, SamplingPipeline]: base_pipeline = SamplingPipeline(request.param[0]) refiner_pipeline = SamplingPipeline(request.param[1]) yield base_pipeline, refiner_pipeline del base_pipeline del refiner_pipeline torch.cuda.empty_cache() def create_init_image(self, h, w): image_array = numpy.random.rand(h, w, 3) * 255 image = Image.fromarray(image_array.astype("uint8")).convert("RGB") return helpers.get_input_image_tensor(image) @pytest.mark.parametrize("sampler_enum", Sampler) def test_txt2img(self, pipeline: SamplingPipeline, sampler_enum): output = pipeline.text_to_image( params=SamplingParams(sampler=sampler_enum.value, steps=10), prompt="A professional photograph of an astronaut riding a pig", negative_prompt="", samples=1, ) assert output is not None @pytest.mark.parametrize("sampler_enum", Sampler) def test_img2img(self, pipeline: SamplingPipeline, sampler_enum): output = pipeline.image_to_image( params=SamplingParams(sampler=sampler_enum.value, steps=10), image=self.create_init_image(pipeline.specs.height, pipeline.specs.width), prompt="A professional photograph of an astronaut riding a pig", negative_prompt="", samples=1, ) assert output is not None @pytest.mark.parametrize("sampler_enum", Sampler) @pytest.mark.parametrize( "use_init_image", [True, False], ids=["img2img", "txt2img"] ) def test_sdxl_with_refiner( self, sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline], sampler_enum, use_init_image, ): base_pipeline, refiner_pipeline = sdxl_pipelines if use_init_image: output = base_pipeline.image_to_image( params=SamplingParams(sampler=sampler_enum.value, steps=10), image=self.create_init_image( base_pipeline.specs.height, base_pipeline.specs.width ), prompt="A professional photograph of an astronaut riding a pig", negative_prompt="", samples=1, return_latents=True, ) else: output = base_pipeline.text_to_image( params=SamplingParams(sampler=sampler_enum.value, steps=10), prompt="A professional photograph of an astronaut riding a pig", negative_prompt="", samples=1, return_latents=True, ) assert isinstance(output, (tuple, list)) samples, samples_z = output assert samples is not None assert samples_z is not None refiner_pipeline.refiner( params=SamplingParams(sampler=sampler_enum.value, steps=10), image=samples_z, prompt="A professional photograph of an astronaut riding a pig", negative_prompt="", samples=1, )