[
  {
    "path": ".github/workflows/black.yml",
    "content": "name: Run black\non: [pull_request]\n\njobs:\n  lint:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v3\n      - name: Install venv\n        run: |\n          sudo apt-get -y install python3.10-venv\n      - uses: psf/black@stable\n        with:\n          options: \"--check --verbose -l88\"\n          src: \"./sgm ./scripts ./main.py\"\n"
  },
  {
    "path": ".github/workflows/test-build.yaml",
    "content": "name: Build package\n\non:\n  push:\n    branches: [ main ]\n  pull_request:\n\njobs:\n  build:\n    name: Build\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        python-version: [\"3.8\", \"3.10\"]\n        requirements-file: [\"pt2\", \"pt13\"]\n    steps:\n      - uses: actions/checkout@v2\n      - name: Set up Python ${{ matrix.python-version }}\n        uses: actions/setup-python@v2\n        with:\n          python-version: ${{ matrix.python-version }}\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade pip\n          pip install -r requirements/${{ matrix.requirements-file }}.txt\n          pip install ."
  },
  {
    "path": ".github/workflows/test-inference.yml",
    "content": "name: Test inference\r\n\r\non:\r\n  pull_request:\r\n  push:\r\n    branches:\r\n      - main\r\n\r\njobs:\r\n  test:\r\n    name: \"Test inference\"\r\n    # This action is designed only to run on the Stability research cluster at this time, so many assumptions are made about the environment\r\n    if: github.repository == 'stability-ai/generative-models'\r\n    runs-on: [self-hosted, slurm, g40]\r\n    steps:\r\n      - uses: actions/checkout@v3\r\n      - name: \"Symlink checkpoints\"\r\n        run: ln -s ${{vars.SGM_CHECKPOINTS_PATH}} checkpoints\r\n      - name: \"Setup python\"\r\n        uses: actions/setup-python@v4\r\n        with:\r\n          python-version: \"3.10\"\r\n      - name: \"Install Hatch\"\r\n        run: pip install hatch\r\n      - name: \"Run inference tests\"\r\n        run: hatch run ci:test-inference --junit-xml test-results.xml\r\n      - name: Surface failing tests\r\n        if: always()\r\n        uses: pmeier/pytest-results-action@main\r\n        with:\r\n          path: test-results.xml\r\n          summary: true\r\n          display-options: fEX\r\n          fail-on-empty: true\r\n"
  },
  {
    "path": ".gitignore",
    "content": "# extensions\n*.egg-info\n*.py[cod]\n\n# envs\n.pt13\n.pt2\n\n# directories\n/checkpoints\n/dist\n/outputs\n/build\n/src\n/.vscode\n**/__pycache__/\n"
  },
  {
    "path": "CODEOWNERS",
    "content": ".github @Stability-AI/infrastructure"
  },
  {
    "path": "LICENSE-CODE",
    "content": "MIT License\n\nCopyright (c) 2023 Stability AI\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE."
  },
  {
    "path": "README.md",
    "content": "# Generative Models by Stability AI\n\n![sample1](assets/000.jpg)\n\n## News\n\n\n**May 20, 2025**\n- We are releasing **[Stable Video 4D 2.0 (SV4D 2.0)](https://huggingface.co/stabilityai/sv4d2.0)**, an enhanced video-to-4D diffusion model for high-fidelity novel-view video synthesis and 4D asset generation. For research purposes:\n    - **SV4D 2.0** was trained to generate 48 frames (12 video frames x 4 camera views) at 576x576 resolution, given a 12-frame input video of the same size, ideally consisting of white-background images of a moving object.\n    - Compared to our previous 4D model [SV4D](https://huggingface.co/stabilityai/sv4d), **SV4D 2.0** can generate videos with higher fidelity, sharper details during motion, and better spatio-temporal consistency. It also generalizes much better to real-world videos. Moreover, it does not rely on refernce multi-view of the first frame generated by SV3D, making it more robust to self-occlusions.\n    - To generate longer novel-view videos, we autoregressively generate 12 frames at a time and use the previous generation as conditioning views for the remaining frames.\n    - Please check our [project page](https://sv4d20.github.io), [arxiv paper](https://arxiv.org/pdf/2503.16396) and [video summary](https://www.youtube.com/watch?v=dtqj-s50ynU) for more details.\n\n**QUICKSTART** :\n- `python scripts/sampling/simple_video_sample_4d2.py --input_path assets/sv4d_videos/camel.gif --output_folder outputs` (after downloading [sv4d2.safetensors](https://huggingface.co/stabilityai/sv4d2.0) from HuggingFace into `checkpoints/`)\n\nTo run **SV4D 2.0** on a single input video of 21 frames:\n- Download SV4D 2.0 model (`sv4d2.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d2.0) to `checkpoints/`: `huggingface-cli download stabilityai/sv4d2.0 sv4d2.safetensors --local-dir checkpoints`\n- Run inference: `python scripts/sampling/simple_video_sample_4d2.py --input_path <path/to/video>`\n    - `input_path` : The input video `<path/to/video>` can be\n      - a single video file in `gif` or `mp4` format, such as `assets/sv4d_videos/camel.gif`, or\n      - a folder containing images of video frames in `.jpg`, `.jpeg`, or `.png` format, or\n      - a file name pattern matching images of video frames.\n    - `num_steps` : default is 50, can decrease to it to shorten sampling time.\n    - `elevations_deg` : specified elevations (reletive to input view), default is 0.0 (same as input view).\n    - **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Clipdrop](https://clipdrop.co/) or [SAM2](https://github.com/facebookresearch/segment-anything-2) before running SV4D.\n    - **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--encoding_t=1` (of frames encoded at a time) and `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`.\n\nNotes:\n- We also train a 8-view model that generates 5 frames x 8 views at a time (same as SV4D).\n  - Download the model from huggingface: `huggingface-cli download stabilityai/sv4d2.0 sv4d2_8views.safetensors --local-dir checkpoints`\n  - Run inference: `python scripts/sampling/simple_video_sample_4d2.py --model_path checkpoints/sv4d2_8views.safetensors --input_path assets/sv4d_videos/chest.gif --output_folder outputs`\n  - The 5x8 model takes 5 frames of input at a time. But the inference scripts for both model take 21-frame video as input by default (same as SV3D and SV4D), we run the model autoregressively until we generate 21 frames.\n- Install dependencies before running:\n```\npython3.10 -m venv .generativemodels\nsource .generativemodels/bin/activate\npip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # check CUDA version\npip3 install -r requirements/pt2.txt\npip3 install .\npip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata\n```\n\n  ![tile](assets/sv4d2.gif)\n\n\n**July 24, 2024**\n- We are releasing **[Stable Video 4D (SV4D)](https://huggingface.co/stabilityai/sv4d)**, a video-to-4D diffusion model for novel-view video synthesis. For research purposes:\n    - **SV4D** was trained to generate 40 frames (5 video frames x 8 camera views) at 576x576 resolution, given 5 context frames (the input video), and 8 reference views (synthesised from the first frame of the input video, using a multi-view diffusion model like SV3D) of the same size, ideally white-background images with one object.\n    - To generate longer novel-view videos (21 frames), we propose a novel sampling method using SV4D, by first sampling 5 anchor frames and then densely sampling the remaining frames while maintaining temporal consistency.\n    - To run the community-build gradio demo locally, run `python -m scripts.demo.gradio_app_sv4d`.\n    - Please check our [project page](https://sv4d.github.io), [tech report](https://sv4d.github.io/static/sv4d_technical_report.pdf) and [video summary](https://www.youtube.com/watch?v=RBP8vdAWTgk) for more details.\n\n**QUICKSTART** : `python scripts/sampling/simple_video_sample_4d.py --input_path assets/sv4d_videos/test_video1.mp4 --output_folder outputs/sv4d` (after downloading [sv4d.safetensors](https://huggingface.co/stabilityai/sv4d) and [sv3d_u.safetensors](https://huggingface.co/stabilityai/sv3d) from HuggingFace into `checkpoints/`)\n\nTo run **SV4D** on a single input video of 21 frames:\n- Download SV3D models (`sv3d_u.safetensors` and `sv3d_p.safetensors`) from [here](https://huggingface.co/stabilityai/sv3d) and SV4D model (`sv4d.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d) to `checkpoints/`\n- Run `python scripts/sampling/simple_video_sample_4d.py --input_path <path/to/video>`\n    - `input_path` : The input video `<path/to/video>` can be\n      - a single video file in `gif` or `mp4` format, such as `assets/sv4d_videos/test_video1.mp4`, or\n      - a folder containing images of video frames in `.jpg`, `.jpeg`, or `.png` format, or\n      - a file name pattern matching images of video frames.\n    - `num_steps` : default is 20, can increase to 50 for better quality but longer sampling time.\n    - `sv3d_version` : To specify the SV3D model to generate reference multi-views, set `--sv3d_version=sv3d_u` for SV3D_u or `--sv3d_version=sv3d_p` for SV3D_p.\n    - `elevations_deg` : To generate novel-view videos at a specified elevation (default elevation is 10) using SV3D_p (default is SV3D_u), run `python scripts/sampling/simple_video_sample_4d.py --input_path assets/sv4d_videos/test_video1.mp4 --sv3d_version sv3d_p --elevations_deg 30.0`\n    - **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Clipdrop](https://clipdrop.co/) or [SAM2](https://github.com/facebookresearch/segment-anything-2) before running SV4D.\n    - **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--encoding_t=1` (of frames encoded at a time) and `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`.\n\n  ![tile](assets/sv4d.gif)\n\n\n**March 18, 2024**\n- We are releasing **[SV3D](https://huggingface.co/stabilityai/sv3d)**, an image-to-video model for novel multi-view synthesis, for research purposes:\n    - **SV3D** was trained to generate 21 frames at resolution 576x576, given 1 context frame of the same size, ideally a white-background image with one object.\n    - **SV3D_u**: This variant generates orbital videos based on single image inputs without camera conditioning..\n    - **SV3D_p**: Extending the capability of **SVD3_u**, this variant accommodates both single images and orbital views allowing for the creation of 3D video along specified camera paths.\n    - We extend the streamlit demo `scripts/demo/video_sampling.py` and the standalone python script `scripts/sampling/simple_video_sample.py` for inference of both models.\n    - Please check our [project page](https://sv3d.github.io), [tech report](https://sv3d.github.io/static/paper.pdf) and [video summary](https://youtu.be/Zqw4-1LcfWg) for more details.\n\nTo run **SV3D_u** on a single image:\n- Download `sv3d_u.safetensors` from https://huggingface.co/stabilityai/sv3d to `checkpoints/sv3d_u.safetensors`\n- Run `python scripts/sampling/simple_video_sample.py --input_path <path/to/image.png> --version sv3d_u`\n\nTo run **SV3D_p** on a single image:\n- Download `sv3d_p.safetensors` from https://huggingface.co/stabilityai/sv3d to `checkpoints/sv3d_p.safetensors`\n1. Generate static orbit at a specified elevation eg. 10.0 : `python scripts/sampling/simple_video_sample.py --input_path <path/to/image.png> --version sv3d_p --elevations_deg 10.0`\n2. Generate dynamic orbit at a specified elevations and azimuths: specify sequences of 21 elevations (in degrees) to `elevations_deg` ([-90, 90]), and 21 azimuths (in degrees) to `azimuths_deg` [0, 360] in sorted order from 0 to 360. For example: `python scripts/sampling/simple_video_sample.py --input_path <path/to/image.png> --version sv3d_p --elevations_deg [<list of 21 elevations in degrees>] --azimuths_deg [<list of 21 azimuths in degrees>]`\n\nTo run SVD or SV3D on a streamlit server:\n`streamlit run scripts/demo/video_sampling.py`\n\n  ![tile](assets/sv3d.gif)\n\n\n**November 28, 2023**\n- We are releasing SDXL-Turbo, a lightning fast text-to image model.\n  Alongside the model, we release a [technical report](https://stability.ai/research/adversarial-diffusion-distillation)\n    - Usage:\n        - Follow the installation instructions or update the existing environment with `pip install streamlit-keyup`.\n        - Download the [weights](https://huggingface.co/stabilityai/sdxl-turbo) and place them in the `checkpoints/` directory.\n        - Run `streamlit run scripts/demo/turbo.py`.\n\n  ![tile](assets/turbo_tile.png)\n\n\n**November 21, 2023**\n- We are releasing Stable Video Diffusion, an image-to-video model, for research purposes:\n    - [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid): This model was trained to generate 14\n      frames at resolution 576x1024 given a context frame of the same size.\n      We use the standard image encoder from SD 2.1, but replace the decoder with a temporally-aware `deflickering decoder`.\n    - [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt): Same architecture as `SVD` but finetuned\n      for 25 frame generation.\n    - You can run the community-build gradio demo locally by running `python -m scripts.demo.gradio_app`.\n    - We provide a streamlit demo `scripts/demo/video_sampling.py` and a standalone python script `scripts/sampling/simple_video_sample.py` for inference of both models.\n    - Alongside the model, we release a [technical report](https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets).\n\n  ![tile](assets/tile.gif)\n\n**July 26, 2023**\n\n- We are releasing two new open models with a\n  permissive [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0) (see [Inference](#inference) for file\n  hashes):\n    - [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0): An improved version\n      over `SDXL-base-0.9`.\n    - [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0): An improved version\n      over `SDXL-refiner-0.9`.\n\n![sample2](assets/001_with_eval.png)\n\n**July 4, 2023**\n\n- A technical report on SDXL is now available [here](https://arxiv.org/abs/2307.01952).\n\n**June 22, 2023**\n\n- We are releasing two new diffusion models for research purposes:\n    - `SDXL-base-0.9`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The\n      base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip)\n      and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses\n      the OpenCLIP model.\n    - `SDXL-refiner-0.9`: The refiner has been trained to denoise small noise levels of high quality data and as such is\n      not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model.\n\nIf you would like to access these models for your research, please apply using one of the following links:\n[SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),\nand [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).\nThis means that you can apply for any of the two links - and if you are granted - you can access both.\nPlease log in to your Hugging Face Account with your organization email to request access.\n**We plan to do a full release soon (July).**\n\n## The codebase\n\n### General Philosophy\n\nModularity is king. This repo implements a config-driven approach where we build and combine submodules by\ncalling `instantiate_from_config()` on objects defined in yaml configs. See `configs/` for many examples.\n\n### Changelog from the old `ldm` codebase\n\nFor training, we use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), but it should be easy to use other\ntraining wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`,\nnow `DiffusionEngine`) has been cleaned up:\n\n- No more extensive subclassing! We now handle all types of conditioning inputs (vectors, sequences and spatial\n  conditionings, and all combinations thereof) in a single class: `GeneralConditioner`,\n  see `sgm/modules/encoders/modules.py`.\n- We separate guiders (such as classifier-free guidance, see `sgm/modules/diffusionmodules/guiders.py`) from the\n  samplers (`sgm/modules/diffusionmodules/sampling.py`), and the samplers are independent of the model.\n- We adopt the [\"denoiser framework\"](https://arxiv.org/abs/2206.00364) for both training and inference (most notable\n  change is probably now the option to train continuous time models):\n    * Discrete times models (denoisers) are simply a special case of continuous time models (denoisers);\n      see `sgm/modules/diffusionmodules/denoiser.py`.\n    * The following features are now independent: weighting of the diffusion loss\n      function (`sgm/modules/diffusionmodules/denoiser_weighting.py`), preconditioning of the\n      network (`sgm/modules/diffusionmodules/denoiser_scaling.py`), and sampling of noise levels during\n      training (`sgm/modules/diffusionmodules/sigma_sampling.py`).\n- Autoencoding models have also been cleaned up.\n\n## Installation:\n\n<a name=\"installation\"></a>\n\n#### 1. Clone the repo\n\n```shell\ngit clone https://github.com/Stability-AI/generative-models.git\ncd generative-models\n```\n\n#### 2. Setting up the virtualenv\n\nThis is assuming you have navigated to the `generative-models` root after cloning it.\n\n**NOTE:** This is tested under `python3.10`. For other python versions, you might encounter version conflicts.\n\n**PyTorch 2.0**\n\n```shell\n# install required packages from pypi\npython3 -m venv .pt2\nsource .pt2/bin/activate\npip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\npip3 install -r requirements/pt2.txt\n```\n\n#### 3. Install `sgm`\n\n```shell\npip3 install .\n```\n\n#### 4. Install `sdata` for training\n\n```shell\npip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata\n```\n\n## Packaging\n\nThis repository uses PEP 517 compliant packaging using [Hatch](https://hatch.pypa.io/latest/).\n\nTo build a distributable wheel, install `hatch` and run `hatch build`\n(specifying `-t wheel` will skip building a sdist, which is not necessary).\n\n```\npip install hatch\nhatch build -t wheel\n```\n\nYou will find the built package in `dist/`. You can install the wheel with `pip install dist/*.whl`.\n\nNote that the package does **not** currently specify dependencies; you will need to install the required packages,\ndepending on your use case and PyTorch version, manually.\n\n## Inference\n\nWe provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling\nin `scripts/demo/sampling.py`.\nWe provide file hashes for the complete file as well as for only the saved tensors in the file (\nsee [Model Spec](https://github.com/Stability-AI/ModelSpec) for a script to evaluate that).\nThe following models are currently supported:\n\n- [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n  ```\n  File Hash (sha256): 31e35c80fc4829d14f90153f4c74cd59c90b779f6afe05a74cd6120b893f7e5b\n  Tensordata Hash (sha256): 0xd7a9105a900fd52748f20725fe52fe52b507fd36bee4fc107b1550a26e6ee1d7\n  ```\n- [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)\n  ```\n  File Hash (sha256): 7440042bbdc8a24813002c09b6b69b64dc90fded4472613437b7f55f9b7d9c5f\n  Tensordata Hash (sha256): 0x1a77d21bebc4b4de78c474a90cb74dc0d2217caf4061971dbfa75ad406b75d81\n  ```\n- [SDXL-base-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9)\n- [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9)\n\n**Weights for SDXL**:\n\n**SDXL-1.0:**\nThe weights of SDXL-1.0 are available (subject to\na [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0)) here:\n\n- base model: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/\n- refiner model: https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/\n\n**SDXL-0.9:**\nThe weights of SDXL-0.9 are available and subject to a [research license](model_licenses/LICENSE-SDXL0.9).\nIf you would like to access these models for your research, please apply using one of the following links:\n[SDXL-base-0.9 model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),\nand [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).\nThis means that you can apply for any of the two links - and if you are granted - you can access both.\nPlease log in to your Hugging Face Account with your organization email to request access.\n\nAfter obtaining the weights, place them into `checkpoints/`.\nNext, start the demo using\n\n```\nstreamlit run scripts/demo/sampling.py --server.port <your_port>\n```\n\n### Invisible Watermark Detection\n\nImages generated with our code use the\n[invisible-watermark](https://github.com/ShieldMnt/invisible-watermark/)\nlibrary to embed an invisible watermark into the model output. We also provide\na script to easily detect that watermark. Please note that this watermark is\nnot the same as in previous Stable Diffusion 1.x/2.x versions.\n\nTo run the script you need to either have a working installation as above or\ntry an _experimental_ import using only a minimal amount of packages:\n\n```bash\npython -m venv .detect\nsource .detect/bin/activate\n\npip install \"numpy>=1.17\" \"PyWavelets>=1.1.1\" \"opencv-python>=4.1.0.25\"\npip install --no-deps invisible-watermark\n```\n\nTo run the script you need to have a working installation as above. The script\nis then useable in the following ways (don't forget to activate your\nvirtual environment beforehand, e.g. `source .pt1/bin/activate`):\n\n```bash\n# test a single file\npython scripts/demo/detect.py <your filename here>\n# test multiple files at once\npython scripts/demo/detect.py <filename 1> <filename 2> ... <filename n>\n# test all files in a specific folder\npython scripts/demo/detect.py <your folder name here>/*\n```\n\n## Training:\n\nWe are providing example training configs in `configs/example_training`. To launch a training, run\n\n```\npython main.py --base configs/<config1.yaml> configs/<config2.yaml>\n```\n\nwhere configs are merged from left to right (later configs overwrite the same values).\nThis can be used to combine model, training and data configs. However, all of them can also be\ndefined in a single config. For example, to run a class-conditional pixel-based diffusion model training on MNIST,\nrun\n\n```bash\npython main.py --base configs/example_training/toy/mnist_cond.yaml\n```\n\n**NOTE 1:** Using the non-toy-dataset\nconfigs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml`\nand `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depending on the\nused dataset (which is expected to stored in tar-file in\nthe [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search\nfor comments containing `USER:` in the respective config.\n\n**NOTE 2:** This repository supports both `pytorch1.13` and `pytorch2`for training generative models. However for\nautoencoder training as e.g. in `configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml`,\nonly `pytorch1.13` is supported.\n\n**NOTE 3:** Training latent generative models (as e.g. in `configs/example_training/imagenet-f8_cond.yaml`) requires\nretrieving the checkpoint from [Hugging Face](https://huggingface.co/stabilityai/sdxl-vae/tree/main) and replacing\nthe `CKPT_PATH` placeholder in [this line](configs/example_training/imagenet-f8_cond.yaml#81). The same is to be done\nfor the provided text-to-image configs.\n\n### Building New Diffusion Models\n\n#### Conditioner\n\nThe `GeneralConditioner` is configured through the `conditioner_config`. Its only attribute is `emb_models`, a list of\ndifferent embedders (all inherited from `AbstractEmbModel`) that are used to condition the generative model.\nAll embedders should define whether or not they are trainable (`is_trainable`, default `False`), a classifier-free\nguidance dropout rate is used (`ucg_rate`, default `0`), and an input key (`input_key`), for example, `txt` for\ntext-conditioning or `cls` for class-conditioning.\nWhen computing conditionings, the embedder will get `batch[input_key]` as input.\nWe currently support two to four dimensional conditionings and conditionings of different embedders are concatenated\nappropriately.\nNote that the order of the embedders in the `conditioner_config` is important.\n\n#### Network\n\nThe neural network is set through the `network_config`. This used to be called `unet_config`, which is not general\nenough as we plan to experiment with transformer-based diffusion backbones.\n\n#### Loss\n\nThe loss is configured through `loss_config`. For standard diffusion model training, you will have to\nset `sigma_sampler_config`.\n\n#### Sampler config\n\nAs discussed above, the sampler is independent of the model. In the `sampler_config`, we set the type of numerical\nsolver, number of steps, type of discretization, as well as, for example, guidance wrappers for classifier-free\nguidance.\n\n### Dataset Handling\n\nFor large scale training we recommend using the data pipelines from\nour [data pipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement\nand automatically included when following the steps from the [Installation section](#installation).\nSmall map-style datasets should be defined here in the repository (e.g., MNIST, CIFAR-10, ...), and return a dict of\ndata keys/values,\ne.g.,\n\n```python\nexample = {\"jpg\": x,  # this is a tensor -1...1 chw\n           \"txt\": \"a beautiful image\"}\n```\n\nwhere we expect images in -1...1, channel-first format.\n"
  },
  {
    "path": "configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml",
    "content": "model:\n  base_learning_rate: 4.5e-6\n  target: sgm.models.autoencoder.AutoencodingEngine\n  params:\n    input_key: jpg\n    monitor: val/rec_loss\n\n    loss_config:\n      target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator\n      params:\n        perceptual_weight: 0.25\n        disc_start: 20001\n        disc_weight: 0.5\n        learn_logvar: True\n\n        regularization_weights:\n          kl_loss: 1.0\n\n    regularizer_config:\n      target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer\n\n    encoder_config:\n      target: sgm.modules.diffusionmodules.model.Encoder\n      params:\n        attn_type: none\n        double_z: True\n        z_channels: 4\n        resolution: 256\n        in_channels: 3\n        out_ch: 3\n        ch: 128\n        ch_mult: [1, 2, 4]\n        num_res_blocks: 4\n        attn_resolutions: []\n        dropout: 0.0\n\n    decoder_config:\n      target: sgm.modules.diffusionmodules.model.Decoder\n      params: ${model.params.encoder_config.params}\n\ndata:\n  target: sgm.data.dataset.StableDataModuleFromConfig\n  params:\n    train:\n      datapipeline:\n        urls:\n          - DATA-PATH\n        pipeline_config:\n          shardshuffle: 10000\n          sample_shuffle: 10000\n\n        decoders:\n          - pil\n\n        postprocessors:\n          - target: sdata.mappers.TorchVisionImageTransforms\n            params:\n              key: jpg\n              transforms:\n                - target: torchvision.transforms.Resize\n                  params:\n                    size: 256\n                    interpolation: 3\n                - target: torchvision.transforms.ToTensor\n          - target: sdata.mappers.Rescaler\n          - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare\n            params:\n              h_key: height\n              w_key: width\n\n      loader:\n        batch_size: 8\n        num_workers: 4\n\n\nlightning:\n  strategy:\n    target: pytorch_lightning.strategies.DDPStrategy\n    params:\n      find_unused_parameters: True\n\n  modelcheckpoint:\n    params:\n      every_n_train_steps: 5000\n\n  callbacks:\n    metrics_over_trainsteps_checkpoint:\n      params:\n        every_n_train_steps: 50000\n\n    image_logger:\n      target: main.ImageLogger\n      params:\n        enable_autocast: False\n        batch_frequency: 1000\n        max_images: 8\n        increase_log_steps: True\n\n  trainer:\n    devices: 0,\n    limit_val_batches: 50\n    benchmark: True\n    accumulate_grad_batches: 1\n    val_check_interval: 10000"
  },
  {
    "path": "configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml",
    "content": "model:\n  base_learning_rate: 4.5e-6\n  target: sgm.models.autoencoder.AutoencodingEngine\n  params:\n    input_key: jpg\n    monitor: val/loss/rec\n    disc_start_iter: 0\n\n    encoder_config:\n      target: sgm.modules.diffusionmodules.model.Encoder\n      params:\n        attn_type: vanilla-xformers\n        double_z: true\n        z_channels: 8\n        resolution: 256\n        in_channels: 3\n        out_ch: 3\n        ch: 128\n        ch_mult: [1, 2, 4, 4]\n        num_res_blocks: 2\n        attn_resolutions: []\n        dropout: 0.0\n\n    decoder_config:\n      target: sgm.modules.diffusionmodules.model.Decoder\n      params: ${model.params.encoder_config.params}\n\n    regularizer_config:\n      target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer\n\n    loss_config:\n      target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator\n      params:\n        perceptual_weight: 0.25\n        disc_start: 20001\n        disc_weight: 0.5\n        learn_logvar: True\n\n        regularization_weights:\n          kl_loss: 1.0\n\ndata:\n  target: sgm.data.dataset.StableDataModuleFromConfig\n  params:\n    train:\n      datapipeline:\n        urls:\n          - DATA-PATH\n        pipeline_config:\n          shardshuffle: 10000\n          sample_shuffle: 10000\n\n        decoders:\n          - pil\n\n        postprocessors:\n          - target: sdata.mappers.TorchVisionImageTransforms\n            params:\n              key: jpg\n              transforms:\n                - target: torchvision.transforms.Resize\n                  params:\n                    size: 256\n                    interpolation: 3\n                - target: torchvision.transforms.ToTensor\n          - target: sdata.mappers.Rescaler\n          - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare\n            params:\n              h_key: height\n              w_key: width\n\n      loader:\n        batch_size: 8\n        num_workers: 4\n\n\nlightning:\n  strategy:\n    target: pytorch_lightning.strategies.DDPStrategy\n    params:\n      find_unused_parameters: True\n\n  modelcheckpoint:\n    params:\n      every_n_train_steps: 5000\n\n  callbacks:\n    metrics_over_trainsteps_checkpoint:\n      params:\n        every_n_train_steps: 50000\n\n    image_logger:\n      target: main.ImageLogger\n      params:\n        enable_autocast: False\n        batch_frequency: 1000\n        max_images: 8\n        increase_log_steps: True\n\n  trainer:\n    devices: 0,\n    limit_val_batches: 50\n    benchmark: True\n    accumulate_grad_batches: 1\n    val_check_interval: 10000\n"
  },
  {
    "path": "configs/example_training/imagenet-f8_cond.yaml",
    "content": "model:\n  base_learning_rate: 1.0e-4\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    scale_factor: 0.13025\n    disable_first_stage_autocast: True\n    log_keys:\n      - cls\n\n    scheduler_config:\n      target: sgm.lr_scheduler.LambdaLinearScheduler\n      params:\n        warm_up_steps: [10000]\n        cycle_lengths: [10000000000000]\n        f_start: [1.e-6]\n        f_max: [1.]\n        f_min: [1.]\n\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser\n      params:\n        num_idx: 1000\n\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization\n\n    network_config:\n      target: sgm.modules.diffusionmodules.openaimodel.UNetModel\n      params:\n        use_checkpoint: True\n        in_channels: 4\n        out_channels: 4\n        model_channels: 256\n        attention_resolutions: [1, 2, 4]\n        num_res_blocks: 2\n        channel_mult: [1, 2, 4]\n        num_head_channels: 64\n        num_classes: sequential\n        adm_in_channels: 1024\n        transformer_depth: 1\n        context_dim: 1024\n        spatial_transformer_attn_type: softmax-xformers\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n          - is_trainable: True\n            input_key: cls\n            ucg_rate: 0.2\n            target: sgm.modules.encoders.modules.ClassEmbedder\n            params:\n              add_sequence_dim: True\n              embed_dim: 1024\n              n_classes: 1000\n\n          - is_trainable: False\n            ucg_rate: 0.2\n            input_key: original_size_as_tuple\n            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n            params:\n              outdim: 256\n\n          - is_trainable: False\n            input_key: crop_coords_top_left\n            ucg_rate: 0.2\n            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n            params:\n              outdim: 256\n\n    first_stage_config:\n      target: sgm.models.autoencoder.AutoencoderKL\n      params:\n        ckpt_path: CKPT_PATH\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          attn_type: vanilla-xformers\n          double_z: true\n          z_channels: 4\n          resolution: 256\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult: [1, 2, 4, 4]\n          num_res_blocks: 2\n          attn_resolutions: []\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity\n\n    loss_fn_config:\n      target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss\n      params:        \n        loss_weighting_config:\n          target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting\n        sigma_sampler_config:\n          target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling\n          params:\n            num_idx: 1000\n\n            discretization_config:\n              target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization\n\n    sampler_config:\n      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler\n      params:\n        num_steps: 50\n\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization\n\n        guider_config:\n          target: sgm.modules.diffusionmodules.guiders.VanillaCFG\n          params:\n            scale: 5.0\n\ndata:\n  target: sgm.data.dataset.StableDataModuleFromConfig\n  params:\n    train:\n      datapipeline:\n        urls:\n          # USER: adapt this path the root of your custom dataset\n          - DATA_PATH\n        pipeline_config:\n          shardshuffle: 10000\n          sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM\n\n        decoders:\n          - pil\n\n        postprocessors:\n          - target: sdata.mappers.TorchVisionImageTransforms\n            params:\n              key: jpg # USER: you might wanna adapt this for your custom dataset\n              transforms:\n                - target: torchvision.transforms.Resize\n                  params:\n                    size: 256\n                    interpolation: 3\n                - target: torchvision.transforms.ToTensor\n          - target: sdata.mappers.Rescaler\n\n          - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare\n            params:\n              h_key: height # USER: you might wanna adapt this for your custom dataset\n              w_key: width # USER: you might wanna adapt this for your custom dataset\n\n      loader:\n        batch_size: 64\n        num_workers: 6\n\nlightning:\n  modelcheckpoint:\n    params:\n      every_n_train_steps: 5000\n\n  callbacks:\n    metrics_over_trainsteps_checkpoint:\n      params:\n        every_n_train_steps: 25000\n\n    image_logger:\n      target: main.ImageLogger\n      params:\n        disabled: False\n        enable_autocast: False\n        batch_frequency: 1000\n        max_images: 8\n        increase_log_steps: True\n        log_first_step: False\n        log_images_kwargs:\n          use_ema_scope: False\n          N: 8\n          n_rows: 2\n\n  trainer:\n    devices: 0,\n    benchmark: True\n    num_sanity_val_steps: 0\n    accumulate_grad_batches: 1\n    max_epochs: 1000"
  },
  {
    "path": "configs/example_training/toy/cifar10_cond.yaml",
    "content": "model:\n  base_learning_rate: 1.0e-4\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.Denoiser\n      params:\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling\n          params:\n            sigma_data: 1.0\n\n    network_config:\n      target: sgm.modules.diffusionmodules.openaimodel.UNetModel\n      params:\n        in_channels: 3\n        out_channels: 3\n        model_channels: 32\n        attention_resolutions: []\n        num_res_blocks: 4\n        channel_mult: [1, 2, 2]\n        num_head_channels: 32\n        num_classes: sequential\n        adm_in_channels: 128\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n          - is_trainable: True\n            input_key: cls\n            ucg_rate: 0.2\n            target: sgm.modules.encoders.modules.ClassEmbedder\n            params:\n              embed_dim: 128\n              n_classes: 10\n\n    first_stage_config:\n      target: sgm.models.autoencoder.IdentityFirstStage\n\n    loss_fn_config:\n      target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss\n      params:\n        loss_weighting_config:\n          target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting\n          params:\n            sigma_data: 1.0\n        sigma_sampler_config:\n          target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling\n\n    sampler_config:\n      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler\n      params:\n        num_steps: 50\n\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization\n\n        guider_config:\n          target: sgm.modules.diffusionmodules.guiders.VanillaCFG\n          params:\n            scale: 3.0\n\ndata:\n  target: sgm.data.cifar10.CIFAR10Loader\n  params:\n    batch_size: 512\n    num_workers: 1\n\nlightning:\n  modelcheckpoint:\n    params:\n      every_n_train_steps: 5000\n\n  callbacks:\n    metrics_over_trainsteps_checkpoint:\n      params:\n        every_n_train_steps: 25000\n\n    image_logger:\n      target: main.ImageLogger\n      params:\n        disabled: False\n        batch_frequency: 1000\n        max_images: 64\n        increase_log_steps: True\n        log_first_step: False\n        log_images_kwargs:\n          use_ema_scope: False\n          N: 64\n          n_rows: 8\n\n  trainer:\n    devices: 0,\n    benchmark: True\n    num_sanity_val_steps: 0\n    accumulate_grad_batches: 1\n    max_epochs: 20"
  },
  {
    "path": "configs/example_training/toy/mnist.yaml",
    "content": "model:\n  base_learning_rate: 1.0e-4\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.Denoiser\n      params:\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling\n          params:\n            sigma_data: 1.0\n\n    network_config:\n      target: sgm.modules.diffusionmodules.openaimodel.UNetModel\n      params:\n        in_channels: 1\n        out_channels: 1\n        model_channels: 32\n        attention_resolutions: []\n        num_res_blocks: 4\n        channel_mult: [1, 2, 2]\n        num_head_channels: 32\n\n    first_stage_config:\n      target: sgm.models.autoencoder.IdentityFirstStage\n\n    loss_fn_config:\n      target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss\n      params:\n        loss_weighting_config:\n          target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting\n          params:\n            sigma_data: 1.0\n        sigma_sampler_config:\n          target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling\n\n    sampler_config:\n      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler\n      params:\n        num_steps: 50\n\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization\n\ndata:\n  target: sgm.data.mnist.MNISTLoader\n  params:\n    batch_size: 512\n    num_workers: 1\n\nlightning:\n  modelcheckpoint:\n    params:\n      every_n_train_steps: 5000\n\n  callbacks:\n    metrics_over_trainsteps_checkpoint:\n      params:\n        every_n_train_steps: 25000\n\n    image_logger:\n      target: main.ImageLogger\n      params:\n        disabled: False\n        batch_frequency: 1000\n        max_images: 64\n        increase_log_steps: False\n        log_first_step: False\n        log_images_kwargs:\n          use_ema_scope: False\n          N: 64\n          n_rows: 8\n\n  trainer:\n    devices: 0,\n    benchmark: True\n    num_sanity_val_steps: 0\n    accumulate_grad_batches: 1\n    max_epochs: 10"
  },
  {
    "path": "configs/example_training/toy/mnist_cond.yaml",
    "content": "model:\n  base_learning_rate: 1.0e-4\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.Denoiser\n      params:\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling\n          params:\n            sigma_data: 1.0\n\n    network_config:\n      target: sgm.modules.diffusionmodules.openaimodel.UNetModel\n      params:\n        in_channels: 1\n        out_channels: 1\n        model_channels: 32\n        attention_resolutions: []\n        num_res_blocks: 4\n        channel_mult: [1, 2, 2]\n        num_head_channels: 32\n        num_classes: sequential\n        adm_in_channels: 128\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n          - is_trainable: True\n            input_key: cls\n            ucg_rate: 0.2\n            target: sgm.modules.encoders.modules.ClassEmbedder\n            params:\n              embed_dim: 128\n              n_classes: 10\n\n    first_stage_config:\n      target: sgm.models.autoencoder.IdentityFirstStage\n\n    loss_fn_config:\n      target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss\n      params:\n        loss_weighting_config:\n          target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting\n          params:\n            sigma_data: 1.0\n        sigma_sampler_config:\n          target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling\n\n    sampler_config:\n      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler\n      params:\n        num_steps: 50\n\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization\n\n        guider_config:\n          target: sgm.modules.diffusionmodules.guiders.VanillaCFG\n          params:\n            scale: 3.0\n\ndata:\n  target: sgm.data.mnist.MNISTLoader\n  params:\n    batch_size: 512\n    num_workers: 1\n\nlightning:\n  modelcheckpoint:\n    params:\n      every_n_train_steps: 5000\n\n  callbacks:\n    metrics_over_trainsteps_checkpoint:\n      params:\n        every_n_train_steps: 25000\n\n    image_logger:\n      target: main.ImageLogger\n      params:\n        disabled: False\n        batch_frequency: 1000\n        max_images: 16\n        increase_log_steps: True\n        log_first_step: False\n        log_images_kwargs:\n          use_ema_scope: False\n          N: 16\n          n_rows: 4\n\n  trainer:\n    devices: 0,\n    benchmark: True\n    num_sanity_val_steps: 0\n    accumulate_grad_batches: 1\n    max_epochs: 20"
  },
  {
    "path": "configs/example_training/toy/mnist_cond_discrete_eps.yaml",
    "content": "model:\n  base_learning_rate: 1.0e-4\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser\n      params:\n        num_idx: 1000\n\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization\n\n    network_config:\n      target: sgm.modules.diffusionmodules.openaimodel.UNetModel\n      params:\n        in_channels: 1\n        out_channels: 1\n        model_channels: 32\n        attention_resolutions: []\n        num_res_blocks: 4\n        channel_mult: [1, 2, 2]\n        num_head_channels: 32\n        num_classes: sequential\n        adm_in_channels: 128\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n          - is_trainable: True\n            input_key: cls\n            ucg_rate: 0.2\n            target: sgm.modules.encoders.modules.ClassEmbedder\n            params:\n              embed_dim: 128\n              n_classes: 10\n\n    first_stage_config:\n      target: sgm.models.autoencoder.IdentityFirstStage\n\n    loss_fn_config:\n      target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss\n      params:\n        loss_weighting_config:\n          target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting\n        sigma_sampler_config:\n          target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling\n          params:\n            num_idx: 1000\n\n            discretization_config:\n              target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization\n\n    sampler_config:\n      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler\n      params:\n        num_steps: 50\n\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization\n\n        guider_config:\n          target: sgm.modules.diffusionmodules.guiders.VanillaCFG\n          params:\n            scale: 5.0\n\ndata:\n  target: sgm.data.mnist.MNISTLoader\n  params:\n    batch_size: 512\n    num_workers: 1\n\nlightning:\n  modelcheckpoint:\n    params:\n      every_n_train_steps: 5000\n\n  callbacks:\n    metrics_over_trainsteps_checkpoint:\n      params:\n        every_n_train_steps: 25000\n\n    image_logger:\n      target: main.ImageLogger\n      params:\n        disabled: False\n        batch_frequency: 1000\n        max_images: 16\n        increase_log_steps: True\n        log_first_step: False\n        log_images_kwargs:\n          use_ema_scope: False\n          N: 16\n          n_rows: 4\n\n  trainer:\n    devices: 0,\n    benchmark: True\n    num_sanity_val_steps: 0\n    accumulate_grad_batches: 1\n    max_epochs: 20"
  },
  {
    "path": "configs/example_training/toy/mnist_cond_l1_loss.yaml",
    "content": "model:\n  base_learning_rate: 1.0e-4\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.Denoiser\n      params:\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling\n          params:\n            sigma_data: 1.0\n\n    network_config:\n      target: sgm.modules.diffusionmodules.openaimodel.UNetModel\n      params:\n        in_channels: 1\n        out_channels: 1\n        model_channels: 32\n        attention_resolutions: []\n        num_res_blocks: 4\n        channel_mult: [1, 2, 2]\n        num_head_channels: 32\n        num_classes: sequential\n        adm_in_channels: 128\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n          - is_trainable: True\n            input_key: cls\n            ucg_rate: 0.2\n            target: sgm.modules.encoders.modules.ClassEmbedder\n            params:\n              embed_dim: 128\n              n_classes: 10\n\n    first_stage_config:\n      target: sgm.models.autoencoder.IdentityFirstStage\n\n    loss_fn_config:\n      target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss\n      params:\n        loss_type: l1\n        loss_weighting_config:\n          target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting\n          params:\n            sigma_data: 1.0\n        sigma_sampler_config:\n          target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling\n\n    sampler_config:\n      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler\n      params:\n        num_steps: 50\n\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization\n\n        guider_config:\n          target: sgm.modules.diffusionmodules.guiders.VanillaCFG\n          params:\n            scale: 3.0\n\ndata:\n  target: sgm.data.mnist.MNISTLoader\n  params:\n    batch_size: 512\n    num_workers: 1\n\nlightning:\n  modelcheckpoint:\n    params:\n      every_n_train_steps: 5000\n\n  callbacks:\n    metrics_over_trainsteps_checkpoint:\n      params:\n        every_n_train_steps: 25000\n\n    image_logger:\n      target: main.ImageLogger\n      params:\n        disabled: False\n        batch_frequency: 1000\n        max_images: 64\n        increase_log_steps: True\n        log_first_step: False\n        log_images_kwargs:\n          use_ema_scope: False\n          N: 64\n          n_rows: 8\n\n  trainer:\n    devices: 0,\n    benchmark: True\n    num_sanity_val_steps: 0\n    accumulate_grad_batches: 1\n    max_epochs: 20"
  },
  {
    "path": "configs/example_training/toy/mnist_cond_with_ema.yaml",
    "content": "model:\n  base_learning_rate: 1.0e-4\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    use_ema: True\n\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.Denoiser\n      params:\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling\n          params:\n            sigma_data: 1.0\n\n    network_config:\n      target: sgm.modules.diffusionmodules.openaimodel.UNetModel\n      params:\n        in_channels: 1\n        out_channels: 1\n        model_channels: 32\n        attention_resolutions: []\n        num_res_blocks: 4\n        channel_mult: [1, 2, 2]\n        num_head_channels: 32\n        num_classes: sequential\n        adm_in_channels: 128\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n          - is_trainable: True\n            input_key: cls\n            ucg_rate: 0.2\n            target: sgm.modules.encoders.modules.ClassEmbedder\n            params:\n              embed_dim: 128\n              n_classes: 10\n\n    first_stage_config:\n      target: sgm.models.autoencoder.IdentityFirstStage\n\n    loss_fn_config:\n      target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss\n      params:\n        loss_weighting_config:\n          target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting\n          params:\n            sigma_data: 1.0\n        sigma_sampler_config:\n          target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling\n\n    sampler_config:\n      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler\n      params:\n        num_steps: 50\n\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization\n\n        guider_config:\n          target: sgm.modules.diffusionmodules.guiders.VanillaCFG\n          params:\n            scale: 3.0\n\ndata:\n  target: sgm.data.mnist.MNISTLoader\n  params:\n    batch_size: 512\n    num_workers: 1\n\nlightning:\n  modelcheckpoint:\n    params:\n      every_n_train_steps: 5000\n\n  callbacks:\n    metrics_over_trainsteps_checkpoint:\n      params:\n        every_n_train_steps: 25000\n\n    image_logger:\n      target: main.ImageLogger\n      params:\n        disabled: False\n        batch_frequency: 1000\n        max_images: 64\n        increase_log_steps: True\n        log_first_step: False\n        log_images_kwargs:\n          use_ema_scope: False\n          N: 64\n          n_rows: 8\n\n  trainer:\n    devices: 0,\n    benchmark: True\n    num_sanity_val_steps: 0\n    accumulate_grad_batches: 1\n    max_epochs: 20"
  },
  {
    "path": "configs/example_training/txt2img-clipl-legacy-ucg-training.yaml",
    "content": "model:\n  base_learning_rate: 1.0e-4\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    scale_factor: 0.13025\n    disable_first_stage_autocast: True\n    log_keys:\n      - txt\n\n    scheduler_config:\n      target: sgm.lr_scheduler.LambdaLinearScheduler\n      params:\n        warm_up_steps: [10000]\n        cycle_lengths: [10000000000000]\n        f_start: [1.e-6]\n        f_max: [1.]\n        f_min: [1.]\n\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser\n      params:\n        num_idx: 1000\n\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization\n\n    network_config:\n      target: sgm.modules.diffusionmodules.openaimodel.UNetModel\n      params:\n        use_checkpoint: True\n        in_channels: 4\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [1, 2, 4]\n        num_res_blocks: 2\n        channel_mult: [1, 2, 4, 4]\n        num_head_channels: 64\n        num_classes: sequential\n        adm_in_channels: 1792\n        num_heads: 1\n        transformer_depth: 1\n        context_dim: 768\n        spatial_transformer_attn_type: softmax-xformers\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n          - is_trainable: True\n            input_key: txt\n            ucg_rate: 0.1\n            legacy_ucg_value: \"\"\n            target: sgm.modules.encoders.modules.FrozenCLIPEmbedder\n            params:\n              always_return_pooled: True\n\n          - is_trainable: False\n            ucg_rate: 0.1\n            input_key: original_size_as_tuple\n            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n            params:\n              outdim: 256\n\n          - is_trainable: False\n            input_key: crop_coords_top_left\n            ucg_rate: 0.1\n            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n            params:\n              outdim: 256\n\n    first_stage_config:\n      target: sgm.models.autoencoder.AutoencoderKL\n      params:\n        ckpt_path: CKPT_PATH\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          attn_type: vanilla-xformers\n          double_z: true\n          z_channels: 4\n          resolution: 256\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult: [ 1, 2, 4, 4 ]\n          num_res_blocks: 2\n          attn_resolutions: [ ]\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity\n\n    loss_fn_config:\n      target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss\n      params:\n        loss_weighting_config:\n          target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting\n        sigma_sampler_config:\n          target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling\n          params:\n            num_idx: 1000\n\n            discretization_config:\n              target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization\n\n    sampler_config:\n      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler\n      params:\n        num_steps: 50\n\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization\n\n        guider_config:\n          target: sgm.modules.diffusionmodules.guiders.VanillaCFG\n          params:\n            scale: 7.5\n\ndata:\n  target: sgm.data.dataset.StableDataModuleFromConfig\n  params:\n    train:\n      datapipeline:\n        urls:\n          # USER: adapt this path the root of your custom dataset\n          - DATA_PATH\n        pipeline_config:\n          shardshuffle: 10000\n          sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM\n\n        decoders:\n          - pil\n\n        postprocessors:\n          - target: sdata.mappers.TorchVisionImageTransforms\n            params:\n              key: jpg # USER: you might wanna adapt this for your custom dataset\n              transforms:\n                - target: torchvision.transforms.Resize\n                  params:\n                    size: 256\n                    interpolation: 3\n                - target: torchvision.transforms.ToTensor\n          - target: sdata.mappers.Rescaler\n          - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare\n            # USER: you might wanna use non-default parameters due to your custom dataset\n\n      loader:\n        batch_size: 64\n        num_workers: 6\n\nlightning:\n  modelcheckpoint:\n    params:\n      every_n_train_steps: 5000\n\n  callbacks:\n    metrics_over_trainsteps_checkpoint:\n      params:\n        every_n_train_steps: 25000\n\n    image_logger:\n      target: main.ImageLogger\n      params:\n        disabled: False\n        enable_autocast: False\n        batch_frequency: 1000\n        max_images: 8\n        increase_log_steps: True\n        log_first_step: False\n        log_images_kwargs:\n          use_ema_scope: False\n          N: 8\n          n_rows: 2\n\n  trainer:\n    devices: 0,\n    benchmark: True\n    num_sanity_val_steps: 0\n    accumulate_grad_batches: 1\n    max_epochs: 1000"
  },
  {
    "path": "configs/example_training/txt2img-clipl.yaml",
    "content": "model:\n  base_learning_rate: 1.0e-4\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    scale_factor: 0.13025\n    disable_first_stage_autocast: True\n    log_keys:\n      - txt\n\n    scheduler_config:\n      target: sgm.lr_scheduler.LambdaLinearScheduler\n      params:\n        warm_up_steps: [10000]\n        cycle_lengths: [10000000000000]\n        f_start: [1.e-6]\n        f_max: [1.]\n        f_min: [1.]\n\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser\n      params:\n        num_idx: 1000\n\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization\n\n    network_config:\n      target: sgm.modules.diffusionmodules.openaimodel.UNetModel\n      params:\n        use_checkpoint: True\n        in_channels: 4\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [1, 2, 4]\n        num_res_blocks: 2\n        channel_mult: [1, 2, 4, 4]\n        num_head_channels: 64\n        num_classes: sequential\n        adm_in_channels: 1792\n        num_heads: 1\n        transformer_depth: 1\n        context_dim: 768\n        spatial_transformer_attn_type: softmax-xformers\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n          - is_trainable: True\n            input_key: txt\n            ucg_rate: 0.1\n            legacy_ucg_value: \"\"\n            target: sgm.modules.encoders.modules.FrozenCLIPEmbedder\n            params:\n              always_return_pooled: True\n\n          - is_trainable: False\n            ucg_rate: 0.1\n            input_key: original_size_as_tuple\n            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n            params:\n              outdim: 256\n\n          - is_trainable: False\n            input_key: crop_coords_top_left\n            ucg_rate: 0.1\n            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n            params:\n              outdim: 256\n\n    first_stage_config:\n      target: sgm.models.autoencoder.AutoencoderKL\n      params:\n        ckpt_path: CKPT_PATH\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          attn_type: vanilla-xformers\n          double_z: true\n          z_channels: 4\n          resolution: 256\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult: [1, 2, 4, 4]\n          num_res_blocks: 2\n          attn_resolutions: []\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity\n\n    loss_fn_config:\n      target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss\n      params:\n        loss_weighting_config:\n          target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting\n        sigma_sampler_config:\n          target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling\n          params:\n            num_idx: 1000\n\n            discretization_config:\n              target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization\n\n    sampler_config:\n      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler\n      params:\n        num_steps: 50\n\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization\n\n        guider_config:\n          target: sgm.modules.diffusionmodules.guiders.VanillaCFG\n          params:\n            scale: 7.5\n\ndata:\n  target: sgm.data.dataset.StableDataModuleFromConfig\n  params:\n    train:\n      datapipeline:\n        urls:\n          # USER: adapt this path the root of your custom dataset\n          - DATA_PATH\n        pipeline_config:\n          shardshuffle: 10000\n          sample_shuffle: 10000\n\n\n        decoders:\n          - pil\n\n        postprocessors:\n          - target: sdata.mappers.TorchVisionImageTransforms\n            params:\n              key: jpg # USER: you might wanna adapt this for your custom dataset\n              transforms:\n                - target: torchvision.transforms.Resize\n                  params:\n                    size: 256\n                    interpolation: 3\n                - target: torchvision.transforms.ToTensor\n          - target: sdata.mappers.Rescaler\n            # USER: you might wanna use non-default parameters due to your custom dataset\n          - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare\n            # USER: you might wanna use non-default parameters due to your custom dataset\n\n      loader:\n        batch_size: 64\n        num_workers: 6\n\nlightning:\n  modelcheckpoint:\n    params:\n      every_n_train_steps: 5000\n\n  callbacks:\n    metrics_over_trainsteps_checkpoint:\n      params:\n        every_n_train_steps: 25000\n\n    image_logger:\n      target: main.ImageLogger\n      params:\n        disabled: False\n        enable_autocast: False\n        batch_frequency: 1000\n        max_images: 8\n        increase_log_steps: True\n        log_first_step: False\n        log_images_kwargs:\n          use_ema_scope: False\n          N: 8\n          n_rows: 2\n\n  trainer:\n    devices: 0,\n    benchmark: True\n    num_sanity_val_steps: 0\n    accumulate_grad_batches: 1\n    max_epochs: 1000"
  },
  {
    "path": "configs/inference/sd_xl_base.yaml",
    "content": "model:\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    scale_factor: 0.13025\n    disable_first_stage_autocast: True\n\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser\n      params:\n        num_idx: 1000\n\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization\n\n    network_config:\n      target: sgm.modules.diffusionmodules.openaimodel.UNetModel\n      params:\n        adm_in_channels: 2816\n        num_classes: sequential\n        use_checkpoint: True\n        in_channels: 4\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [4, 2]\n        num_res_blocks: 2\n        channel_mult: [1, 2, 4]\n        num_head_channels: 64\n        use_linear_in_transformer: True\n        transformer_depth: [1, 2, 10]\n        context_dim: 2048\n        spatial_transformer_attn_type: softmax-xformers\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n          - is_trainable: False\n            input_key: txt\n            target: sgm.modules.encoders.modules.FrozenCLIPEmbedder\n            params:\n              layer: hidden\n              layer_idx: 11\n\n          - is_trainable: False\n            input_key: txt\n            target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2\n            params:\n              arch: ViT-bigG-14\n              version: laion2b_s39b_b160k\n              freeze: True\n              layer: penultimate\n              always_return_pooled: True\n              legacy: False\n\n          - is_trainable: False\n            input_key: original_size_as_tuple\n            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n            params:\n              outdim: 256\n\n          - is_trainable: False\n            input_key: crop_coords_top_left\n            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n            params:\n              outdim: 256\n\n          - is_trainable: False\n            input_key: target_size_as_tuple\n            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n            params:\n              outdim: 256\n\n    first_stage_config:\n      target: sgm.models.autoencoder.AutoencoderKL\n      params:\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          attn_type: vanilla-xformers\n          double_z: true\n          z_channels: 4\n          resolution: 256\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult: [1, 2, 4, 4]\n          num_res_blocks: 2\n          attn_resolutions: []\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity\n"
  },
  {
    "path": "configs/inference/sd_xl_refiner.yaml",
    "content": "model:\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    scale_factor: 0.13025\n    disable_first_stage_autocast: True\n\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser\n      params:\n        num_idx: 1000\n\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization\n\n    network_config:\n      target: sgm.modules.diffusionmodules.openaimodel.UNetModel\n      params:\n        adm_in_channels: 2560\n        num_classes: sequential\n        use_checkpoint: True\n        in_channels: 4\n        out_channels: 4\n        model_channels: 384\n        attention_resolutions: [4, 2]\n        num_res_blocks: 2\n        channel_mult: [1, 2, 4, 4]\n        num_head_channels: 64\n        use_linear_in_transformer: True\n        transformer_depth: 4\n        context_dim: [1280, 1280, 1280, 1280]\n        spatial_transformer_attn_type: softmax-xformers\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n          - is_trainable: False\n            input_key: txt\n            target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2\n            params:\n              arch: ViT-bigG-14\n              version: laion2b_s39b_b160k\n              legacy: False\n              freeze: True\n              layer: penultimate\n              always_return_pooled: True\n\n          - is_trainable: False\n            input_key: original_size_as_tuple\n            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n            params:\n              outdim: 256\n\n          - is_trainable: False\n            input_key: crop_coords_top_left\n            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n            params:\n              outdim: 256\n\n          - is_trainable: False\n            input_key: aesthetic_score\n            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n            params:\n              outdim: 256\n\n    first_stage_config:\n      target: sgm.models.autoencoder.AutoencoderKL\n      params:\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          attn_type: vanilla-xformers\n          double_z: true\n          z_channels: 4\n          resolution: 256\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult: [1, 2, 4, 4]\n          num_res_blocks: 2\n          attn_resolutions: []\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity\n"
  },
  {
    "path": "configs/inference/sv3d_p.yaml",
    "content": "model:\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    scale_factor: 0.18215\n    disable_first_stage_autocast: True\n\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.Denoiser\n      params:\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise\n\n    network_config:\n      target: sgm.modules.diffusionmodules.video_model.VideoUNet\n      params:\n        adm_in_channels: 1280\n        num_classes: sequential\n        use_checkpoint: True\n        in_channels: 8\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [4, 2, 1]\n        num_res_blocks: 2\n        channel_mult: [1, 2, 4, 4]\n        num_head_channels: 64\n        use_linear_in_transformer: True\n        transformer_depth: 1\n        context_dim: 1024\n        spatial_transformer_attn_type: softmax-xformers\n        extra_ff_mix_layer: True\n        use_spatial_context: True\n        merge_strategy: learned_with_images\n        video_kernel_size: [3, 1, 1]\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n        - input_key: cond_frames_without_noise\n          is_trainable: False\n          target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder\n          params:\n            n_cond_frames: 1\n            n_copies: 1\n            open_clip_embedding_config:\n              target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder\n              params:\n                freeze: True\n\n        - input_key: cond_frames\n          is_trainable: False\n          target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder\n          params:\n            disable_encoder_autocast: True\n            n_cond_frames: 1\n            n_copies: 1\n            is_ae: True\n            encoder_config:\n              target: sgm.models.autoencoder.AutoencoderKLModeOnly\n              params:\n                embed_dim: 4\n                monitor: val/rec_loss\n                ddconfig:\n                  attn_type: vanilla-xformers\n                  double_z: True\n                  z_channels: 4\n                  resolution: 256\n                  in_channels: 3\n                  out_ch: 3\n                  ch: 128\n                  ch_mult: [1, 2, 4, 4]\n                  num_res_blocks: 2\n                  attn_resolutions: []\n                  dropout: 0.0\n                lossconfig:\n                  target: torch.nn.Identity\n\n        - input_key: cond_aug\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n        - input_key: polars_rad\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 512\n\n        - input_key: azimuths_rad\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 512\n\n    first_stage_config:\n      target: sgm.models.autoencoder.AutoencodingEngine\n      params:\n        loss_config:\n          target: torch.nn.Identity\n        regularizer_config:\n          target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer\n        encoder_config:\n          target: torch.nn.Identity\n        decoder_config:\n          target: sgm.modules.diffusionmodules.model.Decoder\n          params:\n            attn_type: vanilla-xformers\n            double_z: True\n            z_channels: 4\n            resolution: 256\n            in_channels: 3\n            out_ch: 3\n            ch: 128\n            ch_mult: [ 1, 2, 4, 4 ]\n            num_res_blocks: 2\n            attn_resolutions: [ ]\n            dropout: 0.0"
  },
  {
    "path": "configs/inference/sv3d_u.yaml",
    "content": "model:\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    scale_factor: 0.18215\n    disable_first_stage_autocast: True\n\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.Denoiser\n      params:\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise\n\n    network_config:\n      target: sgm.modules.diffusionmodules.video_model.VideoUNet\n      params:\n        adm_in_channels: 256\n        num_classes: sequential\n        use_checkpoint: True\n        in_channels: 8\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [4, 2, 1]\n        num_res_blocks: 2\n        channel_mult: [1, 2, 4, 4]\n        num_head_channels: 64\n        use_linear_in_transformer: True\n        transformer_depth: 1\n        context_dim: 1024\n        spatial_transformer_attn_type: softmax-xformers\n        extra_ff_mix_layer: True\n        use_spatial_context: True\n        merge_strategy: learned_with_images\n        video_kernel_size: [3, 1, 1]\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n        - input_key: cond_frames_without_noise\n          is_trainable: False\n          target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder\n          params:\n            n_cond_frames: 1\n            n_copies: 1\n            open_clip_embedding_config:\n              target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder\n              params:\n                freeze: True\n\n        - input_key: cond_frames\n          is_trainable: False\n          target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder\n          params:\n            disable_encoder_autocast: True\n            n_cond_frames: 1\n            n_copies: 1\n            is_ae: True\n            encoder_config:\n              target: sgm.models.autoencoder.AutoencoderKLModeOnly\n              params:\n                embed_dim: 4\n                monitor: val/rec_loss\n                ddconfig:\n                  attn_type: vanilla-xformers\n                  double_z: True\n                  z_channels: 4\n                  resolution: 256\n                  in_channels: 3\n                  out_ch: 3\n                  ch: 128\n                  ch_mult: [1, 2, 4, 4]\n                  num_res_blocks: 2\n                  attn_resolutions: []\n                  dropout: 0.0\n                lossconfig:\n                  target: torch.nn.Identity\n\n        - input_key: cond_aug\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n    first_stage_config:\n      target: sgm.models.autoencoder.AutoencodingEngine\n      params:\n        loss_config:\n          target: torch.nn.Identity\n        regularizer_config:\n          target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer\n        encoder_config:\n          target: torch.nn.Identity\n        decoder_config:\n          target: sgm.modules.diffusionmodules.model.Decoder\n          params:\n            attn_type: vanilla-xformers\n            double_z: True\n            z_channels: 4\n            resolution: 256\n            in_channels: 3\n            out_ch: 3\n            ch: 128\n            ch_mult: [ 1, 2, 4, 4 ]\n            num_res_blocks: 2\n            attn_resolutions: [ ]\n            dropout: 0.0"
  },
  {
    "path": "configs/inference/svd.yaml",
    "content": "model:\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    scale_factor: 0.18215\n    disable_first_stage_autocast: True\n\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.Denoiser\n      params:\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise\n\n    network_config:\n      target: sgm.modules.diffusionmodules.video_model.VideoUNet\n      params:\n        adm_in_channels: 768\n        num_classes: sequential\n        use_checkpoint: True\n        in_channels: 8\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [4, 2, 1]\n        num_res_blocks: 2\n        channel_mult: [1, 2, 4, 4]\n        num_head_channels: 64\n        use_linear_in_transformer: True\n        transformer_depth: 1\n        context_dim: 1024\n        spatial_transformer_attn_type: softmax-xformers\n        extra_ff_mix_layer: True\n        use_spatial_context: True\n        merge_strategy: learned_with_images\n        video_kernel_size: [3, 1, 1]\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n        - is_trainable: False\n          input_key: cond_frames_without_noise\n          target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder\n          params:\n            n_cond_frames: 1\n            n_copies: 1\n            open_clip_embedding_config:\n              target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder\n              params:\n                freeze: True\n\n        - input_key: fps_id\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n        - input_key: motion_bucket_id\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n        - input_key: cond_frames\n          is_trainable: False\n          target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder\n          params:\n            disable_encoder_autocast: True\n            n_cond_frames: 1\n            n_copies: 1\n            is_ae: True\n            encoder_config:\n              target: sgm.models.autoencoder.AutoencoderKLModeOnly\n              params:\n                embed_dim: 4\n                monitor: val/rec_loss\n                ddconfig:\n                  attn_type: vanilla-xformers\n                  double_z: True\n                  z_channels: 4\n                  resolution: 256\n                  in_channels: 3\n                  out_ch: 3\n                  ch: 128\n                  ch_mult: [1, 2, 4, 4]\n                  num_res_blocks: 2\n                  attn_resolutions: []\n                  dropout: 0.0\n                lossconfig:\n                  target: torch.nn.Identity\n\n        - input_key: cond_aug\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n    first_stage_config:\n      target: sgm.models.autoencoder.AutoencodingEngine\n      params:\n        loss_config:\n          target: torch.nn.Identity\n        regularizer_config:\n          target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer\n        encoder_config: \n          target: sgm.modules.diffusionmodules.model.Encoder\n          params:\n            attn_type: vanilla\n            double_z: True\n            z_channels: 4\n            resolution: 256\n            in_channels: 3\n            out_ch: 3\n            ch: 128\n            ch_mult: [1, 2, 4, 4]\n            num_res_blocks: 2\n            attn_resolutions: []\n            dropout: 0.0\n        decoder_config:\n          target: sgm.modules.autoencoding.temporal_ae.VideoDecoder\n          params:\n            attn_type: vanilla\n            double_z: True\n            z_channels: 4\n            resolution: 256\n            in_channels: 3\n            out_ch: 3\n            ch: 128\n            ch_mult: [1, 2, 4, 4]\n            num_res_blocks: 2\n            attn_resolutions: []\n            dropout: 0.0\n            video_kernel_size: [3, 1, 1]"
  },
  {
    "path": "configs/inference/svd_image_decoder.yaml",
    "content": "model:\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    scale_factor: 0.18215\n    disable_first_stage_autocast: True\n\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.Denoiser\n      params:\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise\n\n    network_config:\n      target: sgm.modules.diffusionmodules.video_model.VideoUNet\n      params:\n        adm_in_channels: 768\n        num_classes: sequential\n        use_checkpoint: True\n        in_channels: 8\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [4, 2, 1]\n        num_res_blocks: 2\n        channel_mult: [1, 2, 4, 4]\n        num_head_channels: 64\n        use_linear_in_transformer: True\n        transformer_depth: 1\n        context_dim: 1024\n        spatial_transformer_attn_type: softmax-xformers\n        extra_ff_mix_layer: True\n        use_spatial_context: True\n        merge_strategy: learned_with_images\n        video_kernel_size: [3, 1, 1]\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n        - is_trainable: False\n          input_key: cond_frames_without_noise\n          target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder\n          params:\n            n_cond_frames: 1\n            n_copies: 1\n            open_clip_embedding_config:\n              target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder\n              params:\n                freeze: True\n\n        - input_key: fps_id\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n        - input_key: motion_bucket_id\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n        - input_key: cond_frames\n          is_trainable: False\n          target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder\n          params:\n            disable_encoder_autocast: True\n            n_cond_frames: 1\n            n_copies: 1\n            is_ae: True\n            encoder_config:\n              target: sgm.models.autoencoder.AutoencoderKLModeOnly\n              params:\n                embed_dim: 4\n                monitor: val/rec_loss\n                ddconfig:\n                  attn_type: vanilla-xformers\n                  double_z: True\n                  z_channels: 4\n                  resolution: 256\n                  in_channels: 3\n                  out_ch: 3\n                  ch: 128\n                  ch_mult: [1, 2, 4, 4]\n                  num_res_blocks: 2\n                  attn_resolutions: []\n                  dropout: 0.0\n                lossconfig:\n                  target: torch.nn.Identity\n\n        - input_key: cond_aug\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n    first_stage_config:\n      target: sgm.models.autoencoder.AutoencoderKL\n      params:\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          attn_type: vanilla-xformers\n          double_z: True\n          z_channels: 4\n          resolution: 256\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult: [1, 2, 4, 4]\n          num_res_blocks: 2\n          attn_resolutions: []\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity"
  },
  {
    "path": "main.py",
    "content": "import argparse\nimport datetime\nimport glob\nimport inspect\nimport os\nimport sys\nfrom inspect import Parameter\nfrom typing import Union\n\nimport numpy as np\nimport pytorch_lightning as pl\nimport torch\nimport torchvision\nimport wandb\nfrom matplotlib import pyplot as plt\nfrom natsort import natsorted\nfrom omegaconf import OmegaConf\nfrom packaging import version\nfrom PIL import Image\nfrom pytorch_lightning import seed_everything\nfrom pytorch_lightning.callbacks import Callback\nfrom pytorch_lightning.loggers import WandbLogger\nfrom pytorch_lightning.trainer import Trainer\nfrom pytorch_lightning.utilities import rank_zero_only\n\nfrom sgm.util import exists, instantiate_from_config, isheatmap\n\nMULTINODE_HACKS = True\n\n\ndef default_trainer_args():\n    argspec = dict(inspect.signature(Trainer.__init__).parameters)\n    argspec.pop(\"self\")\n    default_args = {\n        param: argspec[param].default\n        for param in argspec\n        if argspec[param] != Parameter.empty\n    }\n    return default_args\n\n\ndef get_parser(**parser_kwargs):\n    def str2bool(v):\n        if isinstance(v, bool):\n            return v\n        if v.lower() in (\"yes\", \"true\", \"t\", \"y\", \"1\"):\n            return True\n        elif v.lower() in (\"no\", \"false\", \"f\", \"n\", \"0\"):\n            return False\n        else:\n            raise argparse.ArgumentTypeError(\"Boolean value expected.\")\n\n    parser = argparse.ArgumentParser(**parser_kwargs)\n    parser.add_argument(\n        \"-n\",\n        \"--name\",\n        type=str,\n        const=True,\n        default=\"\",\n        nargs=\"?\",\n        help=\"postfix for logdir\",\n    )\n    parser.add_argument(\n        \"--no_date\",\n        type=str2bool,\n        nargs=\"?\",\n        const=True,\n        default=False,\n        help=\"if True, skip date generation for logdir and only use naming via opt.base or opt.name (+ opt.postfix, optionally)\",\n    )\n    parser.add_argument(\n        \"-r\",\n        \"--resume\",\n        type=str,\n        const=True,\n        default=\"\",\n        nargs=\"?\",\n        help=\"resume from logdir or checkpoint in logdir\",\n    )\n    parser.add_argument(\n        \"-b\",\n        \"--base\",\n        nargs=\"*\",\n        metavar=\"base_config.yaml\",\n        help=\"paths to base configs. Loaded from left-to-right. \"\n        \"Parameters can be overwritten or added with command-line options of the form `--key value`.\",\n        default=list(),\n    )\n    parser.add_argument(\n        \"-t\",\n        \"--train\",\n        type=str2bool,\n        const=True,\n        default=True,\n        nargs=\"?\",\n        help=\"train\",\n    )\n    parser.add_argument(\n        \"--no-test\",\n        type=str2bool,\n        const=True,\n        default=False,\n        nargs=\"?\",\n        help=\"disable test\",\n    )\n    parser.add_argument(\n        \"-p\", \"--project\", help=\"name of new or path to existing project\"\n    )\n    parser.add_argument(\n        \"-d\",\n        \"--debug\",\n        type=str2bool,\n        nargs=\"?\",\n        const=True,\n        default=False,\n        help=\"enable post-mortem debugging\",\n    )\n    parser.add_argument(\n        \"-s\",\n        \"--seed\",\n        type=int,\n        default=23,\n        help=\"seed for seed_everything\",\n    )\n    parser.add_argument(\n        \"-f\",\n        \"--postfix\",\n        type=str,\n        default=\"\",\n        help=\"post-postfix for default name\",\n    )\n    parser.add_argument(\n        \"--projectname\",\n        type=str,\n        default=\"stablediffusion\",\n    )\n    parser.add_argument(\n        \"-l\",\n        \"--logdir\",\n        type=str,\n        default=\"logs\",\n        help=\"directory for logging dat shit\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        type=str2bool,\n        nargs=\"?\",\n        const=True,\n        default=False,\n        help=\"scale base-lr by ngpu * batch_size * n_accumulate\",\n    )\n    parser.add_argument(\n        \"--legacy_naming\",\n        type=str2bool,\n        nargs=\"?\",\n        const=True,\n        default=False,\n        help=\"name run based on config file name if true, else by whole path\",\n    )\n    parser.add_argument(\n        \"--enable_tf32\",\n        type=str2bool,\n        nargs=\"?\",\n        const=True,\n        default=False,\n        help=\"enables the TensorFloat32 format both for matmuls and cuDNN for pytorch 1.12\",\n    )\n    parser.add_argument(\n        \"--startup\",\n        type=str,\n        default=None,\n        help=\"Startuptime from distributed script\",\n    )\n    parser.add_argument(\n        \"--wandb\",\n        type=str2bool,\n        nargs=\"?\",\n        const=True,\n        default=False,  # TODO: later default to True\n        help=\"log to wandb\",\n    )\n    parser.add_argument(\n        \"--no_base_name\",\n        type=str2bool,\n        nargs=\"?\",\n        const=True,\n        default=False,  # TODO: later default to True\n        help=\"log to wandb\",\n    )\n    if version.parse(torch.__version__) >= version.parse(\"2.0.0\"):\n        parser.add_argument(\n            \"--resume_from_checkpoint\",\n            type=str,\n            default=None,\n            help=\"single checkpoint file to resume from\",\n        )\n    default_args = default_trainer_args()\n    for key in default_args:\n        parser.add_argument(\"--\" + key, default=default_args[key])\n    return parser\n\n\ndef get_checkpoint_name(logdir):\n    ckpt = os.path.join(logdir, \"checkpoints\", \"last**.ckpt\")\n    ckpt = natsorted(glob.glob(ckpt))\n    print('available \"last\" checkpoints:')\n    print(ckpt)\n    if len(ckpt) > 1:\n        print(\"got most recent checkpoint\")\n        ckpt = sorted(ckpt, key=lambda x: os.path.getmtime(x))[-1]\n        print(f\"Most recent ckpt is {ckpt}\")\n        with open(os.path.join(logdir, \"most_recent_ckpt.txt\"), \"w\") as f:\n            f.write(ckpt + \"\\n\")\n        try:\n            version = int(ckpt.split(\"/\")[-1].split(\"-v\")[-1].split(\".\")[0])\n        except Exception as e:\n            print(\"version confusion but not bad\")\n            print(e)\n            version = 1\n        # version = last_version + 1\n    else:\n        # in this case, we only have one \"last.ckpt\"\n        ckpt = ckpt[0]\n        version = 1\n    melk_ckpt_name = f\"last-v{version}.ckpt\"\n    print(f\"Current melk ckpt name: {melk_ckpt_name}\")\n    return ckpt, melk_ckpt_name\n\n\nclass SetupCallback(Callback):\n    def __init__(\n        self,\n        resume,\n        now,\n        logdir,\n        ckptdir,\n        cfgdir,\n        config,\n        lightning_config,\n        debug,\n        ckpt_name=None,\n    ):\n        super().__init__()\n        self.resume = resume\n        self.now = now\n        self.logdir = logdir\n        self.ckptdir = ckptdir\n        self.cfgdir = cfgdir\n        self.config = config\n        self.lightning_config = lightning_config\n        self.debug = debug\n        self.ckpt_name = ckpt_name\n\n    def on_exception(self, trainer: pl.Trainer, pl_module, exception):\n        if not self.debug and trainer.global_rank == 0:\n            print(\"Summoning checkpoint.\")\n            if self.ckpt_name is None:\n                ckpt_path = os.path.join(self.ckptdir, \"last.ckpt\")\n            else:\n                ckpt_path = os.path.join(self.ckptdir, self.ckpt_name)\n            trainer.save_checkpoint(ckpt_path)\n\n    def on_fit_start(self, trainer, pl_module):\n        if trainer.global_rank == 0:\n            # Create logdirs and save configs\n            os.makedirs(self.logdir, exist_ok=True)\n            os.makedirs(self.ckptdir, exist_ok=True)\n            os.makedirs(self.cfgdir, exist_ok=True)\n\n            if \"callbacks\" in self.lightning_config:\n                if (\n                    \"metrics_over_trainsteps_checkpoint\"\n                    in self.lightning_config[\"callbacks\"]\n                ):\n                    os.makedirs(\n                        os.path.join(self.ckptdir, \"trainstep_checkpoints\"),\n                        exist_ok=True,\n                    )\n            print(\"Project config\")\n            print(OmegaConf.to_yaml(self.config))\n            if MULTINODE_HACKS:\n                import time\n\n                time.sleep(5)\n            OmegaConf.save(\n                self.config,\n                os.path.join(self.cfgdir, \"{}-project.yaml\".format(self.now)),\n            )\n\n            print(\"Lightning config\")\n            print(OmegaConf.to_yaml(self.lightning_config))\n            OmegaConf.save(\n                OmegaConf.create({\"lightning\": self.lightning_config}),\n                os.path.join(self.cfgdir, \"{}-lightning.yaml\".format(self.now)),\n            )\n\n        else:\n            # ModelCheckpoint callback created log directory --- remove it\n            if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir):\n                dst, name = os.path.split(self.logdir)\n                dst = os.path.join(dst, \"child_runs\", name)\n                os.makedirs(os.path.split(dst)[0], exist_ok=True)\n                try:\n                    os.rename(self.logdir, dst)\n                except FileNotFoundError:\n                    pass\n\n\nclass ImageLogger(Callback):\n    def __init__(\n        self,\n        batch_frequency,\n        max_images,\n        clamp=True,\n        increase_log_steps=True,\n        rescale=True,\n        disabled=False,\n        log_on_batch_idx=False,\n        log_first_step=False,\n        log_images_kwargs=None,\n        log_before_first_step=False,\n        enable_autocast=True,\n    ):\n        super().__init__()\n        self.enable_autocast = enable_autocast\n        self.rescale = rescale\n        self.batch_freq = batch_frequency\n        self.max_images = max_images\n        self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]\n        if not increase_log_steps:\n            self.log_steps = [self.batch_freq]\n        self.clamp = clamp\n        self.disabled = disabled\n        self.log_on_batch_idx = log_on_batch_idx\n        self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}\n        self.log_first_step = log_first_step\n        self.log_before_first_step = log_before_first_step\n\n    @rank_zero_only\n    def log_local(\n        self,\n        save_dir,\n        split,\n        images,\n        global_step,\n        current_epoch,\n        batch_idx,\n        pl_module: Union[None, pl.LightningModule] = None,\n    ):\n        root = os.path.join(save_dir, \"images\", split)\n        for k in images:\n            if isheatmap(images[k]):\n                fig, ax = plt.subplots()\n                ax = ax.matshow(\n                    images[k].cpu().numpy(), cmap=\"hot\", interpolation=\"lanczos\"\n                )\n                plt.colorbar(ax)\n                plt.axis(\"off\")\n\n                filename = \"{}_gs-{:06}_e-{:06}_b-{:06}.png\".format(\n                    k, global_step, current_epoch, batch_idx\n                )\n                os.makedirs(root, exist_ok=True)\n                path = os.path.join(root, filename)\n                plt.savefig(path)\n                plt.close()\n                # TODO: support wandb\n            else:\n                grid = torchvision.utils.make_grid(images[k], nrow=4)\n                if self.rescale:\n                    grid = (grid + 1.0) / 2.0  # -1,1 -> 0,1; c,h,w\n                grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)\n                grid = grid.numpy()\n                grid = (grid * 255).astype(np.uint8)\n                filename = \"{}_gs-{:06}_e-{:06}_b-{:06}.png\".format(\n                    k, global_step, current_epoch, batch_idx\n                )\n                path = os.path.join(root, filename)\n                os.makedirs(os.path.split(path)[0], exist_ok=True)\n                img = Image.fromarray(grid)\n                img.save(path)\n                if exists(pl_module):\n                    assert isinstance(\n                        pl_module.logger, WandbLogger\n                    ), \"logger_log_image only supports WandbLogger currently\"\n                    pl_module.logger.log_image(\n                        key=f\"{split}/{k}\",\n                        images=[\n                            img,\n                        ],\n                        step=pl_module.global_step,\n                    )\n\n    @rank_zero_only\n    def log_img(self, pl_module, batch, batch_idx, split=\"train\"):\n        check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step\n        if (\n            self.check_frequency(check_idx)\n            and hasattr(pl_module, \"log_images\")  # batch_idx % self.batch_freq == 0\n            and callable(pl_module.log_images)\n            and\n            # batch_idx > 5 and\n            self.max_images > 0\n        ):\n            logger = type(pl_module.logger)\n            is_train = pl_module.training\n            if is_train:\n                pl_module.eval()\n\n            gpu_autocast_kwargs = {\n                \"enabled\": self.enable_autocast,  # torch.is_autocast_enabled(),\n                \"dtype\": torch.get_autocast_gpu_dtype(),\n                \"cache_enabled\": torch.is_autocast_cache_enabled(),\n            }\n            with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs):\n                images = pl_module.log_images(\n                    batch, split=split, **self.log_images_kwargs\n                )\n\n            for k in images:\n                N = min(images[k].shape[0], self.max_images)\n                if not isheatmap(images[k]):\n                    images[k] = images[k][:N]\n                if isinstance(images[k], torch.Tensor):\n                    images[k] = images[k].detach().float().cpu()\n                    if self.clamp and not isheatmap(images[k]):\n                        images[k] = torch.clamp(images[k], -1.0, 1.0)\n\n            self.log_local(\n                pl_module.logger.save_dir,\n                split,\n                images,\n                pl_module.global_step,\n                pl_module.current_epoch,\n                batch_idx,\n                pl_module=pl_module\n                if isinstance(pl_module.logger, WandbLogger)\n                else None,\n            )\n\n            if is_train:\n                pl_module.train()\n\n    def check_frequency(self, check_idx):\n        if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (\n            check_idx > 0 or self.log_first_step\n        ):\n            try:\n                self.log_steps.pop(0)\n            except IndexError as e:\n                print(e)\n                pass\n            return True\n        return False\n\n    @rank_zero_only\n    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):\n        if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):\n            self.log_img(pl_module, batch, batch_idx, split=\"train\")\n\n    @rank_zero_only\n    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):\n        if self.log_before_first_step and pl_module.global_step == 0:\n            print(f\"{self.__class__.__name__}: logging before training\")\n            self.log_img(pl_module, batch, batch_idx, split=\"train\")\n\n    @rank_zero_only\n    def on_validation_batch_end(\n        self, trainer, pl_module, outputs, batch, batch_idx, *args, **kwargs\n    ):\n        if not self.disabled and pl_module.global_step > 0:\n            self.log_img(pl_module, batch, batch_idx, split=\"val\")\n        if hasattr(pl_module, \"calibrate_grad_norm\"):\n            if (\n                pl_module.calibrate_grad_norm and batch_idx % 25 == 0\n            ) and batch_idx > 0:\n                self.log_gradients(trainer, pl_module, batch_idx=batch_idx)\n\n\n@rank_zero_only\ndef init_wandb(save_dir, opt, config, group_name, name_str):\n    print(f\"setting WANDB_DIR to {save_dir}\")\n    os.makedirs(save_dir, exist_ok=True)\n\n    os.environ[\"WANDB_DIR\"] = save_dir\n    if opt.debug:\n        wandb.init(project=opt.projectname, mode=\"offline\", group=group_name)\n    else:\n        wandb.init(\n            project=opt.projectname,\n            config=config,\n            settings=wandb.Settings(code_dir=\"./sgm\"),\n            group=group_name,\n            name=name_str,\n        )\n\n\nif __name__ == \"__main__\":\n    # custom parser to specify config files, train, test and debug mode,\n    # postfix, resume.\n    # `--key value` arguments are interpreted as arguments to the trainer.\n    # `nested.key=value` arguments are interpreted as config parameters.\n    # configs are merged from left-to-right followed by command line parameters.\n\n    # model:\n    #   base_learning_rate: float\n    #   target: path to lightning module\n    #   params:\n    #       key: value\n    # data:\n    #   target: main.DataModuleFromConfig\n    #   params:\n    #      batch_size: int\n    #      wrap: bool\n    #      train:\n    #          target: path to train dataset\n    #          params:\n    #              key: value\n    #      validation:\n    #          target: path to validation dataset\n    #          params:\n    #              key: value\n    #      test:\n    #          target: path to test dataset\n    #          params:\n    #              key: value\n    # lightning: (optional, has sane defaults and can be specified on cmdline)\n    #   trainer:\n    #       additional arguments to trainer\n    #   logger:\n    #       logger to instantiate\n    #   modelcheckpoint:\n    #       modelcheckpoint to instantiate\n    #   callbacks:\n    #       callback1:\n    #           target: importpath\n    #           params:\n    #               key: value\n\n    now = datetime.datetime.now().strftime(\"%Y-%m-%dT%H-%M-%S\")\n\n    # add cwd for convenience and to make classes in this file available when\n    # running as `python main.py`\n    # (in particular `main.DataModuleFromConfig`)\n    sys.path.append(os.getcwd())\n\n    parser = get_parser()\n\n    opt, unknown = parser.parse_known_args()\n\n    if opt.name and opt.resume:\n        raise ValueError(\n            \"-n/--name and -r/--resume cannot be specified both.\"\n            \"If you want to resume training in a new log folder, \"\n            \"use -n/--name in combination with --resume_from_checkpoint\"\n        )\n    melk_ckpt_name = None\n    name = None\n    if opt.resume:\n        if not os.path.exists(opt.resume):\n            raise ValueError(\"Cannot find {}\".format(opt.resume))\n        if os.path.isfile(opt.resume):\n            paths = opt.resume.split(\"/\")\n            # idx = len(paths)-paths[::-1].index(\"logs\")+1\n            # logdir = \"/\".join(paths[:idx])\n            logdir = \"/\".join(paths[:-2])\n            ckpt = opt.resume\n            _, melk_ckpt_name = get_checkpoint_name(logdir)\n        else:\n            assert os.path.isdir(opt.resume), opt.resume\n            logdir = opt.resume.rstrip(\"/\")\n            ckpt, melk_ckpt_name = get_checkpoint_name(logdir)\n\n        print(\"#\" * 100)\n        print(f'Resuming from checkpoint \"{ckpt}\"')\n        print(\"#\" * 100)\n\n        opt.resume_from_checkpoint = ckpt\n        base_configs = sorted(glob.glob(os.path.join(logdir, \"configs/*.yaml\")))\n        opt.base = base_configs + opt.base\n        _tmp = logdir.split(\"/\")\n        nowname = _tmp[-1]\n    else:\n        if opt.name:\n            name = \"_\" + opt.name\n        elif opt.base:\n            if opt.no_base_name:\n                name = \"\"\n            else:\n                if opt.legacy_naming:\n                    cfg_fname = os.path.split(opt.base[0])[-1]\n                    cfg_name = os.path.splitext(cfg_fname)[0]\n                else:\n                    assert \"configs\" in os.path.split(opt.base[0])[0], os.path.split(\n                        opt.base[0]\n                    )[0]\n                    cfg_path = os.path.split(opt.base[0])[0].split(os.sep)[\n                        os.path.split(opt.base[0])[0].split(os.sep).index(\"configs\")\n                        + 1 :\n                    ]  # cut away the first one (we assert all configs are in \"configs\")\n                    cfg_name = os.path.splitext(os.path.split(opt.base[0])[-1])[0]\n                    cfg_name = \"-\".join(cfg_path) + f\"-{cfg_name}\"\n                name = \"_\" + cfg_name\n        else:\n            name = \"\"\n        if not opt.no_date:\n            nowname = now + name + opt.postfix\n        else:\n            nowname = name + opt.postfix\n            if nowname.startswith(\"_\"):\n                nowname = nowname[1:]\n        logdir = os.path.join(opt.logdir, nowname)\n        print(f\"LOGDIR: {logdir}\")\n\n    ckptdir = os.path.join(logdir, \"checkpoints\")\n    cfgdir = os.path.join(logdir, \"configs\")\n    seed_everything(opt.seed, workers=True)\n\n    # move before model init, in case a torch.compile(...) is called somewhere\n    if opt.enable_tf32:\n        # pt_version = version.parse(torch.__version__)\n        torch.backends.cuda.matmul.allow_tf32 = True\n        torch.backends.cudnn.allow_tf32 = True\n        print(f\"Enabling TF32 for PyTorch {torch.__version__}\")\n    else:\n        print(f\"Using default TF32 settings for PyTorch {torch.__version__}:\")\n        print(\n            f\"torch.backends.cuda.matmul.allow_tf32={torch.backends.cuda.matmul.allow_tf32}\"\n        )\n        print(f\"torch.backends.cudnn.allow_tf32={torch.backends.cudnn.allow_tf32}\")\n\n    try:\n        # init and save configs\n        configs = [OmegaConf.load(cfg) for cfg in opt.base]\n        cli = OmegaConf.from_dotlist(unknown)\n        config = OmegaConf.merge(*configs, cli)\n        lightning_config = config.pop(\"lightning\", OmegaConf.create())\n        # merge trainer cli with config\n        trainer_config = lightning_config.get(\"trainer\", OmegaConf.create())\n\n        # default to gpu\n        trainer_config[\"accelerator\"] = \"gpu\"\n        #\n        standard_args = default_trainer_args()\n        for k in standard_args:\n            if getattr(opt, k) != standard_args[k]:\n                trainer_config[k] = getattr(opt, k)\n\n        ckpt_resume_path = opt.resume_from_checkpoint\n\n        if not \"devices\" in trainer_config and trainer_config[\"accelerator\"] != \"gpu\":\n            del trainer_config[\"accelerator\"]\n            cpu = True\n        else:\n            gpuinfo = trainer_config[\"devices\"]\n            print(f\"Running on GPUs {gpuinfo}\")\n            cpu = False\n        trainer_opt = argparse.Namespace(**trainer_config)\n        lightning_config.trainer = trainer_config\n\n        # model\n        model = instantiate_from_config(config.model)\n\n        # trainer and callbacks\n        trainer_kwargs = dict()\n\n        # default logger configs\n        default_logger_cfgs = {\n            \"wandb\": {\n                \"target\": \"pytorch_lightning.loggers.WandbLogger\",\n                \"params\": {\n                    \"name\": nowname,\n                    # \"save_dir\": logdir,\n                    \"offline\": opt.debug,\n                    \"id\": nowname,\n                    \"project\": opt.projectname,\n                    \"log_model\": False,\n                    # \"dir\": logdir,\n                },\n            },\n            \"csv\": {\n                \"target\": \"pytorch_lightning.loggers.CSVLogger\",\n                \"params\": {\n                    \"name\": \"testtube\",  # hack for sbord fanatics\n                    \"save_dir\": logdir,\n                },\n            },\n        }\n        default_logger_cfg = default_logger_cfgs[\"wandb\" if opt.wandb else \"csv\"]\n        if opt.wandb:\n            # TODO change once leaving \"swiffer\" config directory\n            try:\n                group_name = nowname.split(now)[-1].split(\"-\")[1]\n            except:\n                group_name = nowname\n            default_logger_cfg[\"params\"][\"group\"] = group_name\n            init_wandb(\n                os.path.join(os.getcwd(), logdir),\n                opt=opt,\n                group_name=group_name,\n                config=config,\n                name_str=nowname,\n            )\n        if \"logger\" in lightning_config:\n            logger_cfg = lightning_config.logger\n        else:\n            logger_cfg = OmegaConf.create()\n        logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)\n        trainer_kwargs[\"logger\"] = instantiate_from_config(logger_cfg)\n\n        # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to\n        # specify which metric is used to determine best models\n        default_modelckpt_cfg = {\n            \"target\": \"pytorch_lightning.callbacks.ModelCheckpoint\",\n            \"params\": {\n                \"dirpath\": ckptdir,\n                \"filename\": \"{epoch:06}\",\n                \"verbose\": True,\n                \"save_last\": True,\n            },\n        }\n        if hasattr(model, \"monitor\"):\n            print(f\"Monitoring {model.monitor} as checkpoint metric.\")\n            default_modelckpt_cfg[\"params\"][\"monitor\"] = model.monitor\n            default_modelckpt_cfg[\"params\"][\"save_top_k\"] = 3\n\n        if \"modelcheckpoint\" in lightning_config:\n            modelckpt_cfg = lightning_config.modelcheckpoint\n        else:\n            modelckpt_cfg = OmegaConf.create()\n        modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)\n        print(f\"Merged modelckpt-cfg: \\n{modelckpt_cfg}\")\n\n        # https://pytorch-lightning.readthedocs.io/en/stable/extensions/strategy.html\n        # default to ddp if not further specified\n        default_strategy_config = {\"target\": \"pytorch_lightning.strategies.DDPStrategy\"}\n\n        if \"strategy\" in lightning_config:\n            strategy_cfg = lightning_config.strategy\n        else:\n            strategy_cfg = OmegaConf.create()\n            default_strategy_config[\"params\"] = {\n                \"find_unused_parameters\": False,\n                # \"static_graph\": True,\n                # \"ddp_comm_hook\": default.fp16_compress_hook  # TODO: experiment with this, also for DDPSharded\n            }\n        strategy_cfg = OmegaConf.merge(default_strategy_config, strategy_cfg)\n        print(\n            f\"strategy config: \\n ++++++++++++++ \\n {strategy_cfg} \\n ++++++++++++++ \"\n        )\n        trainer_kwargs[\"strategy\"] = instantiate_from_config(strategy_cfg)\n\n        # add callback which sets up log directory\n        default_callbacks_cfg = {\n            \"setup_callback\": {\n                \"target\": \"main.SetupCallback\",\n                \"params\": {\n                    \"resume\": opt.resume,\n                    \"now\": now,\n                    \"logdir\": logdir,\n                    \"ckptdir\": ckptdir,\n                    \"cfgdir\": cfgdir,\n                    \"config\": config,\n                    \"lightning_config\": lightning_config,\n                    \"debug\": opt.debug,\n                    \"ckpt_name\": melk_ckpt_name,\n                },\n            },\n            \"image_logger\": {\n                \"target\": \"main.ImageLogger\",\n                \"params\": {\"batch_frequency\": 1000, \"max_images\": 4, \"clamp\": True},\n            },\n            \"learning_rate_logger\": {\n                \"target\": \"pytorch_lightning.callbacks.LearningRateMonitor\",\n                \"params\": {\n                    \"logging_interval\": \"step\",\n                    # \"log_momentum\": True\n                },\n            },\n        }\n        if version.parse(pl.__version__) >= version.parse(\"1.4.0\"):\n            default_callbacks_cfg.update({\"checkpoint_callback\": modelckpt_cfg})\n\n        if \"callbacks\" in lightning_config:\n            callbacks_cfg = lightning_config.callbacks\n        else:\n            callbacks_cfg = OmegaConf.create()\n\n        if \"metrics_over_trainsteps_checkpoint\" in callbacks_cfg:\n            print(\n                \"Caution: Saving checkpoints every n train steps without deleting. This might require some free space.\"\n            )\n            default_metrics_over_trainsteps_ckpt_dict = {\n                \"metrics_over_trainsteps_checkpoint\": {\n                    \"target\": \"pytorch_lightning.callbacks.ModelCheckpoint\",\n                    \"params\": {\n                        \"dirpath\": os.path.join(ckptdir, \"trainstep_checkpoints\"),\n                        \"filename\": \"{epoch:06}-{step:09}\",\n                        \"verbose\": True,\n                        \"save_top_k\": -1,\n                        \"every_n_train_steps\": 10000,\n                        \"save_weights_only\": True,\n                    },\n                }\n            }\n            default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)\n\n        callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)\n        if \"ignore_keys_callback\" in callbacks_cfg and ckpt_resume_path is not None:\n            callbacks_cfg.ignore_keys_callback.params[\"ckpt_path\"] = ckpt_resume_path\n        elif \"ignore_keys_callback\" in callbacks_cfg:\n            del callbacks_cfg[\"ignore_keys_callback\"]\n\n        trainer_kwargs[\"callbacks\"] = [\n            instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg\n        ]\n        if not \"plugins\" in trainer_kwargs:\n            trainer_kwargs[\"plugins\"] = list()\n\n        # cmd line trainer args (which are in trainer_opt) have always priority over config-trainer-args (which are in trainer_kwargs)\n        trainer_opt = vars(trainer_opt)\n        trainer_kwargs = {\n            key: val for key, val in trainer_kwargs.items() if key not in trainer_opt\n        }\n        trainer = Trainer(**trainer_opt, **trainer_kwargs)\n\n        trainer.logdir = logdir  ###\n\n        # data\n        data = instantiate_from_config(config.data)\n        # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html\n        # calling these ourselves should not be necessary but it is.\n        # lightning still takes care of proper multiprocessing though\n        data.prepare_data()\n        # data.setup()\n        print(\"#### Data #####\")\n        try:\n            for k in data.datasets:\n                print(\n                    f\"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}\"\n                )\n        except:\n            print(\"datasets not yet initialized.\")\n\n        # configure learning rate\n        if \"batch_size\" in config.data.params:\n            bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate\n        else:\n            bs, base_lr = (\n                config.data.params.train.loader.batch_size,\n                config.model.base_learning_rate,\n            )\n        if not cpu:\n            ngpu = len(lightning_config.trainer.devices.strip(\",\").split(\",\"))\n        else:\n            ngpu = 1\n        if \"accumulate_grad_batches\" in lightning_config.trainer:\n            accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches\n        else:\n            accumulate_grad_batches = 1\n        print(f\"accumulate_grad_batches = {accumulate_grad_batches}\")\n        lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches\n        if opt.scale_lr:\n            model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr\n            print(\n                \"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)\".format(\n                    model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr\n                )\n            )\n        else:\n            model.learning_rate = base_lr\n            print(\"++++ NOT USING LR SCALING ++++\")\n            print(f\"Setting learning rate to {model.learning_rate:.2e}\")\n\n        # allow checkpointing via USR1\n        def melk(*args, **kwargs):\n            # run all checkpoint hooks\n            if trainer.global_rank == 0:\n                print(\"Summoning checkpoint.\")\n                if melk_ckpt_name is None:\n                    ckpt_path = os.path.join(ckptdir, \"last.ckpt\")\n                else:\n                    ckpt_path = os.path.join(ckptdir, melk_ckpt_name)\n                trainer.save_checkpoint(ckpt_path)\n\n        def divein(*args, **kwargs):\n            if trainer.global_rank == 0:\n                import pudb\n\n                pudb.set_trace()\n\n        import signal\n\n        signal.signal(signal.SIGUSR1, melk)\n        signal.signal(signal.SIGUSR2, divein)\n\n        # run\n        if opt.train:\n            try:\n                trainer.fit(model, data, ckpt_path=ckpt_resume_path)\n            except Exception:\n                if not opt.debug:\n                    melk()\n                raise\n        if not opt.no_test and not trainer.interrupted:\n            trainer.test(model, data)\n    except RuntimeError as err:\n        if MULTINODE_HACKS:\n            import datetime\n            import os\n            import socket\n\n            import requests\n\n            device = os.environ.get(\"CUDA_VISIBLE_DEVICES\", \"?\")\n            hostname = socket.gethostname()\n            ts = datetime.datetime.utcnow().strftime(\"%Y-%m-%d %H:%M:%S\")\n            resp = requests.get(\"http://169.254.169.254/latest/meta-data/instance-id\")\n            print(\n                f\"ERROR at {ts} on {hostname}/{resp.text} (CUDA_VISIBLE_DEVICES={device}): {type(err).__name__}: {err}\",\n                flush=True,\n            )\n        raise err\n    except Exception:\n        if opt.debug and trainer.global_rank == 0:\n            try:\n                import pudb as debugger\n            except ImportError:\n                import pdb as debugger\n            debugger.post_mortem()\n        raise\n    finally:\n        # move newly created debug project to debug_runs\n        if opt.debug and not opt.resume and trainer.global_rank == 0:\n            dst, name = os.path.split(logdir)\n            dst = os.path.join(dst, \"debug_runs\", name)\n            os.makedirs(os.path.split(dst)[0], exist_ok=True)\n            os.rename(logdir, dst)\n\n        if opt.wandb:\n            wandb.finish()\n        # if trainer.global_rank == 0:\n        #    print(trainer.profiler.summary())\n"
  },
  {
    "path": "model_licenses/LICENSE-SDXL-Turbo",
    "content": "STABILITY AI NON-COMMERCIAL RESEARCH COMMUNITY LICENSE AGREEMENT        \nDated: November 28, 2023\n\n\nBy using or distributing any portion or element of the Models, Software, Software Products or Derivative Works, you agree to be bound by this Agreement.\n\n\n\"Agreement\" means this Stable Non-Commercial Research Community License Agreement.\n\n\n“AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.\n\n\n\"Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model.\n\n\n“Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.\n\n\n\"Licensee\" or \"you\" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.\n\n\n“Model(s)\" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement.\n\n\n“Non-Commercial Uses” means exercising any of the rights granted herein for the purpose of research or non-commercial purposes. Non-Commercial Uses does not include any production use of the Software Products or any Derivative Works. \n\n\n\"Stability AI\" or \"we\" means Stability AI Ltd. and its affiliates.\n\n\"Software\" means Stability AI’s proprietary software made available under this Agreement. \n\n\n“Software Products” means the Models, Software and Documentation, individually or in any combination. \n\n\n\n1.     License Rights and Redistribution. \n\na.  Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned or controlled by Stability AI embodied in the Software Products to reproduce the Software Products and produce, reproduce, distribute, and create Derivative Works of the Software Products for Non-Commercial Uses only, respectively. \n\nb.  You may not use the Software Products or Derivative Works to enable third parties to use the Software Products or Derivative Works as part of your hosted service or via your APIs, whether you are adding substantial additional functionality thereto or not. Merely distributing the Software Products or Derivative Works for download online without offering any related service (ex. by distributing the Models on HuggingFace) is not a violation of this subsection. If you wish to use the Software Products or any Derivative Works for commercial or production use or you wish to make the Software Products or any Derivative Works available to third parties via your hosted service or your APIs, contact Stability AI at https://stability.ai/contact.    \n\nc.  If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a \"Notice\" text file distributed as a part of such copies: \"This Stability AI Model is licensed under the Stability AI Non-Commercial Research Community License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified.\n\n2.     Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS  AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN \"AS IS\" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS. \n\n3.     Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. \n\n4.     Intellectual Property.\n\na.  No trademark licenses are granted under this Agreement, and in connection with the Software Products or Derivative Works, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products or Derivative Works. \n\nb.  Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works \n\nc.  If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products or Derivative Works in violation of this Agreement. \n\n5.      Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of any Software Products or Derivative Works. Sections 2-4 shall survive the termination of this Agreement.\n"
  },
  {
    "path": "model_licenses/LICENSE-SDXL0.9",
    "content": "SDXL 0.9 RESEARCH LICENSE AGREEMENT\nCopyright (c) Stability AI Ltd.\nThis License Agreement (as may be amended in accordance with this License Agreement, “License”), between you, or your employer or other entity (if you are entering into this agreement on behalf of your employer or other entity) (“Licensee” or “you”) and Stability AI Ltd. (“Stability AI” or “we”) applies to your use of any computer program, algorithm, source code, object code, or software that is made available by Stability AI under this License (“Software”) and any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software (“Documentation”).\nBy clicking “I Accept” below or by using the Software, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to use the Software or Documentation (collectively, the “Software Products”), and you must immediately cease using the Software Products. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to Stability AI that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the Software Products on behalf of your employer or other entity.\n1. LICENSE GRANT\n\na. Subject to your compliance with the Documentation and Sections 2, 3, and 5, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s copyright interests to reproduce, distribute, and create derivative works of the Software solely for your non-commercial research purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Stability AI’s prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License.\n\nb. You may make a reasonable number of copies of the Documentation solely for use in connection with the license to the Software granted above.\n\nc. The grant of rights expressly set forth in this Section 1 (License Grant) are the complete grant of rights to you in the Software Products, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Stability AI and its licensors reserve all rights not expressly granted by this License.\n\n\n2. RESTRICTIONS\n\nYou will not, and will not permit, assist or cause any third party to:\n\na. use, modify, copy, reproduce, create derivative works of, or distribute the Software Products (or any derivative works thereof, works incorporating the Software Products, or any data produced by the Software), in whole or in part, for (i) any commercial or production purposes, (ii) military purposes or in the service of nuclear technology, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates any third-party rights, or (vi) in any manner that violates any applicable law and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and all laws governing the processing of biometric information), as well as all amendments and successor laws to any of the foregoing;\n\nb. alter or remove copyright and other proprietary notices which appear on or in the Software Products;\n\nc. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Stability AI in connection with the Software, or to circumvent or remove any usage restrictions, or to enable functionality disabled by Stability AI; or\n\nd. offer or impose any terms on the Software Products that alter, restrict, or are inconsistent with the terms of this License.\n\ne. 1) violate any applicable U.S. and non-U.S. export control and trade sanctions laws (“Export Laws”); 2) directly or indirectly export, re-export, provide, or otherwise transfer Software Products: (a) to any individual, entity, or country prohibited by Export Laws; (b) to anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications; 3) use or download Software Products if you or they are: (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any purpose prohibited by Export Laws; and (4) will not disguise your location through IP proxying or other methods.\n\n\n3. ATTRIBUTION\n\nTogether with any copies of the Software Products (as well as derivative works thereof or works incorporating the Software Products) that you distribute, you must provide (i) a copy of this License, and (ii) the following attribution notice: “SDXL 0.9 is licensed under the SDXL Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.”\n\n\n4. DISCLAIMERS\n\nTHE SOFTWARE PRODUCTS ARE PROVIDED “AS IS” AND “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. STABILITY AIEXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE SOFTWARE PRODUCTS, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. STABILITY AI MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE SOFTWARE PRODUCTS WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS.\n\n\n5. LIMITATION OF LIABILITY\n\nTO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL STABILITY AI BE LIABLE TO YOU (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF STABILITY AI HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE SOFTWARE PRODUCTS, THEIR CONSTITUENT COMPONENTS, AND ANY OUTPUT (COLLECTIVELY, “SOFTWARE MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE SOFTWARE MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE SOFTWARE MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE SOFTWARE MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE.\n\n\n6. INDEMNIFICATION\n\nYou will indemnify, defend and hold harmless Stability AI and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “Stability AI Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any Stability AI Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to: (a) your access to or use of the Software Products (as well as any results or data generated from such access or use), including any High-Risk Use (defined below); (b) your violation of this License; or (c) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Stability AI Parties of any such Claims, and cooperate with Stability AI Parties in defending such Claims. You will also grant the Stability AI Parties sole control of the defense or settlement, at Stability AI’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Stability AI or the other Stability AI Parties.\n\n\n7. TERMINATION; SURVIVAL\n\na. This License will automatically terminate upon any breach by you of the terms of this License.\n\nb. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you.\n\nc. The following sections survive termination of this License: 2 (Restrictions), 3 (Attribution), 4 (Disclaimers), 5 (Limitation on Liability), 6 (Indemnification) 7 (Termination; Survival), 8 (Third Party Materials), 9 (Trademarks), 10 (Applicable Law; Dispute Resolution), and 11 (Miscellaneous).\n\n\n8. THIRD PARTY MATERIALS\n\nThe Software Products may contain third-party software or other components (including free and open source software) (all of the foregoing, “Third Party Materials”), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Stability AI does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk.\n\n\n9. TRADEMARKS\n\nLicensee has not been granted any trademark license as part of this License and may not use any name or mark associated with Stability AI without the prior written permission of Stability AI, except to the extent necessary to make the reference required by the “ATTRIBUTION” section of this Agreement.\n\n\n10. APPLICABLE LAW; DISPUTE RESOLUTION\n\nThis License will be governed and construed under the laws of the State of California without regard to conflicts of law provisions. Any suit or proceeding arising out of or relating to this License will be brought in the federal or state courts, as applicable, in San Mateo County, California, and each party irrevocably submits to the jurisdiction and venue of such courts.\n\n\n11. MISCELLANEOUS\n\nIf any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Stability AI to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the Documentation, contains the entire understanding between you and Stability AI regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Stability AI regarding such subject matter. No change or addition to any provision of this License will be binding unless it is in writing and signed by an authorized representative of both you and Stability AI."
  },
  {
    "path": "model_licenses/LICENSE-SDXL1.0",
    "content": "Copyright (c) 2023 Stability AI CreativeML Open RAIL++-M License dated July 26, 2023\n\nSection I: PREAMBLE Multimodal generative models are being widely adopted and used, and\nhave the potential to transform the way artists, among other individuals, conceive and\nbenefit from AI or ML technologies as a tool for content creation. Notwithstanding the\ncurrent and potential benefits that these artifacts can bring to society at large, there\nare also concerns about potential misuses of them, either due to their technical\nlimitations or ethical considerations. In short, this license strives for both the open\nand responsible downstream use of the accompanying model. When it comes to the open\ncharacter, we took inspiration from open source permissive licenses regarding the grant\nof IP rights. Referring to the downstream responsible use, we added use-based\nrestrictions not permitting the use of the model in very specific scenarios, in order\nfor the licensor to be able to enforce the license in case potential misuses of the\nModel may occur. At the same time, we strive to promote open and responsible research on\ngenerative models for art and content generation. Even though downstream derivative\nversions of the model could be released under different licensing terms, the latter will\nalways have to include - at minimum - the same use-based restrictions as the ones in the\noriginal license (this license). We believe in the intersection between open and\nresponsible AI development; thus, this agreement aims to strike a balance between both\nin order to enable responsible open-science in the field of AI. This CreativeML Open\nRAIL++-M License governs the use of the model (and its derivatives) and is informed by\nthe model card associated with the model. NOW THEREFORE, You and Licensor agree as\nfollows: Definitions \"License\" means the terms and conditions for use, reproduction, and\nDistribution as defined in this document. \"Data\" means a collection of information\nand/or content extracted from the dataset used with the Model, including to train,\npretrain, or otherwise evaluate the Model. The Data is not licensed under this License.\n\"Output\" means the results of operating a Model as embodied in informational content\nresulting therefrom. \"Model\" means any accompanying machine-learning based assemblies\n(including checkpoints), consisting of learnt weights, parameters (including optimizer\nstates), corresponding to the model architecture as embodied in the Complementary\nMaterial, that have been trained or tuned, in whole or in part on the Data, using the\nComplementary Material. \"Derivatives of the Model\" means all modifications to the Model,\nworks based on the Model, or any other model which is created or initialized by transfer\nof patterns of the weights, parameters, activations or output of the Model, to the other\nmodel, in order to cause the other model to perform similarly to the Model, including -\nbut not limited to - distillation methods entailing the use of intermediate data\nrepresentations or methods based on the generation of synthetic data by the Model for\ntraining the other model. \"Complementary Material\" means the accompanying source code\nand scripts used to define, run, load, benchmark or evaluate the Model, and used to\nprepare data for training or evaluation, if any. This includes any accompanying\ndocumentation, tutorials, examples, etc, if any. \"Distribution\" means any transmission,\nreproduction, publication or other sharing of the Model or Derivatives of the Model to a\nthird party, including providing the Model as a hosted service made available by\nelectronic or other remote means - e.g. API-based or web access. \"Licensor\" means the\ncopyright owner or entity authorized by the copyright owner that is granting the\nLicense, including the persons or entities that may have rights in the Model and/or\ndistributing the Model. \"You\" (or \"Your\") means an individual or Legal Entity exercising\npermissions granted by this License and/or making use of the Model for whichever purpose\nand in any field of use, including usage of the Model in an end-use application - e.g.\nchatbot, translator, image generator. \"Third Parties\" means individuals or legal\nentities that are not under common control with Licensor or You. \"Contribution\" means\nany work of authorship, including the original version of the Model and any\nmodifications or additions to that Model or Derivatives of the Model thereof, that is\nintentionally submitted to Licensor for inclusion in the Model by the copyright owner or\nby an individual or Legal Entity authorized to submit on behalf of the copyright owner.\nFor the purposes of this definition, \"submitted\" means any form of electronic, verbal,\nor written communication sent to the Licensor or its representatives, including but not\nlimited to communication on electronic mailing lists, source code control systems, and\nissue tracking systems that are managed by, or on behalf of, the Licensor for the\npurpose of discussing and improving the Model, but excluding communication that is\nconspicuously marked or otherwise designated in writing by the copyright owner as \"Not a\nContribution.\" \"Contributor\" means Licensor and any individual or Legal Entity on behalf\nof whom a Contribution has been received by Licensor and subsequently incorporated\nwithin the Model.\n\nSection II: INTELLECTUAL PROPERTY RIGHTS Both copyright and patent grants apply to the\nModel, Derivatives of the Model and Complementary Material. The Model and Derivatives of\nthe Model are subject to additional terms as described in\n\nSection III. Grant of Copyright License. Subject to the terms and conditions of this\nLicense, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive,\nno-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly\ndisplay, publicly perform, sublicense, and distribute the Complementary Material, the\nModel, and Derivatives of the Model. Grant of Patent License. Subject to the terms and\nconditions of this License and where and as applicable, each Contributor hereby grants\nto You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n(except as stated in this paragraph) patent license to make, have made, use, offer to\nsell, sell, import, and otherwise transfer the Model and the Complementary Material,\nwhere such license applies only to those patent claims licensable by such Contributor\nthat are necessarily infringed by their Contribution(s) alone or by combination of their\nContribution(s) with the Model to which such Contribution(s) was submitted. If You\ninstitute patent litigation against any entity (including a cross-claim or counterclaim\nin a lawsuit) alleging that the Model and/or Complementary Material or a Contribution\nincorporated within the Model and/or Complementary Material constitutes direct or\ncontributory patent infringement, then any patent licenses granted to You under this\nLicense for the Model and/or Work shall terminate as of the date such litigation is\nasserted or filed. Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION\nDistribution and Redistribution. You may host for Third Party remote access purposes\n(e.g. software-as-a-service), reproduce and distribute copies of the Model or\nDerivatives of the Model thereof in any medium, with or without modifications, provided\nthat You meet the following conditions: Use-based restrictions as referenced in\nparagraph 5 MUST be included as an enforceable provision by You in any type of legal\nagreement (e.g. a license) governing the use and/or distribution of the Model or\nDerivatives of the Model, and You shall give notice to subsequent users You Distribute\nto, that the Model or Derivatives of the Model are subject to paragraph 5. This\nprovision does not apply to the use of Complementary Material. You must give any Third\nParty recipients of the Model or Derivatives of the Model a copy of this License; You\nmust cause any modified files to carry prominent notices stating that You changed the\nfiles; You must retain all copyright, patent, trademark, and attribution notices\nexcluding those notices that do not pertain to any part of the Model, Derivatives of the\nModel. You may add Your own copyright statement to Your modifications and may provide\nadditional or different license terms and conditions - respecting paragraph 4.a. - for\nuse, reproduction, or Distribution of Your modifications, or for any such Derivatives of\nthe Model as a whole, provided Your use, reproduction, and Distribution of the Model\notherwise complies with the conditions stated in this License. Use-based restrictions.\nThe restrictions set forth in Attachment A are considered Use-based restrictions.\nTherefore You cannot use the Model and the Derivatives of the Model for the specified\nrestricted uses. You may use the Model subject to this License, including only for\nlawful purposes and in accordance with the License. Use may include creating any content\nwith, finetuning, updating, running, training, evaluating and/or reparametrizing the\nModel. You shall require all of Your users who use the Model or a Derivative of the\nModel to comply with the terms of this paragraph (paragraph 5). The Output You Generate.\nExcept as set forth herein, Licensor claims no rights in the Output You generate using\nthe Model. You are accountable for the Output you generate and its subsequent uses. No\nuse of the output can contravene any provision as stated in the License.\n\nSection IV: OTHER PROVISIONS Updates and Runtime Restrictions. To the maximum extent\npermitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage\nof the Model in violation of this License. Trademarks and related. Nothing in this\nLicense permits You to make use of Licensors’ trademarks, trade names, logos or to\notherwise suggest endorsement or misrepresent the relationship between the parties; and\nany rights not expressly granted herein are reserved by the Licensors. Disclaimer of\nWarranty. Unless required by applicable law or agreed to in writing, Licensor provides\nthe Model and the Complementary Material (and each Contributor provides its\nContributions) on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either\nexpress or implied, including, without limitation, any warranties or conditions of\nTITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are\nsolely responsible for determining the appropriateness of using or redistributing the\nModel, Derivatives of the Model, and the Complementary Material and assume any risks\nassociated with Your exercise of permissions under this License. Limitation of\nLiability. In no event and under no legal theory, whether in tort (including\nnegligence), contract, or otherwise, unless required by applicable law (such as\ndeliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be\nliable to You for damages, including any direct, indirect, special, incidental, or\nconsequential damages of any character arising as a result of this License or out of the\nuse or inability to use the Model and the Complementary Material (including but not\nlimited to damages for loss of goodwill, work stoppage, computer failure or malfunction,\nor any and all other commercial damages or losses), even if such Contributor has been\nadvised of the possibility of such damages. Accepting Warranty or Additional Liability.\nWhile redistributing the Model, Derivatives of the Model and the Complementary Material\nthereof, You may choose to offer, and charge a fee for, acceptance of support, warranty,\nindemnity, or other liability obligations and/or rights consistent with this License.\nHowever, in accepting such obligations, You may act only on Your own behalf and on Your\nsole responsibility, not on behalf of any other Contributor, and only if You agree to\nindemnify, defend, and hold each Contributor harmless for any liability incurred by, or\nclaims asserted against, such Contributor by reason of your accepting any such warranty\nor additional liability. If any provision of this License is held to be invalid, illegal\nor unenforceable, the remaining provisions shall be unaffected thereby and remain valid\nas if such provision had not been set forth herein.\n\nEND OF TERMS AND CONDITIONS\n\nAttachment A Use Restrictions\nYou agree not to use the Model or Derivatives of the Model:\nIn any way that violates any applicable national, federal, state, local or\ninternational law or regulation; For the purpose of exploiting, harming or attempting to\nexploit or harm minors in any way; To generate or disseminate verifiably false\ninformation and/or content with the purpose of harming others; To generate or\ndisseminate personal identifiable information that can be used to harm an individual; To\ndefame, disparage or otherwise harass others; For fully automated decision making that\nadversely impacts an individual’s legal rights or otherwise creates or modifies a\nbinding, enforceable obligation; For any use intended to or which has the effect of\ndiscriminating against or harming individuals or groups based on online or offline\nsocial behavior or known or predicted personal or personality characteristics; To\nexploit any of the vulnerabilities of a specific group of persons based on their age,\nsocial, physical or mental characteristics, in order to materially distort the behavior\nof a person pertaining to that group in a manner that causes or is likely to cause that\nperson or another person physical or psychological harm; For any use intended to or\nwhich has the effect of discriminating against individuals or groups based on legally\nprotected characteristics or categories; To provide medical advice and medical results\ninterpretation; To generate or disseminate information for the purpose to be used for\nadministration of justice, law enforcement, immigration or asylum processes, such as\npredicting an individual will commit fraud/crime commitment (e.g. by text profiling,\ndrawing causal relationships between assertions made in documents, indiscriminate and\narbitrarily-targeted use).\n"
  },
  {
    "path": "model_licenses/LICENSE-SV3D",
    "content": "STABILITY AI NON-COMMERCIAL COMMUNITY LICENSE AGREEMENT\nDated: March 18, 2024\n\n\"Agreement\" means this Stable Non-Commercial Research Community License Agreement.\n\n“AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.\n\n\"Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws, (b) any modifications to a Model, and (c) any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model.\n\n“Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.\n\n\"Licensee\" or \"you\" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.\n\n“Model(s)\" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement.\n\n“Non-Commercial Uses” means exercising any of the rights granted herein for the purpose of research or non-commercial purposes. Non-Commercial Uses does not include any production use of the Software Products or any Derivative Works.\n\n\"Stability AI\" or \"we\" means Stability AI Ltd and its affiliates.\n\n\n\"Software\" means Stability AI’s proprietary software made available under this Agreement.\n\n“Software Products” means the Models, Software and Documentation, individually or in any combination.\n\n\n\n1. \tLicense Rights and Redistribution.\na.  \tSubject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned or controlled by Stability AI embodied in the Software Products to use, reproduce, distribute, and create Derivative Works of, the Software Products, in each case for Non-Commercial Uses only.\nb.   You may not use the Software Products or Derivative Works to enable third parties to use the Software Products or Derivative Works as part of your hosted service or via your APIs, whether you are adding substantial additional functionality thereto or not. Merely distributing the Software Products or Derivative Works for download online without offering any related service (ex. by distributing the Models on HuggingFace) is not a violation of this subsection. If you wish to use the Software Products or any Derivative Works for commercial or production use or you wish to make the Software Products or any Derivative Works available to third parties via your hosted service or your APIs, contact Stability AI at https://stability.ai/contact.\nc.\tIf you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a \"Notice\" text file distributed as a part of such copies: \"This Stability AI Model is licensed under the Stability AI Non-Commercial Research Community License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified.\n2.\tDisclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS  AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN \"AS IS\" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.\n3.\tLimitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.\n4.   \tIntellectual Property.\na. \tNo trademark licenses are granted under this Agreement, and in connection with the Software Products or Derivative Works, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products or Derivative Works.\nb.\tSubject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works\nc. \tIf you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products or Derivative Works in violation of this Agreement.\n5. \tTerm and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of any Software Products or Derivative Works. Sections 2-4 shall survive the termination of this Agreement.\n\n6.\tGoverning Law. This Agreement will be governed by and construed in accordance with the laws of the United States and the State of California without regard to choice of law\nprinciples.\n\n"
  },
  {
    "path": "model_licenses/LICENSE-SVD",
    "content": "STABLE VIDEO DIFFUSION NON-COMMERCIAL COMMUNITY LICENSE AGREEMENT\t\nDated: November 21, 2023\n\n“AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.\n\n\"Agreement\" means the terms and conditions for use, reproduction, distribution and modification of the Software Products set forth herein.\n\"Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model.\n“Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.\n\n\"Licensee\" or \"you\" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.\n\n\"Stability AI\" or \"we\" means Stability AI Ltd. \n\n\"Software\" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement.\n\n“Software Products” means Software and Documentation. \n\nBy using or distributing any portion or element of the Software Products, you agree to be bound by this Agreement.\n\n\n\nLicense Rights and Redistribution. \nSubject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Software Products to reproduce, distribute, and create Derivative Works of the Software Products for purposes other than commercial or production use.     \nb.\tIf you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a \"Notice\" text file distributed as a part of such copies: \"Stable Video Diffusion is licensed under the Stable Video Diffusion Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified.\n2. \t  Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS  AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN \"AS IS\" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS. \n3.   Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. \n3.   Intellectual Property.\na. \tNo trademark licenses are granted under this Agreement, and in connection with the Software Products, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products. \nSubject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works. \nIf you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products in violation of this Agreement. \n4.   Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Software Products. Sections 2-4 shall survive the termination of this Agreement. \n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"hatchling\"]\nbuild-backend = \"hatchling.build\"\n\n[project]\nname = \"sgm\"\ndynamic = [\"version\"]\ndescription = \"Stability Generative Models\"\nreadme = \"README.md\"\nlicense-files = { paths = [\"LICENSE-CODE\"] }\nrequires-python = \">=3.8\"\n\n[project.urls]\nHomepage = \"https://github.com/Stability-AI/generative-models\"\n\n[tool.hatch.version]\npath = \"sgm/__init__.py\"\n\n[tool.hatch.build]\n# This needs to be explicitly set so the configuration files\n# grafted into the `sgm` directory get included in the wheel's\n# RECORD file.\ninclude = [\n    \"sgm\",\n]\n# The force-include configurations below make Hatch copy\n# the configs/ directory (containing the various YAML files required\n# to generatively model) into the source distribution and the wheel.\n\n[tool.hatch.build.targets.sdist.force-include]\n\"./configs\" = \"sgm/configs\"\n\n[tool.hatch.build.targets.wheel.force-include]\n\"./configs\" = \"sgm/configs\"\n\n[tool.hatch.envs.ci]\nskip-install = false\n\ndependencies = [\n    \"pytest\"\n]\n\n[tool.hatch.envs.ci.scripts]\ntest-inference = [\n    \"pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118\",\n    \"pip install -r requirements/pt2.txt\",    \n    \"pytest -v tests/inference/test_inference.py {args}\",\n]\n"
  },
  {
    "path": "pytest.ini",
    "content": "[pytest]\nmarkers = \n  inference: mark as inference test (deselect with '-m \"not inference\"')"
  },
  {
    "path": "requirements/pt2.txt",
    "content": "black==23.7.0\nchardet==5.1.0\nclip @ git+https://github.com/openai/CLIP.git\neinops>=0.6.1\nfairscale>=0.4.13\nfire>=0.5.0\nfsspec>=2023.6.0\nimageio[ffmpeg]\nimageio[pyav]\ninvisible-watermark>=0.2.0\nkornia==0.6.9\nmatplotlib>=3.7.2\nnatsort>=8.4.0\nninja>=1.11.1\nnumpy==2.1\nomegaconf>=2.3.0\nonnxruntime\nopen-clip-torch>=2.20.0\nopencv-python==4.6.0.66\npandas>=2.0.3\npillow>=9.5.0\npudb>=2022.1.3\npytorch-lightning==2.0.1\npyyaml>=6.0.1\nrembg\nscipy>=1.10.1\nstreamlit>=0.73.1\ntensorboardx==2.6\ntimm>=0.9.2\ntokenizers==0.12.1\ntorch>=2.0.1\ntorchaudio>=2.0.2\ntorchdata==0.6.1\ntorchmetrics>=1.0.1\ntorchvision>=0.15.2\ntqdm>=4.65.0\ntransformers==4.19.1\ntriton==2.0.0\nurllib3<1.27,>=1.25.4\nwandb>=0.15.6\nwebdataset>=0.2.33\nwheel>=0.41.0\nxformers>=0.0.20\ngradio\nstreamlit-keyup==0.2.0\n"
  },
  {
    "path": "scripts/__init__.py",
    "content": ""
  },
  {
    "path": "scripts/demo/__init__.py",
    "content": ""
  },
  {
    "path": "scripts/demo/detect.py",
    "content": "import argparse\n\nimport cv2\nimport numpy as np\n\ntry:\n    from imwatermark import WatermarkDecoder\nexcept ImportError as e:\n    try:\n        # Assume some of the other dependencies such as torch are not fulfilled\n        # import file without loading unnecessary libraries.\n        import importlib.util\n        import sys\n\n        spec = importlib.util.find_spec(\"imwatermark.maxDct\")\n        assert spec is not None\n        maxDct = importlib.util.module_from_spec(spec)\n        sys.modules[\"maxDct\"] = maxDct\n        spec.loader.exec_module(maxDct)\n\n        class WatermarkDecoder(object):\n            \"\"\"A minimal version of\n            https://github.com/ShieldMnt/invisible-watermark/blob/main/imwatermark/watermark.py\n            to only reconstruct bits using dwtDct\"\"\"\n\n            def __init__(self, wm_type=\"bytes\", length=0):\n                assert wm_type == \"bits\", \"Only bits defined in minimal import\"\n                self._wmType = wm_type\n                self._wmLen = length\n\n            def reconstruct(self, bits):\n                if len(bits) != self._wmLen:\n                    raise RuntimeError(\"bits are not matched with watermark length\")\n\n                return bits\n\n            def decode(self, cv2Image, method=\"dwtDct\", **configs):\n                (r, c, channels) = cv2Image.shape\n                if r * c < 256 * 256:\n                    raise RuntimeError(\"image too small, should be larger than 256x256\")\n\n                bits = []\n                assert method == \"dwtDct\"\n                embed = maxDct.EmbedMaxDct(watermarks=[], wmLen=self._wmLen, **configs)\n                bits = embed.decode(cv2Image)\n                return self.reconstruct(bits)\n\n    except:\n        raise e\n\n\n# A fixed 48-bit message that was choosen at random\n# WATERMARK_MESSAGE = 0xB3EC907BB19E\nWATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110\n# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1\nWATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]\nMATCH_VALUES = [\n    [27, \"No watermark detected\"],\n    [33, \"Partial watermark match. Cannot determine with certainty.\"],\n    [\n        35,\n        (\n            \"Likely watermarked. In our test 0.02% of real images were \"\n            'falsely detected as \"Likely watermarked\"'\n        ),\n    ],\n    [\n        49,\n        (\n            \"Very likely watermarked. In our test no real images were \"\n            'falsely detected as \"Very likely watermarked\"'\n        ),\n    ],\n]\n\n\nclass GetWatermarkMatch:\n    def __init__(self, watermark):\n        self.watermark = watermark\n        self.num_bits = len(self.watermark)\n        self.decoder = WatermarkDecoder(\"bits\", self.num_bits)\n\n    def __call__(self, x: np.ndarray) -> np.ndarray:\n        \"\"\"\n        Detects the number of matching bits the predefined watermark with one\n        or multiple images. Images should be in cv2 format, e.g. h x w x c BGR.\n\n        Args:\n            x: ([B], h w, c) in range [0, 255]\n\n        Returns:\n           number of matched bits ([B],)\n        \"\"\"\n        squeeze = len(x.shape) == 3\n        if squeeze:\n            x = x[None, ...]\n\n        bs = x.shape[0]\n        detected = np.empty((bs, self.num_bits), dtype=bool)\n        for k in range(bs):\n            detected[k] = self.decoder.decode(x[k], \"dwtDct\")\n        result = np.sum(detected == self.watermark, axis=-1)\n        if squeeze:\n            return result[0]\n        else:\n            return result\n\n\nget_watermark_match = GetWatermarkMatch(WATERMARK_BITS)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"filename\",\n        nargs=\"+\",\n        type=str,\n        help=\"Image files to check for watermarks\",\n    )\n    opts = parser.parse_args()\n\n    print(\n        \"\"\"\n        This script tries to detect watermarked images. Please be aware of\n        the following:\n        - As the watermark is supposed to be invisible, there is the risk that\n          watermarked images may not be detected.\n        - To maximize the chance of detection make sure that the image has the same\n          dimensions as when the watermark was applied (most likely 1024x1024\n          or 512x512).\n        - Specific image manipulation may drastically decrease the chance that\n          watermarks can be detected.\n        - There is also the chance that an image has the characteristics of the\n          watermark by chance.\n        - The watermark script is public, anybody may watermark any images, and\n          could therefore claim it to be generated.\n        - All numbers below are based on a test using 10,000 images without any\n          modifications after applying the watermark.\n        \"\"\"\n    )\n\n    for fn in opts.filename:\n        image = cv2.imread(fn)\n        if image is None:\n            print(f\"Couldn't read {fn}. Skipping\")\n            continue\n\n        num_bits = get_watermark_match(image)\n        k = 0\n        while num_bits > MATCH_VALUES[k][0]:\n            k += 1\n        print(\n            f\"{fn}: {MATCH_VALUES[k][1]}\",\n            f\"Bits that matched the watermark {num_bits} from {len(WATERMARK_BITS)}\\n\",\n            sep=\"\\n\\t\",\n        )\n"
  },
  {
    "path": "scripts/demo/discretization.py",
    "content": "import torch\n\nfrom sgm.modules.diffusionmodules.discretizer import Discretization\n\n\nclass Img2ImgDiscretizationWrapper:\n    \"\"\"\n    wraps a discretizer, and prunes the sigmas\n    params:\n        strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)\n    \"\"\"\n\n    def __init__(self, discretization: Discretization, strength: float = 1.0):\n        self.discretization = discretization\n        self.strength = strength\n        assert 0.0 <= self.strength <= 1.0\n\n    def __call__(self, *args, **kwargs):\n        # sigmas start large first, and decrease then\n        sigmas = self.discretization(*args, **kwargs)\n        print(f\"sigmas after discretization, before pruning img2img: \", sigmas)\n        sigmas = torch.flip(sigmas, (0,))\n        sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]\n        print(\"prune index:\", max(int(self.strength * len(sigmas)), 1))\n        sigmas = torch.flip(sigmas, (0,))\n        print(f\"sigmas after pruning: \", sigmas)\n        return sigmas\n\n\nclass Txt2NoisyDiscretizationWrapper:\n    \"\"\"\n    wraps a discretizer, and prunes the sigmas\n    params:\n        strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)\n    \"\"\"\n\n    def __init__(\n        self, discretization: Discretization, strength: float = 0.0, original_steps=None\n    ):\n        self.discretization = discretization\n        self.strength = strength\n        self.original_steps = original_steps\n        assert 0.0 <= self.strength <= 1.0\n\n    def __call__(self, *args, **kwargs):\n        # sigmas start large first, and decrease then\n        sigmas = self.discretization(*args, **kwargs)\n        print(f\"sigmas after discretization, before pruning img2img: \", sigmas)\n        sigmas = torch.flip(sigmas, (0,))\n        if self.original_steps is None:\n            steps = len(sigmas)\n        else:\n            steps = self.original_steps + 1\n        prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)\n        sigmas = sigmas[prune_index:]\n        print(\"prune index:\", prune_index)\n        sigmas = torch.flip(sigmas, (0,))\n        print(f\"sigmas after pruning: \", sigmas)\n        return sigmas\n"
  },
  {
    "path": "scripts/demo/gradio_app.py",
    "content": "# Adding this at the very top of app.py to make 'generative-models' directory discoverable\nimport os\nimport sys\n\nsys.path.append(os.path.join(os.path.dirname(__file__), \"generative-models\"))\n\nimport math\nimport random\nimport uuid\nfrom glob import glob\nfrom pathlib import Path\nfrom typing import Optional\n\nimport cv2\nimport gradio as gr\nimport numpy as np\nimport torch\nfrom einops import rearrange, repeat\nfrom fire import Fire\nfrom huggingface_hub import hf_hub_download\nfrom omegaconf import OmegaConf\nfrom PIL import Image\nfrom torchvision.transforms import ToTensor\n\nfrom scripts.sampling.simple_video_sample import (\n    get_batch,\n    get_unique_embedder_keys_from_conditioner,\n    load_model,\n)\nfrom scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering\nfrom sgm.inference.helpers import embed_watermark\nfrom sgm.util import default, instantiate_from_config\n\n# To download all svd models\n# hf_hub_download(repo_id=\"stabilityai/stable-video-diffusion-img2vid-xt\", filename=\"svd_xt.safetensors\", local_dir=\"checkpoints\")\n# hf_hub_download(repo_id=\"stabilityai/stable-video-diffusion-img2vid\", filename=\"svd.safetensors\", local_dir=\"checkpoints\")\n# hf_hub_download(repo_id=\"stabilityai/stable-video-diffusion-img2vid-xt-1-1\", filename=\"svd_xt_1_1.safetensors\", local_dir=\"checkpoints\")\n\n\n# Define the repo, local directory and filename\nrepo_id = \"stabilityai/stable-video-diffusion-img2vid-xt-1-1\"  # replace with \"stabilityai/stable-video-diffusion-img2vid-xt\" or \"stabilityai/stable-video-diffusion-img2vid\" for other models\nfilename = \"svd_xt_1_1.safetensors\"  # replace with \"svd_xt.safetensors\" or \"svd.safetensors\" for other models\nlocal_dir = \"checkpoints\"\nlocal_file_path = os.path.join(local_dir, filename)\n\n# Check if the file already exists\nif not os.path.exists(local_file_path):\n    # If the file doesn't exist, download it\n    hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)\n    print(\"File downloaded.\")\nelse:\n    print(\"File already exists. No need to download.\")\n\n\nversion = \"svd_xt_1_1\"  # replace with 'svd_xt' or 'svd' for other models\ndevice = \"cuda\"\nmax_64_bit_int = 2**63 - 1\n\nif version == \"svd_xt_1_1\":\n    num_frames = 25\n    num_steps = 30\n    model_config = \"scripts/sampling/configs/svd_xt_1_1.yaml\"\nelse:\n    raise ValueError(f\"Version {version} does not exist.\")\n\nmodel, filter = load_model(\n    model_config,\n    device,\n    num_frames,\n    num_steps,\n)\n\n\ndef sample(\n    input_path: str = \"assets/test_image.png\",  # Can either be image file or folder with image files\n    seed: Optional[int] = None,\n    randomize_seed: bool = True,\n    motion_bucket_id: int = 127,\n    fps_id: int = 6,\n    version: str = \"svd_xt_1_1\",\n    cond_aug: float = 0.02,\n    decoding_t: int = 7,  # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.\n    device: str = \"cuda\",\n    output_folder: str = \"outputs\",\n    progress=gr.Progress(track_tqdm=True),\n):\n    \"\"\"\n    Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each\n    image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.\n    \"\"\"\n    fps_id = int(fps_id)  # casting float slider values to int)\n    if randomize_seed:\n        seed = random.randint(0, max_64_bit_int)\n\n    torch.manual_seed(seed)\n\n    path = Path(input_path)\n    all_img_paths = []\n    if path.is_file():\n        if any([input_path.endswith(x) for x in [\"jpg\", \"jpeg\", \"png\"]]):\n            all_img_paths = [input_path]\n        else:\n            raise ValueError(\"Path is not valid image file.\")\n    elif path.is_dir():\n        all_img_paths = sorted(\n            [\n                f\n                for f in path.iterdir()\n                if f.is_file() and f.suffix.lower() in [\".jpg\", \".jpeg\", \".png\"]\n            ]\n        )\n        if len(all_img_paths) == 0:\n            raise ValueError(\"Folder does not contain any images.\")\n    else:\n        raise ValueError\n\n    for input_img_path in all_img_paths:\n        with Image.open(input_img_path) as image:\n            if image.mode == \"RGBA\":\n                image = image.convert(\"RGB\")\n            w, h = image.size\n\n            if h % 64 != 0 or w % 64 != 0:\n                width, height = map(lambda x: x - x % 64, (w, h))\n                image = image.resize((width, height))\n                print(\n                    f\"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!\"\n                )\n\n            image = ToTensor()(image)\n            image = image * 2.0 - 1.0\n\n        image = image.unsqueeze(0).to(device)\n        H, W = image.shape[2:]\n        assert image.shape[1] == 3\n        F = 8\n        C = 4\n        shape = (num_frames, C, H // F, W // F)\n        if (H, W) != (576, 1024):\n            print(\n                \"WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`.\"\n            )\n        if motion_bucket_id > 255:\n            print(\n                \"WARNING: High motion bucket! This may lead to suboptimal performance.\"\n            )\n\n        if fps_id < 5:\n            print(\"WARNING: Small fps value! This may lead to suboptimal performance.\")\n\n        if fps_id > 30:\n            print(\"WARNING: Large fps value! This may lead to suboptimal performance.\")\n\n        value_dict = {}\n        value_dict[\"motion_bucket_id\"] = motion_bucket_id\n        value_dict[\"fps_id\"] = fps_id\n        value_dict[\"cond_aug\"] = cond_aug\n        value_dict[\"cond_frames_without_noise\"] = image\n        value_dict[\"cond_frames\"] = image + cond_aug * torch.randn_like(image)\n        value_dict[\"cond_aug\"] = cond_aug\n\n        with torch.no_grad():\n            with torch.autocast(device):\n                batch, batch_uc = get_batch(\n                    get_unique_embedder_keys_from_conditioner(model.conditioner),\n                    value_dict,\n                    [1, num_frames],\n                    T=num_frames,\n                    device=device,\n                )\n                c, uc = model.conditioner.get_unconditional_conditioning(\n                    batch,\n                    batch_uc=batch_uc,\n                    force_uc_zero_embeddings=[\n                        \"cond_frames\",\n                        \"cond_frames_without_noise\",\n                    ],\n                )\n\n                for k in [\"crossattn\", \"concat\"]:\n                    uc[k] = repeat(uc[k], \"b ... -> b t ...\", t=num_frames)\n                    uc[k] = rearrange(uc[k], \"b t ... -> (b t) ...\", t=num_frames)\n                    c[k] = repeat(c[k], \"b ... -> b t ...\", t=num_frames)\n                    c[k] = rearrange(c[k], \"b t ... -> (b t) ...\", t=num_frames)\n\n                randn = torch.randn(shape, device=device)\n\n                additional_model_inputs = {}\n                additional_model_inputs[\"image_only_indicator\"] = torch.zeros(\n                    2, num_frames\n                ).to(device)\n                additional_model_inputs[\"num_video_frames\"] = batch[\"num_video_frames\"]\n\n                def denoiser(input, sigma, c):\n                    return model.denoiser(\n                        model.model, input, sigma, c, **additional_model_inputs\n                    )\n\n                samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)\n                model.en_and_decode_n_samples_a_time = decoding_t\n                samples_x = model.decode_first_stage(samples_z)\n                samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)\n\n                os.makedirs(output_folder, exist_ok=True)\n                base_count = len(glob(os.path.join(output_folder, \"*.mp4\")))\n                video_path = os.path.join(output_folder, f\"{base_count:06d}.mp4\")\n                writer = cv2.VideoWriter(\n                    video_path,\n                    cv2.VideoWriter_fourcc(*\"mp4v\"),\n                    fps_id + 1,\n                    (samples.shape[-1], samples.shape[-2]),\n                )\n\n                samples = embed_watermark(samples)\n                samples = filter(samples)\n                vid = (\n                    (rearrange(samples, \"t c h w -> t h w c\") * 255)\n                    .cpu()\n                    .numpy()\n                    .astype(np.uint8)\n                )\n                for frame in vid:\n                    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)\n                    writer.write(frame)\n                writer.release()\n\n        return video_path, seed\n\n\ndef resize_image(image_path, output_size=(1024, 576)):\n    image = Image.open(image_path)\n    # Calculate aspect ratios\n    target_aspect = output_size[0] / output_size[1]  # Aspect ratio of the desired size\n    image_aspect = image.width / image.height  # Aspect ratio of the original image\n\n    # Resize then crop if the original image is larger\n    if image_aspect > target_aspect:\n        # Resize the image to match the target height, maintaining aspect ratio\n        new_height = output_size[1]\n        new_width = int(new_height * image_aspect)\n        resized_image = image.resize((new_width, new_height), Image.LANCZOS)\n        # Calculate coordinates for cropping\n        left = (new_width - output_size[0]) / 2\n        top = 0\n        right = (new_width + output_size[0]) / 2\n        bottom = output_size[1]\n    else:\n        # Resize the image to match the target width, maintaining aspect ratio\n        new_width = output_size[0]\n        new_height = int(new_width / image_aspect)\n        resized_image = image.resize((new_width, new_height), Image.LANCZOS)\n        # Calculate coordinates for cropping\n        left = 0\n        top = (new_height - output_size[1]) / 2\n        right = output_size[0]\n        bottom = (new_height + output_size[1]) / 2\n\n    # Crop the image\n    cropped_image = resized_image.crop((left, top, right, bottom))\n\n    return cropped_image\n\n\nwith gr.Blocks() as demo:\n    gr.Markdown(\n        \"\"\"# Community demo for Stable Video Diffusion - Img2Vid - XT ([model](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt), [paper](https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets))\n#### Research release ([_non-commercial_](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/blob/main/LICENSE)): generate `4s` vid from a single image at (`25 frames` at `6 fps`). Generation takes ~60s in an A100. [Join the waitlist for Stability's upcoming web experience](https://stability.ai/contact).\n  \"\"\"\n    )\n    with gr.Row():\n        with gr.Column():\n            image = gr.Image(label=\"Upload your image\", type=\"filepath\")\n            generate_btn = gr.Button(\"Generate\")\n        video = gr.Video()\n    with gr.Accordion(\"Advanced options\", open=False):\n        seed = gr.Slider(\n            label=\"Seed\",\n            value=42,\n            randomize=True,\n            minimum=0,\n            maximum=max_64_bit_int,\n            step=1,\n        )\n        randomize_seed = gr.Checkbox(label=\"Randomize seed\", value=True)\n        motion_bucket_id = gr.Slider(\n            label=\"Motion bucket id\",\n            info=\"Controls how much motion to add/remove from the image\",\n            value=127,\n            minimum=1,\n            maximum=255,\n        )\n        fps_id = gr.Slider(\n            label=\"Frames per second\",\n            info=\"The length of your video in seconds will be 25/fps\",\n            value=6,\n            minimum=5,\n            maximum=30,\n        )\n\n    image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)\n    generate_btn.click(\n        fn=sample,\n        inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id],\n        outputs=[video, seed],\n        api_name=\"video\",\n    )\n\nif __name__ == \"__main__\":\n    demo.queue(max_size=20)\n    demo.launch(share=True)\n"
  },
  {
    "path": "scripts/demo/gradio_app_sv4d.py",
    "content": "# Adding this at the very top of app.py to make 'generative-models' directory discoverable\nimport os\nimport sys\n\nsys.path.append(os.path.join(os.path.dirname(__file__), \"generative-models\"))\n\nfrom glob import glob\nfrom typing import Optional\n\nimport gradio as gr\nimport numpy as np\nimport torch\nfrom huggingface_hub import hf_hub_download\nfrom typing import List, Optional, Union\nimport torchvision\n\nfrom sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder\nfrom scripts.demo.sv4d_helpers import (\n    decode_latents,\n    load_model,\n    initial_model_load,\n    read_video,\n    run_img2vid,\n    prepare_inputs,\n    do_sample_per_step,\n    sample_sv3d,\n    save_video,\n    preprocess_video,\n)\n\n\n# the tmp path, if /tmp/gradio is not writable, change it to a writable path\n# os.environ[\"GRADIO_TEMP_DIR\"] = \"gradio_tmp\"\n\nversion = \"sv4d\"  # replace with 'sv3d_p' or 'sv3d_u' for other models\n\n# Define the repo, local directory and filename\nrepo_id = \"stabilityai/sv4d\"\nfilename = f\"{version}.safetensors\"  # replace with \"sv3d_u.safetensors\" or \"sv3d_p.safetensors\"\nlocal_dir = \"checkpoints\"\nlocal_ckpt_path = os.path.join(local_dir, filename)\n\n# Check if the file already exists\nif not os.path.exists(local_ckpt_path):\n    # If the file doesn't exist, download it\n    hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)\n    print(\"File downloaded. (sv4d)\")\nelse:\n    print(\"File already exists. No need to download. (sv4d)\")\n\ndevice = \"cuda\"\nmax_64_bit_int = 2**63 - 1\n\nnum_frames = 21\nnum_steps = 20\nmodel_config = f\"scripts/sampling/configs/{version}.yaml\"\n\n# Set model config\nT = 5  # number of frames per sample\nV = 8  # number of views per sample\nF = 8  # vae factor to downsize image->latent\nC = 4\nH, W = 576, 576\nn_frames = 21  # number of input and output video frames\nn_views = V + 1  # number of output video views (1 input view + 8 novel views)\nn_views_sv3d = 21\nsubsampled_views = np.array(\n    [0, 2, 5, 7, 9, 12, 14, 16, 19]\n)  # subsample (V+1=)9 (uniform) views from 21 SV3D views\n\nversion_dict = {\n    \"T\": T * V,\n    \"H\": H,\n    \"W\": W,\n    \"C\": C,\n    \"f\": F,\n    \"options\": {\n        \"discretization\": 1,\n        \"cfg\": 3,\n        \"sigma_min\": 0.002,\n        \"sigma_max\": 700.0,\n        \"rho\": 7.0,\n        \"guider\": 5,\n        \"num_steps\": num_steps,\n        \"force_uc_zero_embeddings\": [\n            \"cond_frames\",\n            \"cond_frames_without_noise\",\n            \"cond_view\",\n            \"cond_motion\",\n        ],\n        \"additional_guider_kwargs\": {\n            \"additional_cond_keys\": [\"cond_view\", \"cond_motion\"]\n        },\n    },\n}\n\n# Load SV4D model\nmodel, filter = load_model(\n    model_config,\n    device,\n    version_dict[\"T\"],\n    num_steps,\n)\nmodel = initial_model_load(model)\n\n# -----------sv3d config and model loading----------------\n# if version == \"sv3d_u\":\nsv3d_model_config = \"scripts/sampling/configs/sv3d_u.yaml\"\n# elif version == \"sv3d_p\":\n#     sv3d_model_config = \"scripts/sampling/configs/sv3d_p.yaml\"\n# else:\n#     raise ValueError(f\"Version {version} does not exist.\")\n\n# Define the repo, local directory and filename\nrepo_id = \"stabilityai/sv3d\"\nfilename = f\"sv3d_u.safetensors\"  # replace with \"sv3d_u.safetensors\" or \"sv3d_p.safetensors\"\nlocal_dir = \"checkpoints\"\nlocal_ckpt_path = os.path.join(local_dir, filename)\n\n# Check if the file already exists\nif not os.path.exists(local_ckpt_path):\n    # If the file doesn't exist, download it\n    hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)\n    print(\"File downloaded. (sv3d)\")\nelse:\n    print(\"File already exists. No need to download. (sv3d)\")\n\n# load sv3d model\nsv3d_model, filter = load_model(\n    sv3d_model_config,\n    device,\n    21,\n    num_steps,\n    verbose=False,\n)\nsv3d_model = initial_model_load(sv3d_model)\n# ------------------\n\ndef sample_anchor(\n    input_path: str = \"assets/test_image.png\",  # Can either be image file or folder with image files\n    seed: Optional[int] = None,\n    encoding_t: int = 8,  # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.\n    decoding_t: int = 4,  # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.\n    num_steps: int = 20,\n    sv3d_version: str = \"sv3d_u\",  # sv3d_u or sv3d_p\n    fps_id: int = 6,\n    motion_bucket_id: int = 127,\n    cond_aug: float = 1e-5,\n    device: str = \"cuda\",\n    elevations_deg: Optional[Union[float, List[float]]] = 10.0,\n    azimuths_deg: Optional[List[float]] = None,\n    verbose: Optional[bool] = False,\n):\n    \"\"\"\n    Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each\n    image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.\n    \"\"\"\n    output_folder = os.path.dirname(input_path)\n\n    torch.manual_seed(seed)\n    os.makedirs(output_folder, exist_ok=True)\n\n    # Read input video frames i.e. images at view 0\n    print(f\"Reading {input_path}\")\n    images_v0 = read_video(\n        input_path,\n        n_frames=n_frames,\n        device=device,\n    )\n\n    # Get camera viewpoints\n    if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):\n        elevations_deg = [elevations_deg] * n_views_sv3d\n    assert (\n        len(elevations_deg) == n_views_sv3d\n    ), f\"Please provide 1 value, or a list of {n_views_sv3d} values for elevations_deg! Given {len(elevations_deg)}\"\n    if azimuths_deg is None:\n        azimuths_deg = np.linspace(0, 360, n_views_sv3d + 1)[1:] % 360\n    assert (\n        len(azimuths_deg) == n_views_sv3d\n    ), f\"Please provide a list of {n_views_sv3d} values for azimuths_deg! Given {len(azimuths_deg)}\"\n    polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])\n    azimuths_rad = np.array(\n        [np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]\n    )\n\n    # Sample multi-view images of the first frame using SV3D i.e. images at time 0\n    sv3d_model.sampler.num_steps = num_steps\n    print(\"sv3d_model.sampler.num_steps\", sv3d_model.sampler.num_steps)\n    images_t0 = sample_sv3d(\n        images_v0[0],\n        n_views_sv3d,\n        num_steps,\n        sv3d_version,\n        fps_id,\n        motion_bucket_id,\n        cond_aug,\n        decoding_t,\n        device,\n        polars_rad,\n        azimuths_rad,\n        verbose,\n        sv3d_model,\n    )\n    images_t0 = torch.roll(images_t0, 1, 0)  # move conditioning image to first frame\n\n    sv3d_file = os.path.join(output_folder, \"t000.mp4\")\n    save_video(sv3d_file, images_t0.unsqueeze(1))\n    \n    for emb in model.conditioner.embedders:\n        if isinstance(emb, VideoPredictionEmbedderWithEncoder):\n            emb.en_and_decode_n_samples_a_time = encoding_t\n    model.en_and_decode_n_samples_a_time = decoding_t\n    # Initialize image matrix\n    img_matrix = [[None] * n_views for _ in range(n_frames)]\n    for i, v in enumerate(subsampled_views):\n        img_matrix[0][i] = images_t0[v].unsqueeze(0)\n    for t in range(n_frames):\n        img_matrix[t][0] = images_v0[t]\n\n    # Interleaved sampling for anchor frames\n    t0, v0 = 0, 0\n    frame_indices = np.arange(T - 1, n_frames, T - 1)  # [4, 8, 12, 16, 20]\n    view_indices = np.arange(V) + 1\n    print(f\"Sampling anchor frames {frame_indices}\")\n    image = img_matrix[t0][v0]\n    cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)\n    cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)\n    polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()\n    azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()\n    azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)\n    model.sampler.num_steps = num_steps\n    version_dict[\"options\"][\"num_steps\"] = num_steps\n    samples = run_img2vid(\n        version_dict, model, image, seed, polars, azims, cond_motion, cond_view, decoding_t\n    )\n    samples = samples.view(T, V, 3, H, W)\n    for i, t in enumerate(frame_indices):\n        for j, v in enumerate(view_indices):\n            if img_matrix[t][v] is None:\n                img_matrix[t][v] = samples[i, j][None] * 2 - 1\n\n    # concat video\n    grid_list = []\n    for t in frame_indices:\n        imgs_view = torch.cat(img_matrix[t])\n        grid_list.append(torchvision.utils.make_grid(imgs_view, nrow=3).unsqueeze(0))\n    # save output videos\n    anchor_vis_file = os.path.join(output_folder, \"anchor_vis.mp4\")\n    save_video(anchor_vis_file, grid_list, fps=3)\n    anchor_file = os.path.join(output_folder, \"anchor.mp4\")\n    image_list = samples.view(T*V, 3, H, W).unsqueeze(1) * 2 - 1\n    save_video(anchor_file, image_list)\n\n    return sv3d_file, anchor_vis_file, anchor_file\n\n\ndef sample_all(\n    input_path: str = \"inputs/test_video1.mp4\",  # Can either be video file or folder with image files\n    sv3d_path: str = \"outputs/sv4d/000000_t000.mp4\",\n    anchor_path: str = \"outputs/sv4d/000000_anchor.mp4\",\n    seed: Optional[int] = None,\n    num_steps: int = 20,\n    device: str = \"cuda\",\n    elevations_deg: Optional[Union[float, List[float]]] = 10.0,\n    azimuths_deg: Optional[List[float]] = None,\n):\n    \"\"\"\n    Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each\n    image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.\n    \"\"\"\n    output_folder = os.path.dirname(input_path)\n    torch.manual_seed(seed)\n    os.makedirs(output_folder, exist_ok=True)\n\n    # Read input video frames i.e. images at view 0\n    print(f\"Reading {input_path}\")\n    images_v0 = read_video(\n        input_path,\n        n_frames=n_frames,\n        device=device,\n    )\n\n    images_t0 = read_video(\n        sv3d_path,\n        n_frames=n_views_sv3d,\n        device=device,\n    )\n\n    # Get camera viewpoints\n    if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):\n        elevations_deg = [elevations_deg] * n_views_sv3d\n    assert (\n        len(elevations_deg) == n_views_sv3d\n    ), f\"Please provide 1 value, or a list of {n_views_sv3d} values for elevations_deg! Given {len(elevations_deg)}\"\n    if azimuths_deg is None:\n        azimuths_deg = np.linspace(0, 360, n_views_sv3d + 1)[1:] % 360\n    assert (\n        len(azimuths_deg) == n_views_sv3d\n    ), f\"Please provide a list of {n_views_sv3d} values for azimuths_deg! Given {len(azimuths_deg)}\"\n    polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])\n    azimuths_rad = np.array(\n        [np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]\n    )\n\n    # Initialize image matrix\n    img_matrix = [[None] * n_views for _ in range(n_frames)]\n    for i, v in enumerate(subsampled_views):\n        img_matrix[0][i] = images_t0[v]\n    for t in range(n_frames):\n        img_matrix[t][0] = images_v0[t]\n\n    # load interleaved sampling for anchor frames\n    t0, v0 = 0, 0\n    frame_indices = np.arange(T - 1, n_frames, T - 1)  # [4, 8, 12, 16, 20]\n    view_indices = np.arange(V) + 1\n\n    anchor_frames = read_video(\n        anchor_path,\n        n_frames=T * V,\n        device=device,\n    )\n    anchor_frames = torch.cat(anchor_frames).view(T, V, 3, H, W)\n    for i, t in enumerate(frame_indices):\n        for j, v in enumerate(view_indices):\n            if img_matrix[t][v] is None:\n                img_matrix[t][v] = anchor_frames[i, j][None]\n\n    # Dense sampling for the rest\n    print(f\"Sampling dense frames:\")\n    for t0 in np.arange(0, n_frames - 1, T - 1):  # [0, 4, 8, 12, 16]\n        frame_indices = t0 + np.arange(T)\n        print(f\"Sampling dense frames {frame_indices}\")\n        latent_matrix = torch.randn(n_frames, n_views, C, H // F, W // F).to(\"cuda\")\n\n        polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()\n        azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()\n        azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)\n        \n        # alternate between forward and backward conditioning\n        forward_inputs, forward_frame_indices, backward_inputs, backward_frame_indices = prepare_inputs(\n            frame_indices, \n            img_matrix, \n            v0, \n            view_indices, \n            model, \n            version_dict, \n            seed, \n            polars, \n            azims\n        )\n        \n        for step in range(num_steps):\n            if step % 2 == 1:\n                c, uc, additional_model_inputs, sampler = forward_inputs\n                frame_indices = forward_frame_indices\n            else:\n                c, uc, additional_model_inputs, sampler = backward_inputs\n                frame_indices = backward_frame_indices\n            noisy_latents = latent_matrix[frame_indices][:, view_indices].flatten(0, 1)\n                \n            samples = do_sample_per_step(\n                model,\n                sampler,\n                noisy_latents,\n                c,\n                uc,\n                step,\n                additional_model_inputs,\n            )\n            samples = samples.view(T, V, C, H // F, W // F)\n            for i, t in enumerate(frame_indices):\n                for j, v in enumerate(view_indices):\n                    latent_matrix[t, v] = samples[i, j]\n\n        img_matrix = decode_latents(model, latent_matrix, img_matrix, frame_indices, view_indices, T)\n\n\n    # concat video\n    grid_list = []\n    for t in range(n_frames):\n        imgs_view = torch.cat(img_matrix[t])\n        grid_list.append(torchvision.utils.make_grid(imgs_view, nrow=3).unsqueeze(0))\n    # save output videos\n    vid_file = os.path.join(output_folder, \"sv4d_final.mp4\")\n    save_video(vid_file, grid_list)\n\n    return vid_file, seed\n\n\nwith gr.Blocks() as demo:\n    gr.Markdown(\n        \"\"\"# Demo for SV4D from Stability AI ([model](https://huggingface.co/stabilityai/sv4d), [news](https://stability.ai/news/stable-video-4d))\n#### Research release ([_non-commercial_](https://huggingface.co/stabilityai/sv4d/blob/main/LICENSE.md)): generate 8 novel view videos from a single-view video (with white background).\n#### It takes ~45s to generate anchor frames and another ~160s to generate full results (21 frames).  \n#### Hints for improving performance:  \n- Use a white background; \n- Make the object in the center of the image; \n- The SV4D process the first 21 frames of the uploaded video. Gradio provides a nice option of trimming the uploaded video if needed.  \n  \"\"\"\n    )\n    with gr.Row():\n        with gr.Column():\n            input_video = gr.Video(label=\"Upload your video\")\n            generate_btn = gr.Button(\"Step 1: generate 8 novel view videos (5 anchor frames each)\")\n            interpolate_btn = gr.Button(\"Step 2: Extend novel view videos to 21 frames\")\n        with gr.Column():\n            anchor_video = gr.Video(label=\"SV4D outputs (anchor frames)\")\n            sv3d_video = gr.Video(label=\"SV3D outputs\", interactive=False)\n        with gr.Column():\n            sv4d_interpolated_video = gr.Video(label=\"SV4D outputs (21 frames)\")\n\n    with gr.Accordion(\"Advanced options\", open=False):\n        seed = gr.Slider(\n            label=\"Seed\",\n            value=23,\n            # randomize=True,\n            minimum=0,\n            maximum=100,\n            step=1,\n        )\n        encoding_t = gr.Slider(\n            label=\"Encode n frames at a time\",\n            info=\"Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.\",\n            value=8,\n            minimum=1,\n            maximum=40,\n        )\n        decoding_t = gr.Slider(\n            label=\"Decode n frames at a time\",\n            info=\"Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.\",\n            value=4,\n            minimum=1,\n            maximum=14,\n        )\n        denoising_steps = gr.Slider(\n            label=\"Number of denoising steps\",\n            info=\"Increase will improve the performance but needs more time.\",\n            value=20,\n            minimum=10,\n            maximum=50,\n            step=1,\n        )\n        remove_bg = gr.Checkbox(\n            label=\"Remove background\",\n            info=\"We use rembg. Users can check the alternative way: SAM2 (https://github.com/facebookresearch/segment-anything-2)\",\n        )\n\n    input_video.upload(fn=preprocess_video, inputs=[input_video, remove_bg], outputs=input_video, queue=False)\n\n    with gr.Row(visible=False):\n        anchor_frames = gr.Video()\n\n    generate_btn.click(\n        fn=sample_anchor,\n        inputs=[input_video, seed, encoding_t, decoding_t, denoising_steps],\n        outputs=[sv3d_video, anchor_video, anchor_frames],\n        api_name=\"SV4D output (5 frames)\",\n    )\n\n    interpolate_btn.click(\n        fn=sample_all,\n        inputs=[input_video, sv3d_video, anchor_frames, seed, denoising_steps],\n        outputs=[sv4d_interpolated_video, seed],\n        api_name=\"SV4D interpolation (21 frames)\",\n    )\n\n    examples = gr.Examples(\n        fn=preprocess_video,\n        examples=[\n            \"./assets/sv4d_videos/test_video1.mp4\",\n            \"./assets/sv4d_videos/test_video2.mp4\",\n            \"./assets/sv4d_videos/green_robot.mp4\",\n            \"./assets/sv4d_videos/dolphin.mp4\",\n            \"./assets/sv4d_videos/lucia_v000.mp4\",\n            \"./assets/sv4d_videos/snowboard_v000.mp4\",\n            \"./assets/sv4d_videos/stroller_v000.mp4\",\n            \"./assets/sv4d_videos/human5.mp4\",\n            \"./assets/sv4d_videos/bunnyman.mp4\",\n            \"./assets/sv4d_videos/hiphop_parrot.mp4\",\n            \"./assets/sv4d_videos/guppie_v0.mp4\",\n            \"./assets/sv4d_videos/wave_hello.mp4\",\n            \"./assets/sv4d_videos/pistol_v0.mp4\",\n            \"./assets/sv4d_videos/human7.mp4\",\n            \"./assets/sv4d_videos/monkey.mp4\",\n            \"./assets/sv4d_videos/train_v0.mp4\",\n        ],\n        inputs=[input_video],\n        run_on_click=True,\n        outputs=[input_video],\n    )\n\nif __name__ == \"__main__\":\n    demo.queue(max_size=20)\n    demo.launch(share=True)\n "
  },
  {
    "path": "scripts/demo/sampling.py",
    "content": "from pytorch_lightning import seed_everything\n\nfrom scripts.demo.streamlit_helpers import *\n\nSAVE_PATH = \"outputs/demo/txt2img/\"\n\nSD_XL_BASE_RATIOS = {\n    \"0.5\": (704, 1408),\n    \"0.52\": (704, 1344),\n    \"0.57\": (768, 1344),\n    \"0.6\": (768, 1280),\n    \"0.68\": (832, 1216),\n    \"0.72\": (832, 1152),\n    \"0.78\": (896, 1152),\n    \"0.82\": (896, 1088),\n    \"0.88\": (960, 1088),\n    \"0.94\": (960, 1024),\n    \"1.0\": (1024, 1024),\n    \"1.07\": (1024, 960),\n    \"1.13\": (1088, 960),\n    \"1.21\": (1088, 896),\n    \"1.29\": (1152, 896),\n    \"1.38\": (1152, 832),\n    \"1.46\": (1216, 832),\n    \"1.67\": (1280, 768),\n    \"1.75\": (1344, 768),\n    \"1.91\": (1344, 704),\n    \"2.0\": (1408, 704),\n    \"2.09\": (1472, 704),\n    \"2.4\": (1536, 640),\n    \"2.5\": (1600, 640),\n    \"2.89\": (1664, 576),\n    \"3.0\": (1728, 576),\n}\n\nVERSION2SPECS = {\n    \"SDXL-base-1.0\": {\n        \"H\": 1024,\n        \"W\": 1024,\n        \"C\": 4,\n        \"f\": 8,\n        \"is_legacy\": False,\n        \"config\": \"configs/inference/sd_xl_base.yaml\",\n        \"ckpt\": \"checkpoints/sd_xl_base_1.0.safetensors\",\n    },\n    \"SDXL-base-0.9\": {\n        \"H\": 1024,\n        \"W\": 1024,\n        \"C\": 4,\n        \"f\": 8,\n        \"is_legacy\": False,\n        \"config\": \"configs/inference/sd_xl_base.yaml\",\n        \"ckpt\": \"checkpoints/sd_xl_base_0.9.safetensors\",\n    },\n    \"SDXL-refiner-0.9\": {\n        \"H\": 1024,\n        \"W\": 1024,\n        \"C\": 4,\n        \"f\": 8,\n        \"is_legacy\": True,\n        \"config\": \"configs/inference/sd_xl_refiner.yaml\",\n        \"ckpt\": \"checkpoints/sd_xl_refiner_0.9.safetensors\",\n    },\n    \"SDXL-refiner-1.0\": {\n        \"H\": 1024,\n        \"W\": 1024,\n        \"C\": 4,\n        \"f\": 8,\n        \"is_legacy\": True,\n        \"config\": \"configs/inference/sd_xl_refiner.yaml\",\n        \"ckpt\": \"checkpoints/sd_xl_refiner_1.0.safetensors\",\n    },\n}\n\n\ndef load_img(display=True, key=None, device=\"cuda\"):\n    image = get_interactive_image(key=key)\n    if image is None:\n        return None\n    if display:\n        st.image(image)\n    w, h = image.size\n    print(f\"loaded input image of size ({w}, {h})\")\n    width, height = map(\n        lambda x: x - x % 64, (w, h)\n    )  # resize to integer multiple of 64\n    image = image.resize((width, height))\n    image = np.array(image.convert(\"RGB\"))\n    image = image[None].transpose(0, 3, 1, 2)\n    image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0\n    return image.to(device)\n\n\ndef run_txt2img(\n    state,\n    version,\n    version_dict,\n    is_legacy=False,\n    return_latents=False,\n    filter=None,\n    stage2strength=None,\n):\n    if version.startswith(\"SDXL-base\"):\n        W, H = st.selectbox(\"Resolution:\", list(SD_XL_BASE_RATIOS.values()), 10)\n    else:\n        H = st.number_input(\"H\", value=version_dict[\"H\"], min_value=64, max_value=2048)\n        W = st.number_input(\"W\", value=version_dict[\"W\"], min_value=64, max_value=2048)\n    C = version_dict[\"C\"]\n    F = version_dict[\"f\"]\n\n    init_dict = {\n        \"orig_width\": W,\n        \"orig_height\": H,\n        \"target_width\": W,\n        \"target_height\": H,\n    }\n    value_dict = init_embedder_options(\n        get_unique_embedder_keys_from_conditioner(state[\"model\"].conditioner),\n        init_dict,\n        prompt=prompt,\n        negative_prompt=negative_prompt,\n    )\n    sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength)\n    num_samples = num_rows * num_cols\n\n    if st.button(\"Sample\"):\n        st.write(f\"**Model I:** {version}\")\n        out = do_sample(\n            state[\"model\"],\n            sampler,\n            value_dict,\n            num_samples,\n            H,\n            W,\n            C,\n            F,\n            force_uc_zero_embeddings=[\"txt\"] if not is_legacy else [],\n            return_latents=return_latents,\n            filter=filter,\n        )\n        return out\n\n\ndef run_img2img(\n    state,\n    version_dict,\n    is_legacy=False,\n    return_latents=False,\n    filter=None,\n    stage2strength=None,\n):\n    img = load_img()\n    if img is None:\n        return None\n    H, W = img.shape[2], img.shape[3]\n\n    init_dict = {\n        \"orig_width\": W,\n        \"orig_height\": H,\n        \"target_width\": W,\n        \"target_height\": H,\n    }\n    value_dict = init_embedder_options(\n        get_unique_embedder_keys_from_conditioner(state[\"model\"].conditioner),\n        init_dict,\n        prompt=prompt,\n        negative_prompt=negative_prompt,\n    )\n    strength = st.number_input(\n        \"**Img2Img Strength**\", value=0.75, min_value=0.0, max_value=1.0\n    )\n    sampler, num_rows, num_cols = init_sampling(\n        img2img_strength=strength,\n        stage2strength=stage2strength,\n    )\n    num_samples = num_rows * num_cols\n\n    if st.button(\"Sample\"):\n        out = do_img2img(\n            repeat(img, \"1 ... -> n ...\", n=num_samples),\n            state[\"model\"],\n            sampler,\n            value_dict,\n            num_samples,\n            force_uc_zero_embeddings=[\"txt\"] if not is_legacy else [],\n            return_latents=return_latents,\n            filter=filter,\n        )\n        return out\n\n\ndef apply_refiner(\n    input,\n    state,\n    sampler,\n    num_samples,\n    prompt,\n    negative_prompt,\n    filter=None,\n    finish_denoising=False,\n):\n    init_dict = {\n        \"orig_width\": input.shape[3] * 8,\n        \"orig_height\": input.shape[2] * 8,\n        \"target_width\": input.shape[3] * 8,\n        \"target_height\": input.shape[2] * 8,\n    }\n\n    value_dict = init_dict\n    value_dict[\"prompt\"] = prompt\n    value_dict[\"negative_prompt\"] = negative_prompt\n\n    value_dict[\"crop_coords_top\"] = 0\n    value_dict[\"crop_coords_left\"] = 0\n\n    value_dict[\"aesthetic_score\"] = 6.0\n    value_dict[\"negative_aesthetic_score\"] = 2.5\n\n    st.warning(f\"refiner input shape: {input.shape}\")\n    samples = do_img2img(\n        input,\n        state[\"model\"],\n        sampler,\n        value_dict,\n        num_samples,\n        skip_encode=True,\n        filter=filter,\n        add_noise=not finish_denoising,\n    )\n\n    return samples\n\n\nif __name__ == \"__main__\":\n    st.title(\"Stable Diffusion\")\n    version = st.selectbox(\"Model Version\", list(VERSION2SPECS.keys()), 0)\n    version_dict = VERSION2SPECS[version]\n    if st.checkbox(\"Load Model\"):\n        mode = st.radio(\"Mode\", (\"txt2img\", \"img2img\"), 0)\n    else:\n        mode = \"skip\"\n    st.write(\"__________________________\")\n\n    set_lowvram_mode(st.checkbox(\"Low vram mode\", True))\n\n    if version.startswith(\"SDXL-base\"):\n        add_pipeline = st.checkbox(\"Load SDXL-refiner?\", False)\n        st.write(\"__________________________\")\n    else:\n        add_pipeline = False\n\n    seed = st.sidebar.number_input(\"seed\", value=42, min_value=0, max_value=int(1e9))\n    seed_everything(seed)\n\n    save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))\n\n    if mode != \"skip\":\n        state = init_st(version_dict, load_filter=True)\n        if state[\"msg\"]:\n            st.info(state[\"msg\"])\n        model = state[\"model\"]\n\n    is_legacy = version_dict[\"is_legacy\"]\n\n    prompt = st.text_input(\n        \"prompt\",\n        \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\",\n    )\n    if is_legacy:\n        negative_prompt = st.text_input(\"negative prompt\", \"\")\n    else:\n        negative_prompt = \"\"  # which is unused\n\n    stage2strength = None\n    finish_denoising = False\n\n    if add_pipeline:\n        st.write(\"__________________________\")\n        version2 = st.selectbox(\"Refiner:\", [\"SDXL-refiner-1.0\", \"SDXL-refiner-0.9\"])\n        st.warning(\n            f\"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) \"\n        )\n        st.write(\"**Refiner Options:**\")\n\n        version_dict2 = VERSION2SPECS[version2]\n        state2 = init_st(version_dict2, load_filter=False)\n        st.info(state2[\"msg\"])\n\n        stage2strength = st.number_input(\n            \"**Refinement strength**\", value=0.15, min_value=0.0, max_value=1.0\n        )\n\n        sampler2, *_ = init_sampling(\n            key=2,\n            img2img_strength=stage2strength,\n            specify_num_samples=False,\n        )\n        st.write(\"__________________________\")\n        finish_denoising = st.checkbox(\"Finish denoising with refiner.\", True)\n        if not finish_denoising:\n            stage2strength = None\n\n    if mode == \"txt2img\":\n        out = run_txt2img(\n            state,\n            version,\n            version_dict,\n            is_legacy=is_legacy,\n            return_latents=add_pipeline,\n            filter=state.get(\"filter\"),\n            stage2strength=stage2strength,\n        )\n    elif mode == \"img2img\":\n        out = run_img2img(\n            state,\n            version_dict,\n            is_legacy=is_legacy,\n            return_latents=add_pipeline,\n            filter=state.get(\"filter\"),\n            stage2strength=stage2strength,\n        )\n    elif mode == \"skip\":\n        out = None\n    else:\n        raise ValueError(f\"unknown mode {mode}\")\n    if isinstance(out, (tuple, list)):\n        samples, samples_z = out\n    else:\n        samples = out\n        samples_z = None\n\n    if add_pipeline and samples_z is not None:\n        st.write(\"**Running Refinement Stage**\")\n        samples = apply_refiner(\n            samples_z,\n            state2,\n            sampler2,\n            samples_z.shape[0],\n            prompt=prompt,\n            negative_prompt=negative_prompt if is_legacy else \"\",\n            filter=state.get(\"filter\"),\n            finish_denoising=finish_denoising,\n        )\n\n    if save_locally and samples is not None:\n        perform_save_locally(save_path, samples)\n"
  },
  {
    "path": "scripts/demo/streamlit_helpers.py",
    "content": "import copy\nimport math\nimport os\nfrom glob import glob\nfrom typing import Dict, List, Optional, Tuple, Union\n\nimport cv2\nimport imageio\nimport numpy as np\nimport streamlit as st\nimport torch\nimport torch.nn as nn\nimport torchvision.transforms as TT\nfrom einops import rearrange, repeat\nfrom imwatermark import WatermarkEncoder\nfrom omegaconf import ListConfig, OmegaConf\nfrom PIL import Image\nfrom safetensors.torch import load_file as load_safetensors\nfrom scripts.demo.discretization import (\n    Img2ImgDiscretizationWrapper,\n    Txt2NoisyDiscretizationWrapper,\n)\nfrom scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering\nfrom sgm.inference.helpers import embed_watermark\nfrom sgm.modules.diffusionmodules.guiders import (\n    LinearPredictionGuider,\n    TrianglePredictionGuider,\n    VanillaCFG,\n)\nfrom sgm.modules.diffusionmodules.sampling import (\n    DPMPP2MSampler,\n    DPMPP2SAncestralSampler,\n    EulerAncestralSampler,\n    EulerEDMSampler,\n    HeunEDMSampler,\n    LinearMultistepSampler,\n)\nfrom sgm.util import append_dims, default, instantiate_from_config\nfrom torch import autocast\nfrom torchvision import transforms\nfrom torchvision.utils import make_grid, save_image\n\n\n@st.cache_resource()\ndef init_st(version_dict, load_ckpt=True, load_filter=True):\n    state = dict()\n    if not \"model\" in state:\n        config = version_dict[\"config\"]\n        ckpt = version_dict[\"ckpt\"]\n\n        config = OmegaConf.load(config)\n        model, msg = load_model_from_config(config, ckpt if load_ckpt else None)\n\n        state[\"msg\"] = msg\n        state[\"model\"] = model\n        state[\"ckpt\"] = ckpt if load_ckpt else None\n        state[\"config\"] = config\n        if load_filter:\n            state[\"filter\"] = DeepFloydDataFiltering(verbose=False)\n    return state\n\n\ndef load_model(model):\n    model.cuda()\n\n\nlowvram_mode = False\n\n\ndef set_lowvram_mode(mode):\n    global lowvram_mode\n    lowvram_mode = mode\n\n\ndef initial_model_load(model):\n    global lowvram_mode\n    if lowvram_mode:\n        model.model.half()\n    else:\n        model.cuda()\n    return model\n\n\ndef unload_model(model):\n    global lowvram_mode\n    if lowvram_mode:\n        model.cpu()\n        torch.cuda.empty_cache()\n\n\ndef load_model_from_config(config, ckpt=None, verbose=True):\n    model = instantiate_from_config(config.model)\n\n    if ckpt is not None:\n        print(f\"Loading model from {ckpt}\")\n        if ckpt.endswith(\"ckpt\"):\n            pl_sd = torch.load(ckpt, map_location=\"cpu\")\n            if \"global_step\" in pl_sd:\n                global_step = pl_sd[\"global_step\"]\n                st.info(f\"loaded ckpt from global step {global_step}\")\n                print(f\"Global Step: {pl_sd['global_step']}\")\n            sd = pl_sd[\"state_dict\"]\n        elif ckpt.endswith(\"safetensors\"):\n            sd = load_safetensors(ckpt)\n        else:\n            raise NotImplementedError\n\n        msg = None\n\n        m, u = model.load_state_dict(sd, strict=False)\n\n        if len(m) > 0 and verbose:\n            print(\"missing keys:\")\n            print(m)\n        if len(u) > 0 and verbose:\n            print(\"unexpected keys:\")\n            print(u)\n    else:\n        msg = None\n\n    model = initial_model_load(model)\n    model.eval()\n    return model, msg\n\n\ndef get_unique_embedder_keys_from_conditioner(conditioner):\n    return list(set([x.input_key for x in conditioner.embedders]))\n\n\ndef init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):\n    # Hardcoded demo settings; might undergo some changes in the future\n\n    value_dict = {}\n    for key in keys:\n        if key == \"txt\":\n            if prompt is None:\n                prompt = \"A professional photograph of an astronaut riding a pig\"\n            if negative_prompt is None:\n                negative_prompt = \"\"\n\n            prompt = st.text_input(\"Prompt\", prompt)\n            negative_prompt = st.text_input(\"Negative prompt\", negative_prompt)\n\n            value_dict[\"prompt\"] = prompt\n            value_dict[\"negative_prompt\"] = negative_prompt\n\n        if key == \"original_size_as_tuple\":\n            orig_width = st.number_input(\n                \"orig_width\",\n                value=init_dict[\"orig_width\"],\n                min_value=16,\n            )\n            orig_height = st.number_input(\n                \"orig_height\",\n                value=init_dict[\"orig_height\"],\n                min_value=16,\n            )\n\n            value_dict[\"orig_width\"] = orig_width\n            value_dict[\"orig_height\"] = orig_height\n\n        if key == \"crop_coords_top_left\":\n            crop_coord_top = st.number_input(\"crop_coords_top\", value=0, min_value=0)\n            crop_coord_left = st.number_input(\"crop_coords_left\", value=0, min_value=0)\n\n            value_dict[\"crop_coords_top\"] = crop_coord_top\n            value_dict[\"crop_coords_left\"] = crop_coord_left\n\n        if key == \"aesthetic_score\":\n            value_dict[\"aesthetic_score\"] = 6.0\n            value_dict[\"negative_aesthetic_score\"] = 2.5\n\n        if key == \"target_size_as_tuple\":\n            value_dict[\"target_width\"] = init_dict[\"target_width\"]\n            value_dict[\"target_height\"] = init_dict[\"target_height\"]\n\n        if key in [\"fps_id\", \"fps\"]:\n            fps = st.number_input(\"fps\", value=6, min_value=1)\n\n            value_dict[\"fps\"] = fps\n            value_dict[\"fps_id\"] = fps - 1\n\n        if key == \"motion_bucket_id\":\n            mb_id = st.number_input(\"motion bucket id\", 0, 511, value=127)\n            value_dict[\"motion_bucket_id\"] = mb_id\n\n        if key == \"pool_image\":\n            st.text(\"Image for pool conditioning\")\n            image = load_img(\n                key=\"pool_image_input\",\n                size=224,\n                center_crop=True,\n            )\n            if image is None:\n                st.info(\"Need an image here\")\n                image = torch.zeros(1, 3, 224, 224)\n            value_dict[\"pool_image\"] = image\n\n    return value_dict\n\n\ndef perform_save_locally(save_path, samples):\n    os.makedirs(os.path.join(save_path), exist_ok=True)\n    base_count = len(os.listdir(os.path.join(save_path)))\n    samples = embed_watermark(samples)\n    for sample in samples:\n        sample = 255.0 * rearrange(sample.cpu().numpy(), \"c h w -> h w c\")\n        Image.fromarray(sample.astype(np.uint8)).save(\n            os.path.join(save_path, f\"{base_count:09}.png\")\n        )\n        base_count += 1\n\n\ndef init_save_locally(_dir, init_value: bool = False):\n    save_locally = st.sidebar.checkbox(\"Save images locally\", value=init_value)\n    if save_locally:\n        save_path = st.text_input(\"Save path\", value=os.path.join(_dir, \"samples\"))\n    else:\n        save_path = None\n\n    return save_locally, save_path\n\n\ndef get_guider(options, key):\n    guider = st.sidebar.selectbox(\n        f\"Discretization #{key}\",\n        [\n            \"VanillaCFG\",\n            \"IdentityGuider\",\n            \"LinearPredictionGuider\",\n            \"TrianglePredictionGuider\",\n        ],\n        options.get(\"guider\", 0),\n    )\n\n    additional_guider_kwargs = options.pop(\"additional_guider_kwargs\", {})\n\n    if guider == \"IdentityGuider\":\n        guider_config = {\n            \"target\": \"sgm.modules.diffusionmodules.guiders.IdentityGuider\"\n        }\n    elif guider == \"VanillaCFG\":\n        scale = st.number_input(\n            f\"cfg-scale #{key}\",\n            value=options.get(\"cfg\", 5.0),\n            min_value=0.0,\n        )\n\n        guider_config = {\n            \"target\": \"sgm.modules.diffusionmodules.guiders.VanillaCFG\",\n            \"params\": {\n                \"scale\": scale,\n                **additional_guider_kwargs,\n            },\n        }\n    elif guider == \"LinearPredictionGuider\":\n        max_scale = st.number_input(\n            f\"max-cfg-scale #{key}\",\n            value=options.get(\"cfg\", 1.5),\n            min_value=1.0,\n        )\n        min_scale = st.sidebar.number_input(\n            f\"min guidance scale\",\n            value=options.get(\"min_cfg\", 1.0),\n            min_value=1.0,\n            max_value=10.0,\n        )\n\n        guider_config = {\n            \"target\": \"sgm.modules.diffusionmodules.guiders.LinearPredictionGuider\",\n            \"params\": {\n                \"max_scale\": max_scale,\n                \"min_scale\": min_scale,\n                \"num_frames\": options[\"num_frames\"],\n                **additional_guider_kwargs,\n            },\n        }\n    elif guider == \"TrianglePredictionGuider\":\n        max_scale = st.number_input(\n            f\"max-cfg-scale #{key}\",\n            value=options.get(\"cfg\", 2.5),\n            min_value=1.0,\n            max_value=10.0,\n        )\n        min_scale = st.sidebar.number_input(\n            f\"min guidance scale\",\n            value=options.get(\"min_cfg\", 1.0),\n            min_value=1.0,\n            max_value=10.0,\n        )\n\n        guider_config = {\n            \"target\": \"sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider\",\n            \"params\": {\n                \"max_scale\": max_scale,\n                \"min_scale\": min_scale,\n                \"num_frames\": options[\"num_frames\"],\n                **additional_guider_kwargs,\n            },\n        }\n    else:\n        raise NotImplementedError\n    return guider_config\n\n\ndef init_sampling(\n    key=1,\n    img2img_strength: Optional[float] = None,\n    specify_num_samples: bool = True,\n    stage2strength: Optional[float] = None,\n    options: Optional[Dict[str, int]] = None,\n):\n    options = {} if options is None else options\n\n    num_rows, num_cols = 1, 1\n    if specify_num_samples:\n        num_cols = st.number_input(\n            f\"num cols #{key}\", value=num_cols, min_value=1, max_value=10\n        )\n\n    steps = st.number_input(\n        f\"steps #{key}\", value=options.get(\"num_steps\", 50), min_value=1, max_value=1000\n    )\n    sampler = st.sidebar.selectbox(\n        f\"Sampler #{key}\",\n        [\n            \"EulerEDMSampler\",\n            \"HeunEDMSampler\",\n            \"EulerAncestralSampler\",\n            \"DPMPP2SAncestralSampler\",\n            \"DPMPP2MSampler\",\n            \"LinearMultistepSampler\",\n        ],\n        options.get(\"sampler\", 0),\n    )\n    discretization = st.sidebar.selectbox(\n        f\"Discretization #{key}\",\n        [\n            \"LegacyDDPMDiscretization\",\n            \"EDMDiscretization\",\n        ],\n        options.get(\"discretization\", 0),\n    )\n\n    discretization_config = get_discretization(discretization, options=options, key=key)\n\n    guider_config = get_guider(options=options, key=key)\n\n    sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)\n    if img2img_strength is not None:\n        st.warning(\n            f\"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper\"\n        )\n        sampler.discretization = Img2ImgDiscretizationWrapper(\n            sampler.discretization, strength=img2img_strength\n        )\n    if stage2strength is not None:\n        sampler.discretization = Txt2NoisyDiscretizationWrapper(\n            sampler.discretization, strength=stage2strength, original_steps=steps\n        )\n    return sampler, num_rows, num_cols\n\n\ndef get_discretization(discretization, options, key=1):\n    if discretization == \"LegacyDDPMDiscretization\":\n        discretization_config = {\n            \"target\": \"sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization\",\n        }\n    elif discretization == \"EDMDiscretization\":\n        sigma_min = st.sidebar.number_input(\n            f\"sigma_min #{key}\", value=options.get(\"sigma_min\", 0.03)\n        )  # 0.0292\n        sigma_max = st.sidebar.number_input(\n            f\"sigma_max #{key}\", value=options.get(\"sigma_max\", 14.61)\n        )  # 14.6146\n        rho = st.sidebar.number_input(f\"rho #{key}\", value=options.get(\"rho\", 3.0))\n        discretization_config = {\n            \"target\": \"sgm.modules.diffusionmodules.discretizer.EDMDiscretization\",\n            \"params\": {\n                \"sigma_min\": sigma_min,\n                \"sigma_max\": sigma_max,\n                \"rho\": rho,\n            },\n        }\n\n    return discretization_config\n\n\ndef get_sampler(sampler_name, steps, discretization_config, guider_config, key=1):\n    if sampler_name == \"EulerEDMSampler\" or sampler_name == \"HeunEDMSampler\":\n        s_churn = st.sidebar.number_input(f\"s_churn #{key}\", value=0.0, min_value=0.0)\n        s_tmin = st.sidebar.number_input(f\"s_tmin #{key}\", value=0.0, min_value=0.0)\n        s_tmax = st.sidebar.number_input(f\"s_tmax #{key}\", value=999.0, min_value=0.0)\n        s_noise = st.sidebar.number_input(f\"s_noise #{key}\", value=1.0, min_value=0.0)\n\n        if sampler_name == \"EulerEDMSampler\":\n            sampler = EulerEDMSampler(\n                num_steps=steps,\n                discretization_config=discretization_config,\n                guider_config=guider_config,\n                s_churn=s_churn,\n                s_tmin=s_tmin,\n                s_tmax=s_tmax,\n                s_noise=s_noise,\n                verbose=True,\n            )\n        elif sampler_name == \"HeunEDMSampler\":\n            sampler = HeunEDMSampler(\n                num_steps=steps,\n                discretization_config=discretization_config,\n                guider_config=guider_config,\n                s_churn=s_churn,\n                s_tmin=s_tmin,\n                s_tmax=s_tmax,\n                s_noise=s_noise,\n                verbose=True,\n            )\n    elif (\n        sampler_name == \"EulerAncestralSampler\"\n        or sampler_name == \"DPMPP2SAncestralSampler\"\n    ):\n        s_noise = st.sidebar.number_input(\"s_noise\", value=1.0, min_value=0.0)\n        eta = st.sidebar.number_input(\"eta\", value=1.0, min_value=0.0)\n\n        if sampler_name == \"EulerAncestralSampler\":\n            sampler = EulerAncestralSampler(\n                num_steps=steps,\n                discretization_config=discretization_config,\n                guider_config=guider_config,\n                eta=eta,\n                s_noise=s_noise,\n                verbose=True,\n            )\n        elif sampler_name == \"DPMPP2SAncestralSampler\":\n            sampler = DPMPP2SAncestralSampler(\n                num_steps=steps,\n                discretization_config=discretization_config,\n                guider_config=guider_config,\n                eta=eta,\n                s_noise=s_noise,\n                verbose=True,\n            )\n    elif sampler_name == \"DPMPP2MSampler\":\n        sampler = DPMPP2MSampler(\n            num_steps=steps,\n            discretization_config=discretization_config,\n            guider_config=guider_config,\n            verbose=True,\n        )\n    elif sampler_name == \"LinearMultistepSampler\":\n        order = st.sidebar.number_input(\"order\", value=4, min_value=1)\n        sampler = LinearMultistepSampler(\n            num_steps=steps,\n            discretization_config=discretization_config,\n            guider_config=guider_config,\n            order=order,\n            verbose=True,\n        )\n    else:\n        raise ValueError(f\"unknown sampler {sampler_name}!\")\n\n    return sampler\n\n\ndef get_interactive_image() -> Image.Image:\n    image = st.file_uploader(\"Input\", type=[\"jpg\", \"JPEG\", \"png\"])\n    if image is not None:\n        image = Image.open(image)\n        if not image.mode == \"RGB\":\n            image = image.convert(\"RGB\")\n        return image\n\n\ndef load_img(\n    display: bool = True,\n    size: Union[None, int, Tuple[int, int]] = None,\n    center_crop: bool = False,\n):\n    image = get_interactive_image()\n    if image is None:\n        return None\n    if display:\n        st.image(image)\n    w, h = image.size\n    print(f\"loaded input image of size ({w}, {h})\")\n\n    transform = []\n    if size is not None:\n        transform.append(transforms.Resize(size))\n    if center_crop:\n        transform.append(transforms.CenterCrop(size))\n    transform.append(transforms.ToTensor())\n    transform.append(transforms.Lambda(lambda x: 2.0 * x - 1.0))\n\n    transform = transforms.Compose(transform)\n    img = transform(image)[None, ...]\n    st.text(f\"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}\")\n    return img\n\n\ndef get_init_img(batch_size=1, key=None):\n    init_image = load_img(key=key).cuda()\n    init_image = repeat(init_image, \"1 ... -> b ...\", b=batch_size)\n    return init_image\n\n\ndef do_sample(\n    model,\n    sampler,\n    value_dict,\n    num_samples,\n    H,\n    W,\n    C,\n    F,\n    force_uc_zero_embeddings: Optional[List] = None,\n    force_cond_zero_embeddings: Optional[List] = None,\n    batch2model_input: List = None,\n    return_latents=False,\n    filter=None,\n    T=None,\n    additional_batch_uc_fields=None,\n    decoding_t=None,\n):\n    force_uc_zero_embeddings = default(force_uc_zero_embeddings, [])\n    batch2model_input = default(batch2model_input, [])\n    additional_batch_uc_fields = default(additional_batch_uc_fields, [])\n\n    st.text(\"Sampling\")\n\n    outputs = st.empty()\n    precision_scope = autocast\n    with torch.no_grad():\n        with precision_scope(\"cuda\"):\n            with model.ema_scope():\n                if T is not None:\n                    num_samples = [num_samples, T]\n                else:\n                    num_samples = [num_samples]\n\n                load_model(model.conditioner)\n                batch, batch_uc = get_batch(\n                    get_unique_embedder_keys_from_conditioner(model.conditioner),\n                    value_dict,\n                    num_samples,\n                    T=T,\n                    additional_batch_uc_fields=additional_batch_uc_fields,\n                )\n\n                c, uc = model.conditioner.get_unconditional_conditioning(\n                    batch,\n                    batch_uc=batch_uc,\n                    force_uc_zero_embeddings=force_uc_zero_embeddings,\n                    force_cond_zero_embeddings=force_cond_zero_embeddings,\n                )\n                unload_model(model.conditioner)\n\n                for k in c:\n                    if not k == \"crossattn\":\n                        c[k], uc[k] = map(\n                            lambda y: y[k][: math.prod(num_samples)].to(\"cuda\"), (c, uc)\n                        )\n                    if k in [\"crossattn\", \"concat\"] and T is not None:\n                        uc[k] = repeat(uc[k], \"b ... -> b t ...\", t=T)\n                        uc[k] = rearrange(uc[k], \"b t ... -> (b t) ...\", t=T)\n                        c[k] = repeat(c[k], \"b ... -> b t ...\", t=T)\n                        c[k] = rearrange(c[k], \"b t ... -> (b t) ...\", t=T)\n\n                additional_model_inputs = {}\n                for k in batch2model_input:\n                    if k == \"image_only_indicator\":\n                        assert T is not None\n\n                        if isinstance(\n                            sampler.guider,\n                            (\n                                VanillaCFG,\n                                LinearPredictionGuider,\n                                TrianglePredictionGuider,\n                            ),\n                        ):\n                            additional_model_inputs[k] = torch.zeros(\n                                num_samples[0] * 2, num_samples[1]\n                            ).to(\"cuda\")\n                        else:\n                            additional_model_inputs[k] = torch.zeros(num_samples).to(\n                                \"cuda\"\n                            )\n                    else:\n                        additional_model_inputs[k] = batch[k]\n\n                shape = (math.prod(num_samples), C, H // F, W // F)\n                randn = torch.randn(shape).to(\"cuda\")\n\n                def denoiser(input, sigma, c):\n                    return model.denoiser(\n                        model.model, input, sigma, c, **additional_model_inputs\n                    )\n\n                load_model(model.denoiser)\n                load_model(model.model)\n                samples_z = sampler(denoiser, randn, cond=c, uc=uc)\n                unload_model(model.model)\n                unload_model(model.denoiser)\n\n                load_model(model.first_stage_model)\n                model.en_and_decode_n_samples_a_time = (\n                    decoding_t  # Decode n frames at a time\n                )\n                samples_x = model.decode_first_stage(samples_z)\n                samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)\n                unload_model(model.first_stage_model)\n\n                if filter is not None:\n                    samples = filter(samples)\n\n                if T is None:\n                    grid = torch.stack([samples])\n                    grid = rearrange(grid, \"n b c h w -> (n h) (b w) c\")\n                    outputs.image(grid.cpu().numpy())\n                else:\n                    as_vids = rearrange(samples, \"(b t) c h w -> b t c h w\", t=T)\n                    for i, vid in enumerate(as_vids):\n                        grid = rearrange(make_grid(vid, nrow=4), \"c h w -> h w c\")\n                        st.image(\n                            grid.cpu().numpy(),\n                            f\"Sample #{i} as image\",\n                        )\n\n                if return_latents:\n                    return samples, samples_z\n                return samples\n\n\ndef get_batch(\n    keys,\n    value_dict: dict,\n    N: Union[List, ListConfig],\n    device: str = \"cuda\",\n    T: int = None,\n    additional_batch_uc_fields: List[str] = [],\n):\n    # Hardcoded demo setups; might undergo some changes in the future\n\n    batch = {}\n    batch_uc = {}\n\n    for key in keys:\n        if key == \"txt\":\n            batch[\"txt\"] = [value_dict[\"prompt\"]] * math.prod(N)\n\n            batch_uc[\"txt\"] = [value_dict[\"negative_prompt\"]] * math.prod(N)\n\n        elif key == \"original_size_as_tuple\":\n            batch[\"original_size_as_tuple\"] = (\n                torch.tensor([value_dict[\"orig_height\"], value_dict[\"orig_width\"]])\n                .to(device)\n                .repeat(math.prod(N), 1)\n            )\n        elif key == \"crop_coords_top_left\":\n            batch[\"crop_coords_top_left\"] = (\n                torch.tensor(\n                    [value_dict[\"crop_coords_top\"], value_dict[\"crop_coords_left\"]]\n                )\n                .to(device)\n                .repeat(math.prod(N), 1)\n            )\n        elif key == \"aesthetic_score\":\n            batch[\"aesthetic_score\"] = (\n                torch.tensor([value_dict[\"aesthetic_score\"]])\n                .to(device)\n                .repeat(math.prod(N), 1)\n            )\n            batch_uc[\"aesthetic_score\"] = (\n                torch.tensor([value_dict[\"negative_aesthetic_score\"]])\n                .to(device)\n                .repeat(math.prod(N), 1)\n            )\n\n        elif key == \"target_size_as_tuple\":\n            batch[\"target_size_as_tuple\"] = (\n                torch.tensor([value_dict[\"target_height\"], value_dict[\"target_width\"]])\n                .to(device)\n                .repeat(math.prod(N), 1)\n            )\n        elif key == \"fps\":\n            batch[key] = (\n                torch.tensor([value_dict[\"fps\"]]).to(device).repeat(math.prod(N))\n            )\n        elif key == \"fps_id\":\n            batch[key] = (\n                torch.tensor([value_dict[\"fps_id\"]]).to(device).repeat(math.prod(N))\n            )\n        elif key == \"motion_bucket_id\":\n            batch[key] = (\n                torch.tensor([value_dict[\"motion_bucket_id\"]])\n                .to(device)\n                .repeat(math.prod(N))\n            )\n        elif key == \"pool_image\":\n            batch[key] = repeat(value_dict[key], \"1 ... -> b ...\", b=math.prod(N)).to(\n                device, dtype=torch.half\n            )\n        elif key == \"cond_aug\":\n            batch[key] = repeat(\n                torch.tensor([value_dict[\"cond_aug\"]]).to(\"cuda\"),\n                \"1 -> b\",\n                b=math.prod(N),\n            )\n        elif key == \"cond_frames\":\n            batch[key] = repeat(value_dict[\"cond_frames\"], \"1 ... -> b ...\", b=N[0])\n        elif key == \"cond_frames_without_noise\":\n            batch[key] = repeat(\n                value_dict[\"cond_frames_without_noise\"], \"1 ... -> b ...\", b=N[0]\n            )\n        elif key == \"polars_rad\":\n            batch[key] = torch.tensor(value_dict[\"polars_rad\"]).to(device).repeat(N[0])\n        elif key == \"azimuths_rad\":\n            batch[key] = (\n                torch.tensor(value_dict[\"azimuths_rad\"]).to(device).repeat(N[0])\n            )\n        else:\n            batch[key] = value_dict[key]\n\n    if T is not None:\n        batch[\"num_video_frames\"] = T\n\n    for key in batch.keys():\n        if key not in batch_uc and isinstance(batch[key], torch.Tensor):\n            batch_uc[key] = torch.clone(batch[key])\n        elif key in additional_batch_uc_fields and key not in batch_uc:\n            batch_uc[key] = copy.copy(batch[key])\n    return batch, batch_uc\n\n\n@torch.no_grad()\ndef do_img2img(\n    img,\n    model,\n    sampler,\n    value_dict,\n    num_samples,\n    force_uc_zero_embeddings: Optional[List] = None,\n    force_cond_zero_embeddings: Optional[List] = None,\n    additional_kwargs={},\n    offset_noise_level: int = 0.0,\n    return_latents=False,\n    skip_encode=False,\n    filter=None,\n    add_noise=True,\n):\n    st.text(\"Sampling\")\n\n    outputs = st.empty()\n    precision_scope = autocast\n    with torch.no_grad():\n        with precision_scope(\"cuda\"):\n            with model.ema_scope():\n                load_model(model.conditioner)\n                batch, batch_uc = get_batch(\n                    get_unique_embedder_keys_from_conditioner(model.conditioner),\n                    value_dict,\n                    [num_samples],\n                )\n                c, uc = model.conditioner.get_unconditional_conditioning(\n                    batch,\n                    batch_uc=batch_uc,\n                    force_uc_zero_embeddings=force_uc_zero_embeddings,\n                    force_cond_zero_embeddings=force_cond_zero_embeddings,\n                )\n                unload_model(model.conditioner)\n                for k in c:\n                    c[k], uc[k] = map(lambda y: y[k][:num_samples].to(\"cuda\"), (c, uc))\n\n                for k in additional_kwargs:\n                    c[k] = uc[k] = additional_kwargs[k]\n                if skip_encode:\n                    z = img\n                else:\n                    load_model(model.first_stage_model)\n                    z = model.encode_first_stage(img)\n                    unload_model(model.first_stage_model)\n\n                noise = torch.randn_like(z)\n\n                sigmas = sampler.discretization(sampler.num_steps).cuda()\n                sigma = sigmas[0]\n\n                st.info(f\"all sigmas: {sigmas}\")\n                st.info(f\"noising sigma: {sigma}\")\n                if offset_noise_level > 0.0:\n                    noise = noise + offset_noise_level * append_dims(\n                        torch.randn(z.shape[0], device=z.device), z.ndim\n                    )\n                if add_noise:\n                    noised_z = z + noise * append_dims(sigma, z.ndim).cuda()\n                    noised_z = noised_z / torch.sqrt(\n                        1.0 + sigmas[0] ** 2.0\n                    )  # Note: hardcoded to DDPM-like scaling. need to generalize later.\n                else:\n                    noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0)\n\n                def denoiser(x, sigma, c):\n                    return model.denoiser(model.model, x, sigma, c)\n\n                load_model(model.denoiser)\n                load_model(model.model)\n                samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)\n                unload_model(model.model)\n                unload_model(model.denoiser)\n\n                load_model(model.first_stage_model)\n                samples_x = model.decode_first_stage(samples_z)\n                unload_model(model.first_stage_model)\n                samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)\n\n                if filter is not None:\n                    samples = filter(samples)\n\n                grid = rearrange(grid, \"n b c h w -> (n h) (b w) c\")\n                outputs.image(grid.cpu().numpy())\n                if return_latents:\n                    return samples, samples_z\n                return samples\n\n\ndef get_resizing_factor(\n    desired_shape: Tuple[int, int], current_shape: Tuple[int, int]\n) -> float:\n    r_bound = desired_shape[1] / desired_shape[0]\n    aspect_r = current_shape[1] / current_shape[0]\n    if r_bound >= 1.0:\n        if aspect_r >= r_bound:\n            factor = min(desired_shape) / min(current_shape)\n        else:\n            if aspect_r < 1.0:\n                factor = max(desired_shape) / min(current_shape)\n            else:\n                factor = max(desired_shape) / max(current_shape)\n    else:\n        if aspect_r <= r_bound:\n            factor = min(desired_shape) / min(current_shape)\n        else:\n            if aspect_r > 1:\n                factor = max(desired_shape) / min(current_shape)\n            else:\n                factor = max(desired_shape) / max(current_shape)\n\n    return factor\n\n\ndef get_interactive_image(key=None) -> Image.Image:\n    image = st.file_uploader(\"Input\", type=[\"jpg\", \"JPEG\", \"png\"], key=key)\n    if image is not None:\n        image = Image.open(image)\n        if not image.mode == \"RGB\":\n            image = image.convert(\"RGB\")\n        return image\n\n\ndef load_img_for_prediction(\n    W: int, H: int, display=True, key=None, device=\"cuda\"\n) -> torch.Tensor:\n    image = get_interactive_image(key=key)\n    if image is None:\n        return None\n    if display:\n        st.image(image)\n    w, h = image.size\n\n    image = np.array(image).astype(np.float32) / 255\n    if image.shape[-1] == 4:\n        rgb, alpha = image[:, :, :3], image[:, :, 3:]\n        image = rgb * alpha + (1 - alpha)\n\n    image = image.transpose(2, 0, 1)\n    image = torch.from_numpy(image).to(dtype=torch.float32)\n    image = image.unsqueeze(0)\n\n    rfs = get_resizing_factor((H, W), (h, w))\n    resize_size = [int(np.ceil(rfs * s)) for s in (h, w)]\n    top = (resize_size[0] - H) // 2\n    left = (resize_size[1] - W) // 2\n\n    image = torch.nn.functional.interpolate(\n        image, resize_size, mode=\"area\", antialias=False\n    )\n    image = TT.functional.crop(image, top=top, left=left, height=H, width=W)\n\n    if display:\n        numpy_img = np.transpose(image[0].numpy(), (1, 2, 0))\n        pil_image = Image.fromarray((numpy_img * 255).astype(np.uint8))\n        st.image(pil_image)\n    return image.to(device) * 2.0 - 1.0\n\n\ndef save_video_as_grid_and_mp4(\n    video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5\n):\n    os.makedirs(save_path, exist_ok=True)\n    base_count = len(glob(os.path.join(save_path, \"*.mp4\")))\n\n    video_batch = rearrange(video_batch, \"(b t) c h w -> b t c h w\", t=T)\n    video_batch = embed_watermark(video_batch)\n    for vid in video_batch:\n        save_image(vid, fp=os.path.join(save_path, f\"{base_count:06d}.png\"), nrow=4)\n\n        video_path = os.path.join(save_path, f\"{base_count:06d}.mp4\")\n        vid = (\n            (rearrange(vid, \"t c h w -> t h w c\") * 255).cpu().numpy().astype(np.uint8)\n        )\n        imageio.mimwrite(video_path, vid, fps=fps)\n\n        video_path_h264 = video_path[:-4] + \"_h264.mp4\"\n        os.system(f\"ffmpeg -i '{video_path}' -c:v libx264 '{video_path_h264}'\")\n        with open(video_path_h264, \"rb\") as f:\n            video_bytes = f.read()\n        os.remove(video_path_h264)\n        st.video(video_bytes)\n\n        base_count += 1\n"
  },
  {
    "path": "scripts/demo/sv3d_helpers.py",
    "content": "import os\n\nimport matplotlib.pyplot as plt\nimport numpy as np\n\n\ndef generate_dynamic_cycle_xy_values(\n    length=21,\n    init_elev=0,\n    num_components=84,\n    frequency_range=(1, 5),\n    amplitude_range=(0.5, 10),\n    step_range=(0, 2),\n):\n    # Y values generation\n    y_sequence = np.ones(length) * init_elev\n    for _ in range(num_components):\n        # Choose a frequency that will complete whole cycles in the sequence\n        frequency = np.random.randint(*frequency_range) * (2 * np.pi / length)\n        amplitude = np.random.uniform(*amplitude_range)\n        phase_shift = np.random.choice([0, np.pi])  # np.random.uniform(0, 2 * np.pi)\n        angles = (\n            np.linspace(0, frequency * length, length, endpoint=False) + phase_shift\n        )\n        y_sequence += np.sin(angles) * amplitude\n    # X values generation\n    # Generate length - 1 steps since the last step is back to start\n    steps = np.random.uniform(*step_range, length - 1)\n    total_step_sum = np.sum(steps)\n    # Calculate the scale factor to scale total steps to just under 360\n    scale_factor = (\n        360 - ((360 / length) * np.random.uniform(*step_range))\n    ) / total_step_sum\n    # Apply the scale factor and generate the sequence of X values\n    x_values = np.cumsum(steps * scale_factor)\n    # Ensure the sequence starts at 0 and add the final step to complete the loop\n    x_values = np.insert(x_values, 0, 0)\n    return x_values, y_sequence\n\n\ndef smooth_data(data, window_size):\n    # Extend data at both ends by wrapping around to create a continuous loop\n    pad_size = window_size\n    padded_data = np.concatenate((data[-pad_size:], data, data[:pad_size]))\n\n    # Apply smoothing\n    kernel = np.ones(window_size) / window_size\n    smoothed_data = np.convolve(padded_data, kernel, mode=\"same\")\n\n    # Extract the smoothed data corresponding to the original sequence\n    # Adjust the indices to account for the larger padding\n    start_index = pad_size\n    end_index = -pad_size if pad_size != 0 else None\n    smoothed_original_data = smoothed_data[start_index:end_index]\n    return smoothed_original_data\n\n\n# Function to generate and process the data\ndef gen_dynamic_loop(length=21, elev_deg=0):\n    while True:\n        # Generate the combined X and Y values using the new function\n        azim_values, elev_values = generate_dynamic_cycle_xy_values(\n            length=84, init_elev=elev_deg\n        )\n        # Smooth the Y values directly\n        smoothed_elev_values = smooth_data(elev_values, 5)\n        max_magnitude = np.max(np.abs(smoothed_elev_values))\n        if max_magnitude < 90:\n            break\n    subsample = 84 // length\n    azim_rad = np.deg2rad(azim_values[::subsample])\n    elev_rad = np.deg2rad(smoothed_elev_values[::subsample])\n    # Make cond frame the last one\n    return np.roll(azim_rad, -1), np.roll(elev_rad, -1)\n\n\ndef plot_3D(azim, polar, save_path, dynamic=True):\n    os.makedirs(os.path.dirname(save_path), exist_ok=True)\n    elev = np.deg2rad(90) - polar\n    fig = plt.figure(figsize=(5, 5))\n    ax = fig.add_subplot(projection=\"3d\")\n    cm = plt.get_cmap(\"Greys\")\n    col_line = [cm(i) for i in np.linspace(0.3, 1, len(azim) + 1)]\n    cm = plt.get_cmap(\"cool\")\n    col = [cm(float(i) / (len(azim))) for i in np.arange(len(azim))]\n    xs = np.cos(elev) * np.cos(azim)\n    ys = np.cos(elev) * np.sin(azim)\n    zs = np.sin(elev)\n    ax.scatter(xs[0], ys[0], zs[0], s=100, color=col[0])\n    xs_d, ys_d, zs_d = (xs[1:] - xs[:-1]), (ys[1:] - ys[:-1]), (zs[1:] - zs[:-1])\n    for i in range(len(xs) - 1):\n        if dynamic:\n            ax.quiver(\n                xs[i], ys[i], zs[i], xs_d[i], ys_d[i], zs_d[i], lw=2, color=col_line[i]\n            )\n        else:\n            ax.plot(xs[i : i + 2], ys[i : i + 2], zs[i : i + 2], lw=2, c=col_line[i])\n        ax.scatter(xs[i + 1], ys[i + 1], zs[i + 1], s=100, color=col[i + 1])\n    ax.scatter(xs[:1], ys[:1], zs[:1], s=120, facecolors=\"none\", edgecolors=\"k\")\n    ax.scatter(xs[-1:], ys[-1:], zs[-1:], s=120, facecolors=\"none\", edgecolors=\"k\")\n    ax.view_init(elev=30, azim=-20, roll=0)\n    plt.savefig(save_path, bbox_inches=\"tight\")\n    plt.clf()\n    plt.close()\n"
  },
  {
    "path": "scripts/demo/sv4d_helpers.py",
    "content": "import math\nimport os\nfrom glob import glob\nfrom pathlib import Path\nfrom typing import Dict, List, Optional, Tuple, Union\n\nimport cv2\nimport imageio\nimport numpy as np\nimport torch\nimport torchvision.transforms as TT\nfrom einops import rearrange, repeat\nfrom omegaconf import ListConfig, OmegaConf\nfrom PIL import Image, ImageSequence\nfrom rembg import remove\nfrom scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering\nfrom sgm.modules.autoencoding.temporal_ae import VideoDecoder\nfrom sgm.modules.diffusionmodules.guiders import (\n    LinearPredictionGuider,\n    SpatiotemporalPredictionGuider,\n    TrapezoidPredictionGuider,\n    TrianglePredictionGuider,\n    VanillaCFG,\n)\nfrom sgm.modules.diffusionmodules.sampling import (\n    DPMPP2MSampler,\n    DPMPP2SAncestralSampler,\n    EulerAncestralSampler,\n    EulerEDMSampler,\n    HeunEDMSampler,\n    LinearMultistepSampler,\n)\nfrom sgm.util import default, instantiate_from_config\nfrom torch import autocast\nfrom torchvision.transforms import ToTensor\n\n\ndef load_module_gpu(model):\n    model.cuda()\n\n\ndef unload_module_gpu(model):\n    model.cpu()\n    torch.cuda.empty_cache()\n\n\ndef initial_model_load(model):\n    model.model.half()\n    return model\n\n\ndef get_resizing_factor(\n    desired_shape: Tuple[int, int], current_shape: Tuple[int, int]\n) -> float:\n    r_bound = desired_shape[1] / desired_shape[0]\n    aspect_r = current_shape[1] / current_shape[0]\n    if r_bound >= 1.0:\n        if aspect_r >= r_bound:\n            factor = min(desired_shape) / min(current_shape)\n        else:\n            if aspect_r < 1.0:\n                factor = max(desired_shape) / min(current_shape)\n            else:\n                factor = max(desired_shape) / max(current_shape)\n    else:\n        if aspect_r <= r_bound:\n            factor = min(desired_shape) / min(current_shape)\n        else:\n            if aspect_r > 1:\n                factor = max(desired_shape) / min(current_shape)\n            else:\n                factor = max(desired_shape) / max(current_shape)\n    return factor\n\n\ndef read_gif(input_path, n_frames):\n    frames = []\n    video = Image.open(input_path)\n    for img in ImageSequence.Iterator(video):\n        frames.append(img.convert(\"RGBA\"))\n        if len(frames) == n_frames:\n            break\n    return frames\n\n\ndef read_mp4(input_path, n_frames):\n    frames = []\n    vidcap = cv2.VideoCapture(input_path)\n    success, image = vidcap.read()\n    while success:\n        frames.append(Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)))\n        success, image = vidcap.read()\n        if len(frames) == n_frames:\n            break\n    return frames\n\n\ndef save_img(file_name, img):\n    output_dir = os.path.dirname(file_name)\n    os.makedirs(output_dir, exist_ok=True)\n    imageio.imwrite(\n        file_name,\n        (((img[0].permute(1, 2, 0) + 1) / 2).cpu().numpy() * 255.0).astype(np.uint8),\n    )\n\n\ndef save_video(file_name, imgs, fps=10):\n    output_dir = os.path.dirname(file_name)\n    os.makedirs(output_dir, exist_ok=True)\n    img_grid = [\n        (((img[0].permute(1, 2, 0) + 1) / 2).cpu().numpy() * 255.0).astype(np.uint8)\n        for img in imgs\n    ]\n    if file_name.endswith(\".gif\"):\n        imageio.mimwrite(file_name, img_grid, fps=fps, loop=0)\n    else:\n        imageio.mimwrite(file_name, img_grid, fps=fps)\n\n\ndef read_video(\n    input_path: str,\n    n_frames: int,\n    device: str = \"cuda\",\n):\n    path = Path(input_path)\n    is_video_file = False\n    all_img_paths = []\n    if path.is_file():\n        if any([input_path.endswith(x) for x in [\".gif\", \".mp4\"]]):\n            is_video_file = True\n        else:\n            raise ValueError(\"Path is not a valid video file.\")\n    elif path.is_dir():\n        all_img_paths = sorted(\n            [\n                f\n                for f in path.iterdir()\n                if f.is_file() and f.suffix.lower() in [\".jpg\", \".jpeg\", \".png\"]\n            ]\n        )[:n_frames]\n    elif \"*\" in input_path:\n        all_img_paths = sorted(glob(input_path))[:n_frames]\n    else:\n        raise ValueError\n\n    if is_video_file and input_path.endswith(\".gif\"):\n        images = read_gif(input_path, n_frames)[:n_frames]\n    elif is_video_file and input_path.endswith(\".mp4\"):\n        images = read_mp4(input_path, n_frames)[:n_frames]\n    else:\n        print(f\"Loading {len(all_img_paths)} video frames...\")\n        images = [Image.open(img_path) for img_path in all_img_paths]\n\n    if len(images) < n_frames:\n        images = (images + images[::-1])[:n_frames]\n    if len(images) != n_frames:\n        raise ValueError(f\"Input video contains fewer than {n_frames} frames.\")\n\n    images_v0 = []\n\n    for image in images:\n        image = ToTensor()(image).unsqueeze(0).to(device)\n        images_v0.append(image * 2.0 - 1.0)\n    return images_v0\n\n\ndef preprocess_video(\n    input_path,\n    remove_bg=False,\n    n_frames=21,\n    W=576,\n    H=576,\n    output_folder=None,\n    image_frame_ratio=0.917,\n    base_count=0,\n):\n    print(f\"preprocess {input_path}\")\n    if output_folder is None:\n        output_folder = os.path.dirname(input_path)\n    path = Path(input_path)\n    is_video_file = False\n    all_img_paths = []\n    if path.is_file():\n        if any([input_path.endswith(x) for x in [\".gif\", \".mp4\"]]):\n            is_video_file = True\n        else:\n            raise ValueError(\"Path is not a valid video file.\")\n    elif path.is_dir():\n        all_img_paths = sorted(\n            [\n                f\n                for f in path.iterdir()\n                if f.is_file() and f.suffix.lower() in [\".jpg\", \".jpeg\", \".png\"]\n            ]\n        )[:n_frames]\n    elif \"*\" in input_path:\n        all_img_paths = sorted(glob(input_path))[:n_frames]\n    else:\n        raise ValueError\n\n    if is_video_file and input_path.endswith(\".gif\"):\n        images = read_gif(input_path, n_frames)[:n_frames]\n    elif is_video_file and input_path.endswith(\".mp4\"):\n        images = read_mp4(input_path, n_frames)[:n_frames]\n    else:\n        print(f\"Loading {len(all_img_paths)} video frames...\")\n        images = [Image.open(img_path) for img_path in all_img_paths]\n\n    if len(images) != n_frames:\n        raise ValueError(\n            f\"Input video contains {len(images)} frames, fewer than {n_frames} frames.\"\n        )\n\n    # Remove background\n    for i, image in enumerate(images):\n        if remove_bg:\n            if image.mode == \"RGBA\":\n                pass\n            else:\n                # image.thumbnail([W, H], Image.Resampling.LANCZOS)\n                image = remove(image.convert(\"RGBA\"), alpha_matting=True)\n            images[i] = image\n\n    # Crop video frames, assume the object is already in the center of the image\n    white_thresh = 250\n    images_v0 = []\n    box_coord = [np.inf, np.inf, 0, 0]\n    for image in images:\n        image_arr = np.array(image)\n        in_w, in_h = image_arr.shape[:2]\n        original_center = (in_w // 2, in_h // 2)\n        if image.mode == \"RGBA\":\n            ret, mask = cv2.threshold(\n                np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY\n            )\n        else:\n            # assume the input image has white background\n            ret, mask = cv2.threshold(\n                (np.array(image).mean(-1) <= white_thresh).astype(np.uint8) * 255,\n                0,\n                255,\n                cv2.THRESH_BINARY,\n            )\n\n        x, y, w, h = cv2.boundingRect(mask)\n        box_coord[0] = min(box_coord[0], x)\n        box_coord[1] = min(box_coord[1], y)\n        box_coord[2] = max(box_coord[2], x + w)\n        box_coord[3] = max(box_coord[3], y + h)\n    box_square = max(\n        original_center[0] - box_coord[0], original_center[1] - box_coord[1]\n    )\n    box_square = max(box_square, box_coord[2] - original_center[0])\n    box_square = max(box_square, box_coord[3] - original_center[1])\n    x, y = max(0, original_center[0] - box_square), max(\n        0, original_center[1] - box_square\n    )\n    w, h = min(image_arr.shape[0], 2 * box_square), min(\n        image_arr.shape[1], 2 * box_square\n    )\n    box_size = box_square * 2\n\n    for image in images:\n        if image.mode == \"RGB\":\n            image = image.convert(\"RGBA\")\n        image_arr = np.array(image)\n        side_len = (\n            int(box_size / image_frame_ratio) if image_frame_ratio is not None else in_w\n        )\n        padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)\n        center = side_len // 2\n        box_size_w = min(w, box_size)\n        box_size_h = min(h, box_size)\n        padded_image[\n            center - box_size_w // 2 : center - box_size_w // 2 + box_size_w,\n            center - box_size_h // 2 : center - box_size_h // 2 + box_size_h,\n        ] = image_arr[x : x + w, y : y + h]\n\n        rgba = Image.fromarray(padded_image).resize((W, H), Image.LANCZOS)\n        # rgba = image.resize((W, H), Image.LANCZOS)\n        rgba_arr = np.array(rgba) / 255.0\n        rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])\n        image = (rgb * 255).astype(np.uint8)\n\n        images_v0.append(image)\n\n    processed_file = os.path.join(output_folder, f\"{base_count:06d}_process_input.mp4\")\n    imageio.mimwrite(processed_file, images_v0, fps=10)\n    return processed_file\n\n\ndef sample_sv3d(\n    image,\n    num_frames: Optional[int] = None,  # 21 for SV3D\n    num_steps: Optional[int] = None,\n    version: str = \"sv3d_u\",\n    fps_id: int = 6,\n    motion_bucket_id: int = 127,\n    cond_aug: float = 0.02,\n    decoding_t: int = 14,  # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.\n    device: str = \"cuda\",\n    polar_rad: Optional[Union[float, List[float]]] = None,\n    azim_rad: Optional[List[float]] = None,\n    verbose: Optional[bool] = False,\n    sv3d_model=None,\n):\n    \"\"\"\n    Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each\n    image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.\n    \"\"\"\n\n    if sv3d_model is None:\n        if version == \"sv3d_u\":\n            model_config = \"scripts/sampling/configs/sv3d_u.yaml\"\n        elif version == \"sv3d_p\":\n            model_config = \"scripts/sampling/configs/sv3d_p.yaml\"\n        else:\n            raise ValueError(f\"Version {version} does not exist.\")\n\n        model, filter = load_model(\n            model_config,\n            device,\n            num_frames,\n            num_steps,\n            verbose,\n        )\n    else:\n        model = sv3d_model\n\n    load_module_gpu(model)\n\n    H, W = image.shape[2:]\n    F = 8\n    C = 4\n    shape = (num_frames, C, H // F, W // F)\n\n    value_dict = {}\n    value_dict[\"cond_frames_without_noise\"] = image\n    value_dict[\"motion_bucket_id\"] = motion_bucket_id\n    value_dict[\"fps_id\"] = fps_id\n    value_dict[\"cond_aug\"] = cond_aug\n    value_dict[\"cond_frames\"] = image + cond_aug * torch.randn_like(image)\n    if \"sv3d_p\" in version:\n        value_dict[\"polars_rad\"] = polar_rad\n        value_dict[\"azimuths_rad\"] = azim_rad\n\n    with torch.no_grad():\n        with torch.autocast(device):\n            load_module_gpu(model.conditioner)\n            batch, batch_uc = get_batch_sv3d(\n                get_unique_embedder_keys_from_conditioner(model.conditioner),\n                value_dict,\n                [1, num_frames],\n                T=num_frames,\n                device=device,\n            )\n            c, uc = model.conditioner.get_unconditional_conditioning(\n                batch,\n                batch_uc=batch_uc,\n                force_uc_zero_embeddings=[\n                    \"cond_frames\",\n                    \"cond_frames_without_noise\",\n                ],\n            )\n            unload_module_gpu(model.conditioner)\n\n            for k in [\"crossattn\", \"concat\"]:\n                uc[k] = repeat(uc[k], \"b ... -> b t ...\", t=num_frames)\n                uc[k] = rearrange(uc[k], \"b t ... -> (b t) ...\", t=num_frames)\n                c[k] = repeat(c[k], \"b ... -> b t ...\", t=num_frames)\n                c[k] = rearrange(c[k], \"b t ... -> (b t) ...\", t=num_frames)\n\n            randn = torch.randn(shape, device=device)\n\n            additional_model_inputs = {}\n            additional_model_inputs[\"image_only_indicator\"] = torch.zeros(\n                2, num_frames\n            ).to(device)\n            additional_model_inputs[\"num_video_frames\"] = batch[\"num_video_frames\"]\n\n            def denoiser(input, sigma, c):\n                return model.denoiser(\n                    model.model, input, sigma, c, **additional_model_inputs\n                )\n\n            load_module_gpu(model.model)\n            load_module_gpu(model.denoiser)\n            samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)\n            unload_module_gpu(model.denoiser)\n            unload_module_gpu(model.model)\n\n            load_module_gpu(model.first_stage_model)\n            model.en_and_decode_n_samples_a_time = decoding_t\n            samples_x = model.decode_first_stage(samples_z)\n            unload_module_gpu(model.first_stage_model)\n\n            samples_x[-1:] = value_dict[\"cond_frames_without_noise\"]\n            samples = torch.clamp(samples_x, min=-1.0, max=1.0)\n\n    unload_module_gpu(model)\n    return samples\n\n\ndef decode_latents(\n    model, samples_z, img_matrix, frame_indices, view_indices, timesteps\n):\n    load_module_gpu(model.first_stage_model)\n    for t in frame_indices:\n        for v in view_indices:\n            if True:  # t != 0 and v != 0:\n                if isinstance(model.first_stage_model.decoder, VideoDecoder):\n                    samples_x = model.decode_first_stage(\n                        samples_z[t, v][None], timesteps=timesteps\n                    )\n                else:\n                    samples_x = model.decode_first_stage(samples_z[t, v][None])\n                samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)\n                img_matrix[t][v] = samples * 2 - 1\n    unload_module_gpu(model.first_stage_model)\n    return img_matrix\n\n\ndef init_embedder_options_no_st(keys, init_dict, prompt=None, negative_prompt=None):\n    # Hardcoded demo settings; might undergo some changes in the future\n\n    value_dict = {}\n    for key in keys:\n        if key == \"txt\":\n            if prompt is None:\n                prompt = \"A professional photograph of an astronaut riding a pig\"\n            if negative_prompt is None:\n                negative_prompt = \"\"\n\n            value_dict[\"prompt\"] = prompt\n            value_dict[\"negative_prompt\"] = negative_prompt\n\n        if key == \"original_size_as_tuple\":\n            orig_width = init_dict[\"orig_width\"]\n            orig_height = init_dict[\"orig_height\"]\n\n            value_dict[\"orig_width\"] = orig_width\n            value_dict[\"orig_height\"] = orig_height\n\n        if key == \"crop_coords_top_left\":\n            crop_coord_top = 0\n            crop_coord_left = 0\n\n            value_dict[\"crop_coords_top\"] = crop_coord_top\n            value_dict[\"crop_coords_left\"] = crop_coord_left\n\n        if key == \"aesthetic_score\":\n            value_dict[\"aesthetic_score\"] = 6.0\n            value_dict[\"negative_aesthetic_score\"] = 2.5\n\n        if key == \"target_size_as_tuple\":\n            value_dict[\"target_width\"] = init_dict[\"target_width\"]\n            value_dict[\"target_height\"] = init_dict[\"target_height\"]\n\n        if key in [\"fps_id\", \"fps\"]:\n            fps = 6\n\n            value_dict[\"fps\"] = fps\n            value_dict[\"fps_id\"] = fps - 1\n\n        if key == \"motion_bucket_id\":\n            mb_id = 127\n            value_dict[\"motion_bucket_id\"] = mb_id\n\n        if key == \"noise_level\":\n            value_dict[\"noise_level\"] = 0\n\n    return value_dict\n\n\ndef get_discretization_no_st(discretization, options, key=1):\n    if discretization == \"LegacyDDPMDiscretization\":\n        discretization_config = {\n            \"target\": \"sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization\",\n        }\n    elif discretization == \"EDMDiscretization\":\n        sigma_min = options.get(\"sigma_min\", 0.03)\n        sigma_max = options.get(\"sigma_max\", 14.61)\n        rho = options.get(\"rho\", 3.0)\n        discretization_config = {\n            \"target\": \"sgm.modules.diffusionmodules.discretizer.EDMDiscretization\",\n            \"params\": {\n                \"sigma_min\": sigma_min,\n                \"sigma_max\": sigma_max,\n                \"rho\": rho,\n            },\n        }\n    return discretization_config\n\n\ndef get_guider_no_st(options, key):\n    guider = [\n        \"VanillaCFG\",\n        \"IdentityGuider\",\n        \"LinearPredictionGuider\",\n        \"TrianglePredictionGuider\",\n        \"TrapezoidPredictionGuider\",\n        \"SpatiotemporalPredictionGuider\",\n    ][options.get(\"guider\", 2)]\n\n    additional_guider_kwargs = (\n        options[\"additional_guider_kwargs\"]\n        if \"additional_guider_kwargs\" in options\n        else {}\n    )\n\n    if guider == \"IdentityGuider\":\n        guider_config = {\n            \"target\": \"sgm.modules.diffusionmodules.guiders.IdentityGuider\"\n        }\n    elif guider == \"VanillaCFG\":\n        scale_schedule = \"Identity\"\n\n        if scale_schedule == \"Identity\":\n            scale = options.get(\"cfg\", 5.0)\n\n            scale_schedule_config = {\n                \"target\": \"sgm.modules.diffusionmodules.guiders.IdentitySchedule\",\n                \"params\": {\"scale\": scale},\n            }\n\n        elif scale_schedule == \"Oscillating\":\n            small_scale = 4.0\n            large_scale = 16.0\n            sigma_cutoff = 1.0\n\n            scale_schedule_config = {\n                \"target\": \"sgm.modules.diffusionmodules.guiders.OscillatingSchedule\",\n                \"params\": {\n                    \"small_scale\": small_scale,\n                    \"large_scale\": large_scale,\n                    \"sigma_cutoff\": sigma_cutoff,\n                },\n            }\n        else:\n            raise NotImplementedError\n\n        guider_config = {\n            \"target\": \"sgm.modules.diffusionmodules.guiders.VanillaCFG\",\n            \"params\": {\n                \"scale_schedule_config\": scale_schedule_config,\n                **additional_guider_kwargs,\n            },\n        }\n    elif guider == \"LinearPredictionGuider\":\n        max_scale = options.get(\"cfg\", 1.5)\n\n        guider_config = {\n            \"target\": \"sgm.modules.diffusionmodules.guiders.LinearPredictionGuider\",\n            \"params\": {\n                \"max_scale\": max_scale,\n                \"num_frames\": options[\"num_frames\"],\n                **additional_guider_kwargs,\n            },\n        }\n    elif guider == \"TrianglePredictionGuider\":\n        max_scale = options.get(\"cfg\", 1.5)\n        period = options.get(\"period\", 1.0)\n        period_fusing = options.get(\"period_fusing\", \"max\")\n\n        guider_config = {\n            \"target\": \"sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider\",\n            \"params\": {\n                \"max_scale\": max_scale,\n                \"num_frames\": options[\"num_frames\"],\n                \"period\": period,\n                \"period_fusing\": period_fusing,\n                **additional_guider_kwargs,\n            },\n        }\n    elif guider == \"TrapezoidPredictionGuider\":\n        max_scale = options.get(\"cfg\", 1.5)\n        edge_perc = options.get(\"edge_perc\", 0.1)\n\n        guider_config = {\n            \"target\": \"sgm.modules.diffusionmodules.guiders.TrapezoidPredictionGuider\",\n            \"params\": {\n                \"max_scale\": max_scale,\n                \"num_frames\": options[\"num_frames\"],\n                \"edge_perc\": edge_perc,\n                **additional_guider_kwargs,\n            },\n        }\n    elif guider == \"SpatiotemporalPredictionGuider\":\n        max_scale = options.get(\"cfg\", 1.5)\n        min_scale = options.get(\"min_cfg\", 1.0)\n\n        guider_config = {\n            \"target\": \"sgm.modules.diffusionmodules.guiders.SpatiotemporalPredictionGuider\",\n            \"params\": {\n                \"max_scale\": max_scale,\n                \"min_scale\": min_scale,\n                \"num_frames\": options[\"num_frames\"],\n                \"num_views\": options[\"num_views\"],\n                **additional_guider_kwargs,\n            },\n        }\n    else:\n        raise NotImplementedError\n    return guider_config\n\n\ndef get_sampler_no_st(sampler_name, steps, discretization_config, guider_config, key=1):\n    if sampler_name == \"EulerEDMSampler\" or sampler_name == \"HeunEDMSampler\":\n        s_churn = 0.0\n        s_tmin = 0.0\n        s_tmax = 999.0\n        s_noise = 1.0\n\n        if sampler_name == \"EulerEDMSampler\":\n            sampler = EulerEDMSampler(\n                num_steps=steps,\n                discretization_config=discretization_config,\n                guider_config=guider_config,\n                s_churn=s_churn,\n                s_tmin=s_tmin,\n                s_tmax=s_tmax,\n                s_noise=s_noise,\n                verbose=False,\n            )\n        elif sampler_name == \"HeunEDMSampler\":\n            sampler = HeunEDMSampler(\n                num_steps=steps,\n                discretization_config=discretization_config,\n                guider_config=guider_config,\n                s_churn=s_churn,\n                s_tmin=s_tmin,\n                s_tmax=s_tmax,\n                s_noise=s_noise,\n                verbose=False,\n            )\n    elif (\n        sampler_name == \"EulerAncestralSampler\"\n        or sampler_name == \"DPMPP2SAncestralSampler\"\n    ):\n        s_noise = 1.0\n        eta = 1.0\n\n        if sampler_name == \"EulerAncestralSampler\":\n            sampler = EulerAncestralSampler(\n                num_steps=steps,\n                discretization_config=discretization_config,\n                guider_config=guider_config,\n                eta=eta,\n                s_noise=s_noise,\n                verbose=False,\n            )\n        elif sampler_name == \"DPMPP2SAncestralSampler\":\n            sampler = DPMPP2SAncestralSampler(\n                num_steps=steps,\n                discretization_config=discretization_config,\n                guider_config=guider_config,\n                eta=eta,\n                s_noise=s_noise,\n                verbose=False,\n            )\n    elif sampler_name == \"DPMPP2MSampler\":\n        sampler = DPMPP2MSampler(\n            num_steps=steps,\n            discretization_config=discretization_config,\n            guider_config=guider_config,\n            verbose=False,\n        )\n    elif sampler_name == \"LinearMultistepSampler\":\n        order = 4\n        sampler = LinearMultistepSampler(\n            num_steps=steps,\n            discretization_config=discretization_config,\n            guider_config=guider_config,\n            order=order,\n            verbose=False,\n        )\n    else:\n        raise ValueError(f\"unknown sampler {sampler_name}!\")\n\n    return sampler\n\n\ndef init_sampling_no_st(\n    key=1,\n    options: Optional[Dict[str, int]] = None,\n):\n    options = {} if options is None else options\n\n    num_rows, num_cols = 1, 1\n    steps = options.get(\"num_steps\", 50)\n    sampler = [\n        \"EulerEDMSampler\",\n        \"HeunEDMSampler\",\n        \"EulerAncestralSampler\",\n        \"DPMPP2SAncestralSampler\",\n        \"DPMPP2MSampler\",\n        \"LinearMultistepSampler\",\n    ][options.get(\"sampler\", 0)]\n    discretization = [\n        \"LegacyDDPMDiscretization\",\n        \"EDMDiscretization\",\n    ][options.get(\"discretization\", 1)]\n\n    discretization_config = get_discretization_no_st(\n        discretization, options=options, key=key\n    )\n\n    guider_config = get_guider_no_st(options=options, key=key)\n\n    sampler = get_sampler_no_st(\n        sampler, steps, discretization_config, guider_config, key=key\n    )\n    return sampler, num_rows, num_cols\n\n\ndef run_img2vid(\n    version_dict,\n    model,\n    image,\n    seed=23,\n    polar_rad=[10] * 21,\n    azim_rad=np.linspace(0, 360, 21 + 1)[1:],\n    cond_motion=None,\n    cond_view=None,\n    decoding_t=None,\n    cond_mv=True,\n):\n    options = version_dict[\"options\"]\n    H = version_dict[\"H\"]\n    W = version_dict[\"W\"]\n    T = version_dict[\"T\"]\n    C = version_dict[\"C\"]\n    F = version_dict[\"f\"]\n    init_dict = {\n        \"orig_width\": 576,\n        \"orig_height\": 576,\n        \"target_width\": W,\n        \"target_height\": H,\n    }\n    ukeys = set(get_unique_embedder_keys_from_conditioner(model.conditioner))\n\n    value_dict = init_embedder_options_no_st(\n        ukeys,\n        init_dict,\n        negative_prompt=options.get(\"negative_promt\", \"\"),\n        prompt=\"A 3D model.\",\n    )\n    if \"fps\" not in ukeys:\n        value_dict[\"fps\"] = 6\n\n    value_dict[\"is_image\"] = 0\n    value_dict[\"is_webvid\"] = 0\n    if cond_mv:\n        value_dict[\"image_only_indicator\"] = 1.0\n    else:\n        value_dict[\"image_only_indicator\"] = 0.0\n\n    cond_aug = 0.00\n    if cond_motion is not None:\n        value_dict[\"cond_frames_without_noise\"] = cond_motion\n        value_dict[\"cond_frames\"] = (\n            cond_motion[:, None].repeat(1, cond_view.shape[0], 1, 1, 1).flatten(0, 1)\n        )\n    else:\n        value_dict[\"cond_frames_without_noise\"] = image\n        value_dict[\"cond_frames\"] = image + cond_aug * torch.randn_like(image)\n    value_dict[\"cond_aug\"] = cond_aug\n    value_dict[\"polar_rad\"] = polar_rad\n    value_dict[\"azimuth_rad\"] = azim_rad\n    value_dict[\"rotated\"] = False\n    value_dict[\"cond_motion\"] = cond_motion\n    value_dict[\"cond_view\"] = cond_view\n\n    # seed_everything(seed)\n\n    options[\"num_frames\"] = T\n    sampler, num_rows, num_cols = init_sampling_no_st(options=options)\n    num_samples = num_rows * num_cols\n\n    samples = do_sample(\n        model,\n        sampler,\n        value_dict,\n        num_samples,\n        H,\n        W,\n        C,\n        F,\n        T=T,\n        batch2model_input=[\"num_video_frames\", \"image_only_indicator\"],\n        force_uc_zero_embeddings=options.get(\"force_uc_zero_embeddings\", None),\n        force_cond_zero_embeddings=options.get(\"force_cond_zero_embeddings\", None),\n        return_latents=False,\n        decoding_t=decoding_t,\n    )\n\n    return samples\n\n\ndef prepare_inputs_forward_backward(\n    img_matrix,\n    view_indices,\n    frame_indices,\n    v0,\n    t0,\n    t1,\n    model,\n    version_dict,\n    seed,\n    polars,\n    azims,\n):\n    # forward sampling\n    forward_frame_indices = frame_indices.copy()\n    image = img_matrix[t0][v0]\n    cond_motion = torch.cat([img_matrix[t][v0] for t in forward_frame_indices], 0)\n    cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)\n    forward_inputs = prepare_sampling(\n        version_dict,\n        model,\n        image,\n        seed,\n        polars,\n        azims,\n        cond_motion,\n        cond_view,\n    )\n\n    # backward sampling\n    backward_frame_indices = frame_indices[::-1].copy()\n    image = img_matrix[t1][v0]\n    cond_motion = torch.cat([img_matrix[t][v0] for t in backward_frame_indices], 0)\n    cond_view = torch.cat([img_matrix[t1][v] for v in view_indices], 0)\n    backward_inputs = prepare_sampling(\n        version_dict,\n        model,\n        image,\n        seed,\n        polars,\n        azims,\n        cond_motion,\n        cond_view,\n    )\n    return (\n        forward_inputs,\n        forward_frame_indices,\n        backward_inputs,\n        backward_frame_indices,\n    )\n\n\ndef prepare_inputs(\n    frame_indices,\n    img_matrix,\n    v0,\n    view_indices,\n    model,\n    version_dict,\n    seed,\n    polars,\n    azims,\n):\n    load_module_gpu(model.conditioner)\n    # forward sampling\n    forward_frame_indices = frame_indices.copy()\n    t0 = forward_frame_indices[0]\n    image = img_matrix[t0][v0]\n    cond_motion = torch.cat([img_matrix[t][v0] for t in forward_frame_indices], 0)\n    cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)\n    forward_inputs = prepare_sampling(\n        version_dict,\n        model,\n        image,\n        seed,\n        polars,\n        azims,\n        cond_motion,\n        cond_view,\n    )\n\n    # backward sampling\n    backward_frame_indices = frame_indices[::-1].copy()\n    t0 = backward_frame_indices[0]\n    image = img_matrix[t0][v0]\n    cond_motion = torch.cat([img_matrix[t][v0] for t in backward_frame_indices], 0)\n    cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)\n    backward_inputs = prepare_sampling(\n        version_dict,\n        model,\n        image,\n        seed,\n        polars,\n        azims,\n        cond_motion,\n        cond_view,\n    )\n\n    unload_module_gpu(model.conditioner)\n    return (\n        forward_inputs,\n        forward_frame_indices,\n        backward_inputs,\n        backward_frame_indices,\n    )\n\n\ndef do_sample(\n    model,\n    sampler,\n    value_dict,\n    num_samples,\n    H,\n    W,\n    C,\n    F,\n    force_uc_zero_embeddings: Optional[List] = None,\n    force_cond_zero_embeddings: Optional[List] = None,\n    batch2model_input: List = None,\n    return_latents=False,\n    filter=None,\n    T=None,\n    additional_batch_uc_fields=None,\n    decoding_t=None,\n):\n    force_uc_zero_embeddings = default(force_uc_zero_embeddings, [])\n    batch2model_input = default(batch2model_input, [])\n    additional_batch_uc_fields = default(additional_batch_uc_fields, [])\n\n    precision_scope = autocast\n    with torch.no_grad():\n        with precision_scope(\"cuda\"):\n            with model.ema_scope():\n                if T is not None:\n                    num_samples = [num_samples, T]\n                else:\n                    num_samples = [num_samples]\n\n                load_module_gpu(model.conditioner)\n                batch, batch_uc = get_batch(\n                    get_unique_embedder_keys_from_conditioner(model.conditioner),\n                    value_dict,\n                    num_samples,\n                    T=T,\n                    additional_batch_uc_fields=additional_batch_uc_fields,\n                )\n                c, uc = model.conditioner.get_unconditional_conditioning(\n                    batch,\n                    batch_uc=batch_uc,\n                    force_uc_zero_embeddings=force_uc_zero_embeddings,\n                    force_cond_zero_embeddings=force_cond_zero_embeddings,\n                )\n                unload_module_gpu(model.conditioner)\n\n                for k in c:\n                    if not k == \"crossattn\":\n                        c[k], uc[k] = map(\n                            lambda y: y[k][: math.prod(num_samples)].to(\"cuda\"), (c, uc)\n                        )\n\n                if value_dict[\"image_only_indicator\"] == 0:\n                    c[\"cond_view\"] *= 0\n                    uc[\"cond_view\"] *= 0\n\n                additional_model_inputs = {}\n                for k in batch2model_input:\n                    if k == \"image_only_indicator\":\n                        assert T is not None\n\n                        if isinstance(\n                            sampler.guider,\n                            (\n                                VanillaCFG,\n                                LinearPredictionGuider,\n                                TrianglePredictionGuider,\n                                TrapezoidPredictionGuider,\n                                SpatiotemporalPredictionGuider,\n                            ),\n                        ):\n                            additional_model_inputs[k] = (\n                                torch.zeros(num_samples[0] * 2, num_samples[1]).to(\n                                    \"cuda\"\n                                )\n                                + value_dict[\"image_only_indicator\"]\n                            )\n                        else:\n                            additional_model_inputs[k] = torch.zeros(num_samples).to(\n                                \"cuda\"\n                            )\n                    else:\n                        additional_model_inputs[k] = batch[k]\n\n                shape = (math.prod(num_samples), C, H // F, W // F)\n                randn = torch.randn(shape).to(\"cuda\")\n\n                def denoiser(input, sigma, c):\n                    return model.denoiser(\n                        model.model, input, sigma, c, **additional_model_inputs\n                    )\n\n                load_module_gpu(model.model)\n                load_module_gpu(model.denoiser)\n                samples_z = sampler(denoiser, randn, cond=c, uc=uc)\n                unload_module_gpu(model.denoiser)\n                unload_module_gpu(model.model)\n\n                load_module_gpu(model.first_stage_model)\n                if isinstance(model.first_stage_model.decoder, VideoDecoder):\n                    samples_x = model.decode_first_stage(\n                        samples_z, timesteps=default(decoding_t, T)\n                    )\n                else:\n                    samples_x = model.decode_first_stage(samples_z)\n                samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)\n                unload_module_gpu(model.first_stage_model)\n\n                if filter is not None:\n                    samples = filter(samples)\n\n                if return_latents:\n                    return samples, samples_z\n\n                return samples\n\n\ndef prepare_sampling_(\n    model,\n    sampler,\n    value_dict,\n    num_samples,\n    force_uc_zero_embeddings: Optional[List] = None,\n    force_cond_zero_embeddings: Optional[List] = None,\n    batch2model_input: List = None,\n    T=None,\n    additional_batch_uc_fields=None,\n):\n    force_uc_zero_embeddings = default(force_uc_zero_embeddings, [])\n    batch2model_input = default(batch2model_input, [])\n    additional_batch_uc_fields = default(additional_batch_uc_fields, [])\n\n    precision_scope = autocast\n    with torch.no_grad():\n        with precision_scope(\"cuda\"):\n            with model.ema_scope():\n                if T is not None:\n                    num_samples = [num_samples, T]\n                else:\n                    num_samples = [num_samples]\n                load_module_gpu(model.conditioner)\n                batch, batch_uc = get_batch(\n                    get_unique_embedder_keys_from_conditioner(model.conditioner),\n                    value_dict,\n                    num_samples,\n                    T=T,\n                    additional_batch_uc_fields=additional_batch_uc_fields,\n                )\n                c, uc = model.conditioner.get_unconditional_conditioning(\n                    batch,\n                    batch_uc=batch_uc,\n                    force_uc_zero_embeddings=force_uc_zero_embeddings,\n                    force_cond_zero_embeddings=force_cond_zero_embeddings,\n                )\n                unload_module_gpu(model.conditioner)\n\n                for k in c:\n                    if not k == \"crossattn\":\n                        c[k], uc[k] = map(\n                            lambda y: y[k][: math.prod(num_samples)].to(\"cuda\"), (c, uc)\n                        )\n\n                additional_model_inputs = {}\n                for k in batch2model_input:\n                    if k == \"image_only_indicator\":\n                        assert T is not None\n\n                        if isinstance(\n                            sampler.guider,\n                            (\n                                VanillaCFG,\n                                LinearPredictionGuider,\n                                TrianglePredictionGuider,\n                                TrapezoidPredictionGuider,\n                                SpatiotemporalPredictionGuider,\n                            ),\n                        ):\n                            additional_model_inputs[k] = (\n                                torch.zeros(num_samples[0] * 2, num_samples[1]).to(\n                                    \"cuda\"\n                                )\n                                + value_dict[\"image_only_indicator\"]\n                            )\n                        else:\n                            additional_model_inputs[k] = torch.zeros(num_samples).to(\n                                \"cuda\"\n                            )\n                    else:\n                        additional_model_inputs[k] = batch[k]\n\n    return c, uc, additional_model_inputs\n\n\ndef do_sample_per_step(\n    model, sampler, noisy_latents, c, uc, step, additional_model_inputs\n):\n    precision_scope = autocast\n    with torch.no_grad():\n        with precision_scope(\"cuda\"):\n            with model.ema_scope():\n                noisy_latents_scaled, s_in, sigmas, num_sigmas, _, _ = (\n                    sampler.prepare_sampling_loop(\n                        noisy_latents.clone(), c, uc, sampler.num_steps\n                    )\n                )\n\n                if step == 0:\n                    latents = noisy_latents_scaled\n                else:\n                    latents = noisy_latents\n\n                def denoiser(input, sigma, c):\n                    return model.denoiser(\n                        model.model, input, sigma, c, **additional_model_inputs\n                    )\n\n                gamma = (\n                    min(sampler.s_churn / (num_sigmas - 1), 2**0.5 - 1)\n                    if sampler.s_tmin <= sigmas[step] <= sampler.s_tmax\n                    else 0.0\n                )\n\n                load_module_gpu(model.model)\n                load_module_gpu(model.denoiser)\n                samples_z = sampler.sampler_step(\n                    s_in * sigmas[step],\n                    s_in * sigmas[step + 1],\n                    denoiser,\n                    latents,\n                    c,\n                    uc,\n                    gamma,\n                )\n                unload_module_gpu(model.denoiser)\n                unload_module_gpu(model.model)\n    return samples_z\n\n\ndef prepare_sampling(\n    version_dict,\n    model,\n    image,\n    seed=23,\n    polar_rad=[10] * 21,\n    azim_rad=np.linspace(0, 360, 21 + 1)[1:],\n    cond_motion=None,\n    cond_view=None,\n):\n    options = version_dict[\"options\"]\n    H = version_dict[\"H\"]\n    W = version_dict[\"W\"]\n    T = version_dict[\"T\"]\n    C = version_dict[\"C\"]\n    F = version_dict[\"f\"]\n    init_dict = {\n        \"orig_width\": 576,\n        \"orig_height\": 576,\n        \"target_width\": W,\n        \"target_height\": H,\n    }\n    ukeys = set(get_unique_embedder_keys_from_conditioner(model.conditioner))\n\n    value_dict = init_embedder_options_no_st(\n        ukeys,\n        init_dict,\n        negative_prompt=options.get(\"negative_promt\", \"\"),\n        prompt=\"A 3D model.\",\n    )\n    if \"fps\" not in ukeys:\n        value_dict[\"fps\"] = 6\n\n    value_dict[\"is_image\"] = 0\n    value_dict[\"is_webvid\"] = 0\n    value_dict[\"image_only_indicator\"] = 1.0\n\n    cond_aug = 0.00\n    if cond_motion is not None:\n        value_dict[\"cond_frames_without_noise\"] = cond_motion\n        value_dict[\"cond_frames\"] = (\n            cond_motion[:, None].repeat(1, cond_view.shape[0], 1, 1, 1).flatten(0, 1)\n        )\n    else:\n        value_dict[\"cond_frames_without_noise\"] = image\n        value_dict[\"cond_frames\"] = image + cond_aug * torch.randn_like(image)\n    value_dict[\"cond_aug\"] = cond_aug\n    value_dict[\"polar_rad\"] = polar_rad\n    value_dict[\"azimuth_rad\"] = azim_rad\n    value_dict[\"rotated\"] = False\n    value_dict[\"cond_motion\"] = cond_motion\n    value_dict[\"cond_view\"] = cond_view\n\n    options[\"num_frames\"] = T\n    sampler, num_rows, num_cols = init_sampling_no_st(options=options)\n    num_samples = num_rows * num_cols\n\n    c, uc, additional_model_inputs = prepare_sampling_(\n        model,\n        sampler,\n        value_dict,\n        num_samples,\n        force_uc_zero_embeddings=options.get(\"force_uc_zero_embeddings\", None),\n        force_cond_zero_embeddings=options.get(\"force_cond_zero_embeddings\", None),\n        batch2model_input=[\"num_video_frames\", \"image_only_indicator\"],\n        T=T,\n    )\n\n    return c, uc, additional_model_inputs, sampler\n\n\ndef get_unique_embedder_keys_from_conditioner(conditioner):\n    return list(set([x.input_key for x in conditioner.embedders]))\n\n\ndef get_batch_sv3d(keys, value_dict, N, T, device):\n    batch = {}\n    batch_uc = {}\n\n    for key in keys:\n        if key == \"fps_id\":\n            batch[key] = (\n                torch.tensor([value_dict[\"fps_id\"]])\n                .to(device)\n                .repeat(int(math.prod(N)))\n            )\n        elif key == \"motion_bucket_id\":\n            batch[key] = (\n                torch.tensor([value_dict[\"motion_bucket_id\"]])\n                .to(device)\n                .repeat(int(math.prod(N)))\n            )\n        elif key == \"cond_aug\":\n            batch[key] = repeat(\n                torch.tensor([value_dict[\"cond_aug\"]]).to(device),\n                \"1 -> b\",\n                b=math.prod(N),\n            )\n        elif key == \"cond_frames\" or key == \"cond_frames_without_noise\":\n            batch[key] = repeat(value_dict[key], \"1 ... -> b ...\", b=N[0])\n        elif key == \"polars_rad\" or key == \"azimuths_rad\":\n            batch[key] = torch.tensor(value_dict[key]).to(device).repeat(N[0])\n        else:\n            batch[key] = value_dict[key]\n\n    if T is not None:\n        batch[\"num_video_frames\"] = T\n\n    for key in batch.keys():\n        if key not in batch_uc and isinstance(batch[key], torch.Tensor):\n            batch_uc[key] = torch.clone(batch[key])\n    return batch, batch_uc\n\n\ndef get_batch(\n    keys,\n    value_dict: dict,\n    N: Union[List, ListConfig],\n    device: str = \"cuda\",\n    T: int = None,\n    additional_batch_uc_fields: List[str] = [],\n):\n    batch = {}\n    batch_uc = {}\n\n    for key in keys:\n        if key == \"txt\":\n            batch[\"txt\"] = [value_dict[\"prompt\"]] * math.prod(N)\n            batch_uc[\"txt\"] = [value_dict[\"negative_prompt\"]] * math.prod(N)\n\n        elif key == \"original_size_as_tuple\":\n            batch[\"original_size_as_tuple\"] = (\n                torch.tensor([value_dict[\"orig_height\"], value_dict[\"orig_width\"]])\n                .to(device)\n                .repeat(math.prod(N), 1)\n            )\n        elif key == \"crop_coords_top_left\":\n            batch[\"crop_coords_top_left\"] = (\n                torch.tensor(\n                    [value_dict[\"crop_coords_top\"], value_dict[\"crop_coords_left\"]]\n                )\n                .to(device)\n                .repeat(math.prod(N), 1)\n            )\n        elif key == \"aesthetic_score\":\n            batch[\"aesthetic_score\"] = (\n                torch.tensor([value_dict[\"aesthetic_score\"]])\n                .to(device)\n                .repeat(math.prod(N), 1)\n            )\n            batch_uc[\"aesthetic_score\"] = (\n                torch.tensor([value_dict[\"negative_aesthetic_score\"]])\n                .to(device)\n                .repeat(math.prod(N), 1)\n            )\n\n        elif key == \"target_size_as_tuple\":\n            batch[\"target_size_as_tuple\"] = (\n                torch.tensor([value_dict[\"target_height\"], value_dict[\"target_width\"]])\n                .to(device)\n                .repeat(math.prod(N), 1)\n            )\n        elif key == \"fps\":\n            batch[key] = (\n                torch.tensor([value_dict[\"fps\"]]).to(device).repeat(math.prod(N))\n            )\n        elif key == \"fps_id\":\n            batch[key] = (\n                torch.tensor([value_dict[\"fps_id\"]]).to(device).repeat(math.prod(N))\n            )\n        elif key == \"motion_bucket_id\":\n            batch[key] = (\n                torch.tensor([value_dict[\"motion_bucket_id\"]])\n                .to(device)\n                .repeat(math.prod(N))\n            )\n        elif key == \"pool_image\":\n            batch[key] = repeat(value_dict[key], \"1 ... -> b ...\", b=math.prod(N)).to(\n                device, dtype=torch.half\n            )\n        elif key == \"is_image\":\n            batch[key] = (\n                torch.tensor([value_dict[\"is_image\"]])\n                .to(device)\n                .repeat(math.prod(N))\n                .long()\n            )\n        elif key == \"is_webvid\":\n            batch[key] = (\n                torch.tensor([value_dict[\"is_webvid\"]])\n                .to(device)\n                .repeat(math.prod(N))\n                .long()\n            )\n        elif key == \"cond_aug\":\n            batch[key] = repeat(\n                torch.tensor([value_dict[\"cond_aug\"]]).to(\"cuda\"),\n                \"1 -> b\",\n                b=math.prod(N),\n            )\n        elif (\n            key == \"cond_frames\"\n            or key == \"cond_frames_without_noise\"\n            or key == \"back_frames\"\n        ):\n            # batch[key] = repeat(value_dict[key], \"1 ... -> b ...\", b=N[0])\n            batch[key] = value_dict[key]\n\n        elif key == \"interpolation_context\":\n            batch[key] = repeat(\n                value_dict[\"interpolation_context\"], \"b ... -> (b n) ...\", n=N[1]\n            )\n\n        elif key == \"start_frame\":\n            assert T is not None\n            batch[key] = repeat(value_dict[key], \"b ... -> (b t) ...\", t=T)\n\n        elif key == \"polar_rad\" or key == \"azimuth_rad\":\n            batch[key] = (\n                torch.tensor(value_dict[key]).to(device).repeat(math.prod(N) // T)\n            )\n\n        elif key == \"rotated\":\n            batch[key] = (\n                torch.tensor([value_dict[\"rotated\"]]).to(device).repeat(math.prod(N))\n            )\n\n        else:\n            batch[key] = value_dict[key]\n\n    if T is not None:\n        batch[\"num_video_frames\"] = T\n\n    for key in batch.keys():\n        if key not in batch_uc and isinstance(batch[key], torch.Tensor):\n            batch_uc[key] = torch.clone(batch[key])\n        elif key in additional_batch_uc_fields and key not in batch_uc:\n            batch_uc[key] = copy.copy(batch[key])\n    return batch, batch_uc\n\n\ndef load_model(\n    config: str,\n    device: str,\n    num_frames: int,\n    num_steps: int,\n    verbose: bool = False,\n    ckpt_path: str = None,\n):\n    config = OmegaConf.load(config)\n    if device == \"cuda\":\n        config.model.params.conditioner_config.params.emb_models[\n            0\n        ].params.open_clip_embedding_config.params.init_device = device\n\n    config.model.params.sampler_config.params.verbose = verbose\n    config.model.params.sampler_config.params.num_steps = num_steps\n    config.model.params.sampler_config.params.guider_config.params.num_frames = (\n        num_frames\n    )\n    if ckpt_path is not None:\n        config.model.params.ckpt_path = ckpt_path\n    if device == \"cuda\":\n        with torch.device(device):\n            model = instantiate_from_config(config.model).to(device).eval()\n    else:\n        model = instantiate_from_config(config.model).to(device).eval()\n\n    filter = DeepFloydDataFiltering(verbose=False, device=device)\n    return model, filter\n"
  },
  {
    "path": "scripts/demo/turbo.py",
    "content": "from st_keyup import st_keyup\nfrom streamlit_helpers import *\n\nfrom sgm.modules.diffusionmodules.sampling import EulerAncestralSampler\n\nVERSION2SPECS = {\n    \"SDXL-Turbo\": {\n        \"H\": 512,\n        \"W\": 512,\n        \"C\": 4,\n        \"f\": 8,\n        \"is_legacy\": False,\n        \"config\": \"configs/inference/sd_xl_base.yaml\",\n        \"ckpt\": \"checkpoints/sd_xl_turbo_1.0.safetensors\",\n    },\n}\n\n\nclass SubstepSampler(EulerAncestralSampler):\n    def __init__(self, n_sample_steps=1, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.n_sample_steps = n_sample_steps\n        self.steps_subset = [0, 100, 200, 300, 1000]\n\n    def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):\n        sigmas = self.discretization(\n            self.num_steps if num_steps is None else num_steps, device=self.device\n        )\n        sigmas = sigmas[\n            self.steps_subset[: self.n_sample_steps] + self.steps_subset[-1:]\n        ]\n        uc = cond\n        x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)\n        num_sigmas = len(sigmas)\n        s_in = x.new_ones([x.shape[0]])\n        return x, s_in, sigmas, num_sigmas, cond, uc\n\n\ndef seeded_randn(shape, seed):\n    randn = np.random.RandomState(seed).randn(*shape)\n    randn = torch.from_numpy(randn).to(device=\"cuda\", dtype=torch.float32)\n    return randn\n\n\nclass SeededNoise:\n    def __init__(self, seed):\n        self.seed = seed\n\n    def __call__(self, x):\n        self.seed = self.seed + 1\n        return seeded_randn(x.shape, self.seed)\n\n\ndef init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):\n    value_dict = {}\n    for key in keys:\n        if key == \"txt\":\n            value_dict[\"prompt\"] = prompt\n            value_dict[\"negative_prompt\"] = \"\"\n\n        if key == \"original_size_as_tuple\":\n            orig_width = init_dict[\"orig_width\"]\n            orig_height = init_dict[\"orig_height\"]\n\n            value_dict[\"orig_width\"] = orig_width\n            value_dict[\"orig_height\"] = orig_height\n\n        if key == \"crop_coords_top_left\":\n            crop_coord_top = 0\n            crop_coord_left = 0\n\n            value_dict[\"crop_coords_top\"] = crop_coord_top\n            value_dict[\"crop_coords_left\"] = crop_coord_left\n\n        if key == \"aesthetic_score\":\n            value_dict[\"aesthetic_score\"] = 6.0\n            value_dict[\"negative_aesthetic_score\"] = 2.5\n\n        if key == \"target_size_as_tuple\":\n            value_dict[\"target_width\"] = init_dict[\"target_width\"]\n            value_dict[\"target_height\"] = init_dict[\"target_height\"]\n\n    return value_dict\n\n\ndef sample(\n    model,\n    sampler,\n    prompt=\"A lush garden with oversized flowers and vibrant colors, inhabited by miniature animals.\",\n    H=1024,\n    W=1024,\n    seed=0,\n    filter=None,\n):\n    F = 8\n    C = 4\n    shape = (1, C, H // F, W // F)\n\n    value_dict = init_embedder_options(\n        keys=get_unique_embedder_keys_from_conditioner(model.conditioner),\n        init_dict={\n            \"orig_width\": W,\n            \"orig_height\": H,\n            \"target_width\": W,\n            \"target_height\": H,\n        },\n        prompt=prompt,\n    )\n\n    if seed is None:\n        seed = torch.seed()\n    precision_scope = autocast\n    with torch.no_grad():\n        with precision_scope(\"cuda\"):\n            batch, batch_uc = get_batch(\n                get_unique_embedder_keys_from_conditioner(model.conditioner),\n                value_dict,\n                [1],\n            )\n            c = model.conditioner(batch)\n            uc = None\n            randn = seeded_randn(shape, seed)\n\n            def denoiser(input, sigma, c):\n                return model.denoiser(\n                    model.model,\n                    input,\n                    sigma,\n                    c,\n                )\n\n            samples_z = sampler(denoiser, randn, cond=c, uc=uc)\n            samples_x = model.decode_first_stage(samples_z)\n            samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)\n            if filter is not None:\n                samples = filter(samples)\n            samples = (\n                (255 * samples)\n                .to(dtype=torch.uint8)\n                .permute(0, 2, 3, 1)\n                .detach()\n                .cpu()\n                .numpy()\n            )\n    return samples\n\n\ndef v_spacer(height) -> None:\n    for _ in range(height):\n        st.write(\"\\n\")\n\n\nif __name__ == \"__main__\":\n    st.title(\"Turbo\")\n\n    head_cols = st.columns([1, 1, 1])\n    with head_cols[0]:\n        version = st.selectbox(\"Model Version\", list(VERSION2SPECS.keys()), 0)\n        version_dict = VERSION2SPECS[version]\n\n    with head_cols[1]:\n        v_spacer(2)\n        if st.checkbox(\"Load Model\"):\n            mode = \"txt2img\"\n        else:\n            mode = \"skip\"\n\n    if mode != \"skip\":\n        state = init_st(version_dict, load_filter=True)\n        if state[\"msg\"]:\n            st.info(state[\"msg\"])\n        model = state[\"model\"]\n        load_model(model)\n\n    # seed\n    if \"seed\" not in st.session_state:\n        st.session_state.seed = 0\n\n    def increment_counter():\n        st.session_state.seed += 1\n\n    def decrement_counter():\n        if st.session_state.seed > 0:\n            st.session_state.seed -= 1\n\n    with head_cols[2]:\n        n_steps = st.number_input(label=\"number of steps\", min_value=1, max_value=4)\n\n    sampler = SubstepSampler(\n        n_sample_steps=1,\n        num_steps=1000,\n        eta=1.0,\n        discretization_config=dict(\n            target=\"sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization\"\n        ),\n    )\n    sampler.n_sample_steps = n_steps\n    default_prompt = (\n        \"A cinematic shot of a baby racoon wearing an intricate italian priest robe.\"\n    )\n    prompt = st_keyup(\n        \"Enter a value\", value=default_prompt, debounce=300, key=\"interactive_text\"\n    )\n\n    cols = st.columns([1, 5, 1])\n    if mode != \"skip\":\n        with cols[0]:\n            v_spacer(14)\n            st.button(\"↩\", on_click=decrement_counter)\n        with cols[2]:\n            v_spacer(14)\n            st.button(\"↪\", on_click=increment_counter)\n\n        sampler.noise_sampler = SeededNoise(seed=st.session_state.seed)\n        out = sample(\n            model,\n            sampler,\n            H=512,\n            W=512,\n            seed=st.session_state.seed,\n            prompt=prompt,\n            filter=state.get(\"filter\"),\n        )\n        with cols[1]:\n            st.image(out[0])\n"
  },
  {
    "path": "scripts/demo/video_sampling.py",
    "content": "import os\nimport sys\n\nsys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), \"../../\")))\nfrom pytorch_lightning import seed_everything\nfrom scripts.demo.streamlit_helpers import *\nfrom scripts.demo.sv3d_helpers import *\n\nSAVE_PATH = \"outputs/demo/vid/\"\n\nVERSION2SPECS = {\n    \"svd\": {\n        \"T\": 14,\n        \"H\": 576,\n        \"W\": 1024,\n        \"C\": 4,\n        \"f\": 8,\n        \"config\": \"configs/inference/svd.yaml\",\n        \"ckpt\": \"checkpoints/svd.safetensors\",\n        \"options\": {\n            \"discretization\": 1,\n            \"cfg\": 2.5,\n            \"sigma_min\": 0.002,\n            \"sigma_max\": 700.0,\n            \"rho\": 7.0,\n            \"guider\": 2,\n            \"force_uc_zero_embeddings\": [\"cond_frames\", \"cond_frames_without_noise\"],\n            \"num_steps\": 25,\n        },\n    },\n    \"svd_image_decoder\": {\n        \"T\": 14,\n        \"H\": 576,\n        \"W\": 1024,\n        \"C\": 4,\n        \"f\": 8,\n        \"config\": \"configs/inference/svd_image_decoder.yaml\",\n        \"ckpt\": \"checkpoints/svd_image_decoder.safetensors\",\n        \"options\": {\n            \"discretization\": 1,\n            \"cfg\": 2.5,\n            \"sigma_min\": 0.002,\n            \"sigma_max\": 700.0,\n            \"rho\": 7.0,\n            \"guider\": 2,\n            \"force_uc_zero_embeddings\": [\"cond_frames\", \"cond_frames_without_noise\"],\n            \"num_steps\": 25,\n        },\n    },\n    \"svd_xt\": {\n        \"T\": 25,\n        \"H\": 576,\n        \"W\": 1024,\n        \"C\": 4,\n        \"f\": 8,\n        \"config\": \"configs/inference/svd.yaml\",\n        \"ckpt\": \"checkpoints/svd_xt.safetensors\",\n        \"options\": {\n            \"discretization\": 1,\n            \"cfg\": 3.0,\n            \"min_cfg\": 1.5,\n            \"sigma_min\": 0.002,\n            \"sigma_max\": 700.0,\n            \"rho\": 7.0,\n            \"guider\": 2,\n            \"force_uc_zero_embeddings\": [\"cond_frames\", \"cond_frames_without_noise\"],\n            \"num_steps\": 30,\n            \"decoding_t\": 14,\n        },\n    },\n    \"svd_xt_image_decoder\": {\n        \"T\": 25,\n        \"H\": 576,\n        \"W\": 1024,\n        \"C\": 4,\n        \"f\": 8,\n        \"config\": \"configs/inference/svd_image_decoder.yaml\",\n        \"ckpt\": \"checkpoints/svd_xt_image_decoder.safetensors\",\n        \"options\": {\n            \"discretization\": 1,\n            \"cfg\": 3.0,\n            \"min_cfg\": 1.5,\n            \"sigma_min\": 0.002,\n            \"sigma_max\": 700.0,\n            \"rho\": 7.0,\n            \"guider\": 2,\n            \"force_uc_zero_embeddings\": [\"cond_frames\", \"cond_frames_without_noise\"],\n            \"num_steps\": 30,\n            \"decoding_t\": 14,\n        },\n    },\n    \"sv3d_u\": {\n        \"T\": 21,\n        \"H\": 576,\n        \"W\": 576,\n        \"C\": 4,\n        \"f\": 8,\n        \"config\": \"configs/inference/sv3d_u.yaml\",\n        \"ckpt\": \"checkpoints/sv3d_u.safetensors\",\n        \"options\": {\n            \"discretization\": 1,\n            \"cfg\": 2.5,\n            \"sigma_min\": 0.002,\n            \"sigma_max\": 700.0,\n            \"rho\": 7.0,\n            \"guider\": 3,\n            \"force_uc_zero_embeddings\": [\"cond_frames\", \"cond_frames_without_noise\"],\n            \"num_steps\": 50,\n            \"decoding_t\": 14,\n        },\n    },\n    \"sv3d_p\": {\n        \"T\": 21,\n        \"H\": 576,\n        \"W\": 576,\n        \"C\": 4,\n        \"f\": 8,\n        \"config\": \"configs/inference/sv3d_p.yaml\",\n        \"ckpt\": \"checkpoints/sv3d_p.safetensors\",\n        \"options\": {\n            \"discretization\": 1,\n            \"cfg\": 2.5,\n            \"sigma_min\": 0.002,\n            \"sigma_max\": 700.0,\n            \"rho\": 7.0,\n            \"guider\": 3,\n            \"force_uc_zero_embeddings\": [\"cond_frames\", \"cond_frames_without_noise\"],\n            \"num_steps\": 50,\n            \"decoding_t\": 14,\n        },\n    },\n}\n\n\nif __name__ == \"__main__\":\n    st.title(\"Stable Video Diffusion / SV3D\")\n    version = st.selectbox(\n        \"Model Version\",\n        [k for k in VERSION2SPECS.keys()],\n        0,\n    )\n    version_dict = VERSION2SPECS[version]\n    if st.checkbox(\"Load Model\"):\n        mode = \"img2vid\"\n    else:\n        mode = \"skip\"\n\n    H = st.sidebar.number_input(\n        \"H\", value=version_dict[\"H\"], min_value=64, max_value=2048\n    )\n    W = st.sidebar.number_input(\n        \"W\", value=version_dict[\"W\"], min_value=64, max_value=2048\n    )\n    T = st.sidebar.number_input(\n        \"T\", value=version_dict[\"T\"], min_value=0, max_value=128\n    )\n    C = version_dict[\"C\"]\n    F = version_dict[\"f\"]\n    options = version_dict[\"options\"]\n\n    if mode != \"skip\":\n        state = init_st(version_dict, load_filter=True)\n        if state[\"msg\"]:\n            st.info(state[\"msg\"])\n        model = state[\"model\"]\n\n        ukeys = set(\n            get_unique_embedder_keys_from_conditioner(state[\"model\"].conditioner)\n        )\n\n        value_dict = init_embedder_options(\n            ukeys,\n            {},\n        )\n\n        if \"fps\" not in ukeys:\n            value_dict[\"fps\"] = 10\n\n        value_dict[\"image_only_indicator\"] = 0\n\n        if mode == \"img2vid\":\n            img = load_img_for_prediction(W, H)\n            if \"sv3d\" in version:\n                cond_aug = 1e-5\n            else:\n                cond_aug = st.number_input(\n                    \"Conditioning augmentation:\", value=0.02, min_value=0.0\n                )\n            value_dict[\"cond_frames_without_noise\"] = img\n            value_dict[\"cond_frames\"] = img + cond_aug * torch.randn_like(img)\n            value_dict[\"cond_aug\"] = cond_aug\n\n        if \"sv3d_p\" in version:\n            elev_deg = st.number_input(\"elev_deg\", value=5, min_value=-90, max_value=90)\n            trajectory = st.selectbox(\n                \"Trajectory\",\n                [\"same elevation\", \"dynamic\"],\n                0,\n            )\n            if trajectory == \"same elevation\":\n                value_dict[\"polars_rad\"] = np.array([np.deg2rad(90 - elev_deg)] * T)\n                value_dict[\"azimuths_rad\"] = np.linspace(0, 2 * np.pi, T + 1)[1:]\n            elif trajectory == \"dynamic\":\n                azim_rad, elev_rad = gen_dynamic_loop(length=21, elev_deg=elev_deg)\n                value_dict[\"polars_rad\"] = np.deg2rad(90) - elev_rad\n                value_dict[\"azimuths_rad\"] = azim_rad\n        elif \"sv3d_u\" in version:\n            elev_deg = st.number_input(\"elev_deg\", value=5, min_value=-90, max_value=90)\n            value_dict[\"polars_rad\"] = np.array([np.deg2rad(90 - elev_deg)] * T)\n            value_dict[\"azimuths_rad\"] = np.linspace(0, 2 * np.pi, T + 1)[1:]\n\n        seed = st.sidebar.number_input(\n            \"seed\", value=23, min_value=0, max_value=int(1e9)\n        )\n        seed_everything(seed)\n\n        save_locally, save_path = init_save_locally(\n            os.path.join(SAVE_PATH, version), init_value=True\n        )\n\n        if \"sv3d\" in version:\n            plot_save_path = os.path.join(save_path, \"plot_3D.png\")\n            plot_3D(\n                azim=value_dict[\"azimuths_rad\"],\n                polar=value_dict[\"polars_rad\"],\n                save_path=plot_save_path,\n                dynamic=(\"sv3d_p\" in version),\n            )\n            st.image(\n                plot_save_path,\n                f\"3D camera trajectory\",\n            )\n\n        options[\"num_frames\"] = T\n\n        sampler, num_rows, num_cols = init_sampling(options=options)\n        num_samples = num_rows * num_cols\n\n        decoding_t = st.number_input(\n            \"Decode t frames at a time (set small if you are low on VRAM)\",\n            value=options.get(\"decoding_t\", T),\n            min_value=1,\n            max_value=int(1e9),\n        )\n\n        if st.checkbox(\"Overwrite fps in mp4 generator\", False):\n            saving_fps = st.number_input(\n                f\"saving video at fps:\", value=value_dict[\"fps\"], min_value=1\n            )\n        else:\n            saving_fps = value_dict[\"fps\"]\n\n        if st.button(\"Sample\"):\n            out = do_sample(\n                model,\n                sampler,\n                value_dict,\n                num_samples,\n                H,\n                W,\n                C,\n                F,\n                T=T,\n                batch2model_input=[\"num_video_frames\", \"image_only_indicator\"],\n                force_uc_zero_embeddings=options.get(\"force_uc_zero_embeddings\", None),\n                force_cond_zero_embeddings=options.get(\n                    \"force_cond_zero_embeddings\", None\n                ),\n                return_latents=False,\n                decoding_t=decoding_t,\n            )\n\n            if isinstance(out, (tuple, list)):\n                samples, samples_z = out\n            else:\n                samples = out\n                samples_z = None\n\n            if save_locally:\n                save_video_as_grid_and_mp4(samples, save_path, T, fps=saving_fps)\n"
  },
  {
    "path": "scripts/sampling/configs/sv3d_p.yaml",
    "content": "model:\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    scale_factor: 0.18215\n    disable_first_stage_autocast: True\n    ckpt_path: checkpoints/sv3d_p.safetensors\n\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.Denoiser\n      params:\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise\n\n    network_config:\n      target: sgm.modules.diffusionmodules.video_model.VideoUNet\n      params:\n        adm_in_channels: 1280\n        num_classes: sequential\n        use_checkpoint: True\n        in_channels: 8\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [4, 2, 1]\n        num_res_blocks: 2\n        channel_mult: [1, 2, 4, 4]\n        num_head_channels: 64\n        use_linear_in_transformer: True\n        transformer_depth: 1\n        context_dim: 1024\n        spatial_transformer_attn_type: softmax-xformers\n        extra_ff_mix_layer: True\n        use_spatial_context: True\n        merge_strategy: learned_with_images\n        video_kernel_size: [3, 1, 1]\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n        - input_key: cond_frames_without_noise\n          is_trainable: False\n          target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder\n          params:\n            n_cond_frames: 1\n            n_copies: 1\n            open_clip_embedding_config:\n              target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder\n              params:\n                freeze: True\n\n        - input_key: cond_frames\n          is_trainable: False\n          target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder\n          params:\n            disable_encoder_autocast: True\n            n_cond_frames: 1\n            n_copies: 1\n            is_ae: True\n            encoder_config:\n              target: sgm.models.autoencoder.AutoencoderKLModeOnly\n              params:\n                embed_dim: 4\n                monitor: val/rec_loss\n                ddconfig:\n                  attn_type: vanilla-xformers\n                  double_z: True\n                  z_channels: 4\n                  resolution: 256\n                  in_channels: 3\n                  out_ch: 3\n                  ch: 128\n                  ch_mult: [1, 2, 4, 4]\n                  num_res_blocks: 2\n                  attn_resolutions: []\n                  dropout: 0.0\n                lossconfig:\n                  target: torch.nn.Identity\n\n        - input_key: cond_aug\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n        - input_key: polars_rad\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 512\n\n        - input_key: azimuths_rad\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 512\n\n    first_stage_config:\n      target: sgm.models.autoencoder.AutoencodingEngine\n      params:\n        loss_config:\n          target: torch.nn.Identity\n        regularizer_config:\n          target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer\n        encoder_config:\n          target: torch.nn.Identity\n        decoder_config:\n          target: sgm.modules.diffusionmodules.model.Decoder\n          params:\n            attn_type: vanilla-xformers\n            double_z: True\n            z_channels: 4\n            resolution: 256\n            in_channels: 3\n            out_ch: 3\n            ch: 128\n            ch_mult: [ 1, 2, 4, 4 ]\n            num_res_blocks: 2\n            attn_resolutions: [ ]\n            dropout: 0.0\n\n    sampler_config:\n      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler\n      params:\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization\n          params:\n            sigma_max: 700.0\n\n        guider_config:\n          target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider\n          params:\n            max_scale: 2.5\n"
  },
  {
    "path": "scripts/sampling/configs/sv3d_u.yaml",
    "content": "model:\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    scale_factor: 0.18215\n    disable_first_stage_autocast: True\n    ckpt_path: checkpoints/sv3d_u.safetensors\n\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.Denoiser\n      params:\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise\n\n    network_config:\n      target: sgm.modules.diffusionmodules.video_model.VideoUNet\n      params:\n        adm_in_channels: 256\n        num_classes: sequential\n        use_checkpoint: True\n        in_channels: 8\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [4, 2, 1]\n        num_res_blocks: 2\n        channel_mult: [1, 2, 4, 4]\n        num_head_channels: 64\n        use_linear_in_transformer: True\n        transformer_depth: 1\n        context_dim: 1024\n        spatial_transformer_attn_type: softmax-xformers\n        extra_ff_mix_layer: True\n        use_spatial_context: True\n        merge_strategy: learned_with_images\n        video_kernel_size: [3, 1, 1]\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n        - is_trainable: False\n          input_key: cond_frames_without_noise\n          target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder\n          params:\n            n_cond_frames: 1\n            n_copies: 1\n            open_clip_embedding_config:\n              target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder\n              params:\n                freeze: True\n\n        - input_key: cond_frames\n          is_trainable: False\n          target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder\n          params:\n            disable_encoder_autocast: True\n            n_cond_frames: 1\n            n_copies: 1\n            is_ae: True\n            encoder_config:\n              target: sgm.models.autoencoder.AutoencoderKLModeOnly\n              params:\n                embed_dim: 4\n                monitor: val/rec_loss\n                ddconfig:\n                  attn_type: vanilla-xformers\n                  double_z: True\n                  z_channels: 4\n                  resolution: 256\n                  in_channels: 3\n                  out_ch: 3\n                  ch: 128\n                  ch_mult: [1, 2, 4, 4]\n                  num_res_blocks: 2\n                  attn_resolutions: []\n                  dropout: 0.0\n                lossconfig:\n                  target: torch.nn.Identity\n\n        - input_key: cond_aug\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n    first_stage_config:\n      target: sgm.models.autoencoder.AutoencodingEngine\n      params:\n        loss_config:\n          target: torch.nn.Identity\n        regularizer_config:\n          target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer\n        encoder_config:\n          target: torch.nn.Identity\n        decoder_config:\n          target: sgm.modules.diffusionmodules.model.Decoder\n          params:\n            attn_type: vanilla-xformers\n            double_z: True\n            z_channels: 4\n            resolution: 256\n            in_channels: 3\n            out_ch: 3\n            ch: 128\n            ch_mult: [ 1, 2, 4, 4 ]\n            num_res_blocks: 2\n            attn_resolutions: [ ]\n            dropout: 0.0\n\n    sampler_config:\n      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler\n      params:\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization\n          params:\n            sigma_max: 700.0\n\n        guider_config:\n          target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider\n          params:\n            max_scale: 2.5\n"
  },
  {
    "path": "scripts/sampling/configs/sv4d.yaml",
    "content": "N_TIME: 5\nN_VIEW: 8\nN_FRAMES: 40\n\nmodel:\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    scale_factor: 0.18215\n    en_and_decode_n_samples_a_time: 7\n    disable_first_stage_autocast: True\n    ckpt_path: checkpoints/sv4d.safetensors\n\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.Denoiser\n      params:\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise\n\n    network_config:\n      target: sgm.modules.diffusionmodules.video_model.SpatialUNetModelWithTime\n      params:\n        adm_in_channels: 1280\n        attention_resolutions: [4, 2, 1]\n        channel_mult: [1, 2, 4, 4]\n        context_dim: 1024\n        motion_context_dim: 4\n        extra_ff_mix_layer: True\n        in_channels: 8\n        legacy: False\n        model_channels: 320\n        num_classes: sequential\n        num_head_channels: 64\n        num_res_blocks: 2\n        out_channels: 4\n        replicate_time_mix_bug: True\n        spatial_transformer_attn_type: softmax-xformers\n        time_block_merge_factor: 0.0\n        time_block_merge_strategy: learned_with_images\n        time_kernel_size: [3, 1, 1]\n        time_mix_legacy: False\n        transformer_depth: 1\n        use_checkpoint: False\n        use_linear_in_transformer: True\n        use_spatial_context: True\n        use_spatial_transformer: True\n        use_motion_attention: True\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n\n        - input_key: cond_frames_without_noise\n          target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder\n          is_trainable: False\n          params:\n            n_cond_frames: ${N_TIME}\n            n_copies: 1\n            open_clip_embedding_config:\n              target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder\n              params:\n                freeze: True\n\n        - input_key: cond_frames\n          target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder\n          is_trainable: False\n          params:\n            is_ae: True\n            n_cond_frames: ${N_FRAMES}\n            n_copies: 1\n            encoder_config:\n              target: sgm.models.autoencoder.AutoencoderKLModeOnly\n              params:\n                ddconfig:\n                  attn_resolutions: []\n                  attn_type: vanilla-xformers\n                  ch: 128\n                  ch_mult: [1, 2, 4, 4]\n                  double_z: True\n                  dropout: 0.0\n                  in_channels: 3\n                  num_res_blocks: 2\n                  out_ch: 3\n                  resolution: 256\n                  z_channels: 4\n                embed_dim: 4\n                lossconfig:\n                  target: torch.nn.Identity\n                monitor: val/rec_loss\n            sigma_cond_config:\n              target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n              params:\n                outdim: 256\n            sigma_sampler_config:\n              target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler\n\n        - input_key: polar_rad\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 512\n\n        - input_key: azimuth_rad\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 512\n\n        - input_key: cond_view\n          is_trainable: False\n          target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder\n          params:\n            encoder_config:\n              target: sgm.models.autoencoder.AutoencoderKLModeOnly\n              params:\n                embed_dim: 4\n                monitor: val/rec_loss\n                ddconfig:\n                  attn_resolutions: []\n                  attn_type: vanilla-xformers\n                  ch: 128\n                  ch_mult: [1, 2, 4, 4]\n                  double_z: True\n                  dropout: 0.0\n                  in_channels: 3\n                  num_res_blocks: 2\n                  out_ch: 3\n                  resolution: 256\n                  z_channels: 4\n                lossconfig:\n                  target: torch.nn.Identity\n            is_ae: True\n            n_cond_frames: ${N_VIEW}\n            n_copies: 1\n            sigma_sampler_config:\n              target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler\n\n        - input_key: cond_motion\n          is_trainable: False\n          target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder\n          params:\n            is_ae: True\n            n_cond_frames: ${N_TIME}\n            n_copies: 1\n            encoder_config:\n              target: sgm.models.autoencoder.AutoencoderKLModeOnly\n              params:\n                embed_dim: 4\n                monitor: val/rec_loss\n                ddconfig:\n                  attn_resolutions: []\n                  attn_type: vanilla-xformers\n                  ch: 128\n                  ch_mult: [1, 2, 4, 4]\n                  double_z: True\n                  dropout: 0.0\n                  in_channels: 3\n                  num_res_blocks: 2\n                  out_ch: 3\n                  resolution: 256\n                  z_channels: 4\n                lossconfig:\n                  target: torch.nn.Identity\n            sigma_sampler_config:\n              target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler\n\n    first_stage_config:\n      target: sgm.models.autoencoder.AutoencodingEngine\n      params:\n        loss_config:\n          target: torch.nn.Identity\n        regularizer_config:\n          target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer\n        encoder_config:\n          target: torch.nn.Identity\n        decoder_config:\n          target: sgm.modules.diffusionmodules.model.Decoder\n          params:\n            attn_resolutions: []\n            attn_type: vanilla-xformers\n            ch: 128\n            ch_mult: [1, 2, 4, 4]\n            double_z: True\n            dropout: 0.0\n            in_channels: 3\n            num_res_blocks: 2\n            out_ch: 3\n            resolution: 256\n            z_channels: 4\n\n    sampler_config:\n      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler\n      params:\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization\n          params:\n            sigma_max: 500.0\n        guider_config:\n          target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider\n          params:\n            max_scale: 2.5\n            num_frames: ${N_FRAMES}\n            additional_cond_keys: [ cond_view, cond_motion ]\n"
  },
  {
    "path": "scripts/sampling/configs/sv4d2.yaml",
    "content": "N_TIME: 12\nN_VIEW: 4\nN_FRAMES: 48\n\nmodel:\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    scale_factor: 0.18215\n    en_and_decode_n_samples_a_time: 8\n    disable_first_stage_autocast: True\n    ckpt_path: checkpoints/sv4d2.safetensors\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.Denoiser\n      params:\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise\n\n    network_config:\n      target: sgm.modules.diffusionmodules.video_model.SpatialUNetModelWithTime\n      params:\n        adm_in_channels: 1280\n        attention_resolutions: [4, 2, 1]\n        channel_mult: [1, 2, 4, 4]\n        context_dim: 1024\n        motion_context_dim: 4\n        extra_ff_mix_layer: True\n        in_channels: 8\n        legacy: False\n        model_channels: 320\n        num_classes: sequential\n        num_head_channels: 64\n        num_res_blocks: 2\n        out_channels: 4\n        replicate_time_mix_bug: True\n        spatial_transformer_attn_type: softmax-xformers\n        time_block_merge_factor: 0.0\n        time_block_merge_strategy: learned_with_images\n        time_kernel_size: [3, 1, 1]\n        time_mix_legacy: False\n        transformer_depth: 1\n        use_checkpoint: False\n        use_linear_in_transformer: True\n        use_spatial_context: True\n        use_spatial_transformer: True\n        separate_motion_merge_factor: True\n        use_motion_attention: True\n        use_3d_attention: True\n        use_camera_emb: True\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n\n        - input_key: cond_frames_without_noise\n          target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder\n          is_trainable: False\n          params:\n            n_cond_frames: ${N_TIME}\n            n_copies: 1\n            open_clip_embedding_config:\n              target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder\n              params:\n                freeze: True\n\n        - input_key: cond_frames\n          target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder\n          is_trainable: False\n          params:\n            is_ae: True\n            n_cond_frames: ${N_FRAMES}\n            n_copies: 1\n            encoder_config:\n              target: sgm.models.autoencoder.AutoencoderKLModeOnly\n              params:\n                ddconfig:\n                  attn_resolutions: []\n                  attn_type: vanilla-xformers\n                  ch: 128\n                  ch_mult: [1, 2, 4, 4]\n                  double_z: True\n                  dropout: 0.0\n                  in_channels: 3\n                  num_res_blocks: 2\n                  out_ch: 3\n                  resolution: 256\n                  z_channels: 4\n                embed_dim: 4\n                lossconfig:\n                  target: torch.nn.Identity\n                monitor: val/rec_loss\n            sigma_cond_config:\n              target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n              params:\n                outdim: 256\n            sigma_sampler_config:\n              target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler\n\n        - input_key: polar_rad\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 512\n\n        - input_key: azimuth_rad\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 512\n\n        - input_key: cond_view\n          is_trainable: False\n          target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder\n          params:\n            is_ae: True\n            n_cond_frames: ${N_VIEW}\n            n_copies: 1\n            encoder_config:\n              target: sgm.models.autoencoder.AutoencoderKLModeOnly\n              params:\n                embed_dim: 4\n                monitor: val/rec_loss\n                ddconfig:\n                  attn_resolutions: []\n                  attn_type: vanilla-xformers\n                  ch: 128\n                  ch_mult: [1, 2, 4, 4]\n                  double_z: True\n                  dropout: 0.0\n                  in_channels: 3\n                  num_res_blocks: 2\n                  out_ch: 3\n                  resolution: 256\n                  z_channels: 4\n                lossconfig:\n                  target: torch.nn.Identity\n            sigma_sampler_config:\n              target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler\n\n        - input_key: cond_motion\n          is_trainable: False\n          target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder\n          params:\n            is_ae: True\n            n_cond_frames: ${N_TIME}\n            n_copies: 1\n            encoder_config:\n              target: sgm.models.autoencoder.AutoencoderKLModeOnly\n              params:\n                embed_dim: 4\n                monitor: val/rec_loss\n                ddconfig:\n                  attn_resolutions: []\n                  attn_type: vanilla-xformers\n                  ch: 128\n                  ch_mult: [1, 2, 4, 4]\n                  double_z: True\n                  dropout: 0.0\n                  in_channels: 3\n                  num_res_blocks: 2\n                  out_ch: 3\n                  resolution: 256\n                  z_channels: 4\n                lossconfig:\n                  target: torch.nn.Identity\n            sigma_sampler_config:\n              target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler\n\n    first_stage_config:\n      target: sgm.models.autoencoder.AutoencodingEngine\n      params:\n        loss_config:\n          target: torch.nn.Identity\n        regularizer_config:\n          target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer\n        encoder_config:\n          target: torch.nn.Identity\n        decoder_config:\n          target: sgm.modules.diffusionmodules.model.Decoder\n          params:\n            attn_resolutions: []\n            attn_type: vanilla-xformers\n            ch: 128\n            ch_mult: [1, 2, 4, 4]\n            double_z: True\n            dropout: 0.0\n            in_channels: 3\n            num_res_blocks: 2\n            out_ch: 3\n            resolution: 256\n            z_channels: 4\n\n    sampler_config:\n      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler\n      params:\n        num_steps: 50\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization\n          params:\n            sigma_max: 500.0\n        guider_config:\n          target: sgm.modules.diffusionmodules.guiders.SpatiotemporalPredictionGuider\n          params:\n            max_scale: 1.5\n            min_scale: 1.5\n            num_frames: ${N_FRAMES}\n            num_views: ${N_VIEW}\n            additional_cond_keys: [ cond_view, cond_motion ]\n"
  },
  {
    "path": "scripts/sampling/configs/sv4d2_8views.yaml",
    "content": "N_TIME: 5\nN_VIEW: 8\nN_FRAMES: 40\n\nmodel:\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    scale_factor: 0.18215\n    en_and_decode_n_samples_a_time: 8\n    disable_first_stage_autocast: True\n    ckpt_path: checkpoints/sv4d2_8views.safetensors\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.Denoiser\n      params:\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise\n\n    network_config:\n      target: sgm.modules.diffusionmodules.video_model.SpatialUNetModelWithTime\n      params:\n        adm_in_channels: 1280\n        attention_resolutions: [4, 2, 1]\n        channel_mult: [1, 2, 4, 4]\n        context_dim: 1024\n        motion_context_dim: 4\n        extra_ff_mix_layer: True\n        in_channels: 8\n        legacy: False\n        model_channels: 320\n        num_classes: sequential\n        num_head_channels: 64\n        num_res_blocks: 2\n        out_channels: 4\n        replicate_time_mix_bug: True\n        spatial_transformer_attn_type: softmax-xformers\n        time_block_merge_factor: 0.0\n        time_block_merge_strategy: learned_with_images\n        time_kernel_size: [3, 1, 1]\n        time_mix_legacy: False\n        transformer_depth: 1\n        use_checkpoint: False\n        use_linear_in_transformer: True\n        use_spatial_context: True\n        use_spatial_transformer: True\n        separate_motion_merge_factor: True\n        use_motion_attention: True\n        use_3d_attention: False\n        use_camera_emb: True\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n\n        - input_key: cond_frames_without_noise\n          target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder\n          is_trainable: False\n          params:\n            n_cond_frames: ${N_TIME}\n            n_copies: 1\n            open_clip_embedding_config:\n              target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder\n              params:\n                freeze: True\n\n        - input_key: cond_frames\n          target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder\n          is_trainable: False\n          params:\n            is_ae: True\n            n_cond_frames: ${N_FRAMES}\n            n_copies: 1\n            encoder_config:\n              target: sgm.models.autoencoder.AutoencoderKLModeOnly\n              params:\n                ddconfig:\n                  attn_resolutions: []\n                  attn_type: vanilla-xformers\n                  ch: 128\n                  ch_mult: [1, 2, 4, 4]\n                  double_z: True\n                  dropout: 0.0\n                  in_channels: 3\n                  num_res_blocks: 2\n                  out_ch: 3\n                  resolution: 256\n                  z_channels: 4\n                embed_dim: 4\n                lossconfig:\n                  target: torch.nn.Identity\n                monitor: val/rec_loss\n            sigma_cond_config:\n              target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n              params:\n                outdim: 256\n            sigma_sampler_config:\n              target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler\n\n        - input_key: polar_rad\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 512\n\n        - input_key: azimuth_rad\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 512\n\n        - input_key: cond_view\n          is_trainable: False\n          target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder\n          params:\n            is_ae: True\n            n_cond_frames: ${N_VIEW}\n            n_copies: 1\n            encoder_config:\n              target: sgm.models.autoencoder.AutoencoderKLModeOnly\n              params:\n                embed_dim: 4\n                monitor: val/rec_loss\n                ddconfig:\n                  attn_resolutions: []\n                  attn_type: vanilla-xformers\n                  ch: 128\n                  ch_mult: [1, 2, 4, 4]\n                  double_z: True\n                  dropout: 0.0\n                  in_channels: 3\n                  num_res_blocks: 2\n                  out_ch: 3\n                  resolution: 256\n                  z_channels: 4\n                lossconfig:\n                  target: torch.nn.Identity\n            sigma_sampler_config:\n              target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler\n\n        - input_key: cond_motion\n          is_trainable: False\n          target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder\n          params:\n            is_ae: True\n            n_cond_frames: ${N_TIME}\n            n_copies: 1\n            encoder_config:\n              target: sgm.models.autoencoder.AutoencoderKLModeOnly\n              params:\n                embed_dim: 4\n                monitor: val/rec_loss\n                ddconfig:\n                  attn_resolutions: []\n                  attn_type: vanilla-xformers\n                  ch: 128\n                  ch_mult: [1, 2, 4, 4]\n                  double_z: True\n                  dropout: 0.0\n                  in_channels: 3\n                  num_res_blocks: 2\n                  out_ch: 3\n                  resolution: 256\n                  z_channels: 4\n                lossconfig:\n                  target: torch.nn.Identity\n            sigma_sampler_config:\n              target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler\n\n    first_stage_config:\n      target: sgm.models.autoencoder.AutoencodingEngine\n      params:\n        loss_config:\n          target: torch.nn.Identity\n        regularizer_config:\n          target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer\n        encoder_config:\n          target: torch.nn.Identity\n        decoder_config:\n          target: sgm.modules.diffusionmodules.model.Decoder\n          params:\n            attn_resolutions: []\n            attn_type: vanilla-xformers\n            ch: 128\n            ch_mult: [1, 2, 4, 4]\n            double_z: True\n            dropout: 0.0\n            in_channels: 3\n            num_res_blocks: 2\n            out_ch: 3\n            resolution: 256\n            z_channels: 4\n\n    sampler_config:\n      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler\n      params:\n        num_steps: 50\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization\n          params:\n            sigma_max: 500.0\n        guider_config:\n          target: sgm.modules.diffusionmodules.guiders.SpatiotemporalPredictionGuider\n          params:\n            max_scale: 2.0\n            min_scale: 1.5\n            num_frames: ${N_FRAMES}\n            num_views: ${N_VIEW}\n            additional_cond_keys: [ cond_view, cond_motion ]\n"
  },
  {
    "path": "scripts/sampling/configs/svd.yaml",
    "content": "model:\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    scale_factor: 0.18215\n    disable_first_stage_autocast: True\n    ckpt_path: checkpoints/svd.safetensors\n\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.Denoiser\n      params:\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise\n\n    network_config:\n      target: sgm.modules.diffusionmodules.video_model.VideoUNet\n      params:\n        adm_in_channels: 768\n        num_classes: sequential\n        use_checkpoint: True\n        in_channels: 8\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [4, 2, 1]\n        num_res_blocks: 2\n        channel_mult: [1, 2, 4, 4]\n        num_head_channels: 64\n        use_linear_in_transformer: True\n        transformer_depth: 1\n        context_dim: 1024\n        spatial_transformer_attn_type: softmax-xformers\n        extra_ff_mix_layer: True\n        use_spatial_context: True\n        merge_strategy: learned_with_images\n        video_kernel_size: [3, 1, 1]\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n        - is_trainable: False\n          input_key: cond_frames_without_noise\n          target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder\n          params:\n            n_cond_frames: 1\n            n_copies: 1\n            open_clip_embedding_config:\n              target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder\n              params:\n                freeze: True\n\n        - input_key: fps_id\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n        - input_key: motion_bucket_id\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n        - input_key: cond_frames\n          is_trainable: False\n          target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder\n          params:\n            disable_encoder_autocast: True\n            n_cond_frames: 1\n            n_copies: 1\n            is_ae: True\n            encoder_config:\n              target: sgm.models.autoencoder.AutoencoderKLModeOnly\n              params:\n                embed_dim: 4\n                monitor: val/rec_loss\n                ddconfig:\n                  attn_type: vanilla-xformers\n                  double_z: True\n                  z_channels: 4\n                  resolution: 256\n                  in_channels: 3\n                  out_ch: 3\n                  ch: 128\n                  ch_mult: [1, 2, 4, 4]\n                  num_res_blocks: 2\n                  attn_resolutions: []\n                  dropout: 0.0\n                lossconfig:\n                  target: torch.nn.Identity\n\n        - input_key: cond_aug\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n    first_stage_config:\n      target: sgm.models.autoencoder.AutoencodingEngine\n      params:\n        loss_config:\n          target: torch.nn.Identity\n        regularizer_config:\n          target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer\n        encoder_config: \n          target: sgm.modules.diffusionmodules.model.Encoder\n          params:\n            attn_type: vanilla\n            double_z: True\n            z_channels: 4\n            resolution: 256\n            in_channels: 3\n            out_ch: 3\n            ch: 128\n            ch_mult: [1, 2, 4, 4]\n            num_res_blocks: 2\n            attn_resolutions: []\n            dropout: 0.0\n        decoder_config:\n          target: sgm.modules.autoencoding.temporal_ae.VideoDecoder\n          params:\n            attn_type: vanilla\n            double_z: True\n            z_channels: 4\n            resolution: 256\n            in_channels: 3\n            out_ch: 3\n            ch: 128\n            ch_mult: [1, 2, 4, 4]\n            num_res_blocks: 2\n            attn_resolutions: []\n            dropout: 0.0\n            video_kernel_size: [3, 1, 1]\n\n    sampler_config:\n      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler\n      params:\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization\n          params:\n            sigma_max: 700.0\n\n        guider_config:\n          target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider\n          params:\n            max_scale: 2.5\n            min_scale: 1.0"
  },
  {
    "path": "scripts/sampling/configs/svd_image_decoder.yaml",
    "content": "model:\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    scale_factor: 0.18215\n    disable_first_stage_autocast: True\n    ckpt_path: checkpoints/svd_image_decoder.safetensors\n\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.Denoiser\n      params:\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise\n\n    network_config:\n      target: sgm.modules.diffusionmodules.video_model.VideoUNet\n      params:\n        adm_in_channels: 768\n        num_classes: sequential\n        use_checkpoint: True\n        in_channels: 8\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [4, 2, 1]\n        num_res_blocks: 2\n        channel_mult: [1, 2, 4, 4]\n        num_head_channels: 64\n        use_linear_in_transformer: True\n        transformer_depth: 1\n        context_dim: 1024\n        spatial_transformer_attn_type: softmax-xformers\n        extra_ff_mix_layer: True\n        use_spatial_context: True\n        merge_strategy: learned_with_images\n        video_kernel_size: [3, 1, 1]\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n        - is_trainable: False\n          input_key: cond_frames_without_noise\n          target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder\n          params:\n            n_cond_frames: 1\n            n_copies: 1\n            open_clip_embedding_config:\n              target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder\n              params:\n                freeze: True\n\n        - input_key: fps_id\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n        - input_key: motion_bucket_id\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n        - input_key: cond_frames\n          is_trainable: False\n          target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder\n          params:\n            disable_encoder_autocast: True\n            n_cond_frames: 1\n            n_copies: 1\n            is_ae: True\n            encoder_config:\n              target: sgm.models.autoencoder.AutoencoderKLModeOnly\n              params:\n                embed_dim: 4\n                monitor: val/rec_loss\n                ddconfig:\n                  attn_type: vanilla-xformers\n                  double_z: True\n                  z_channels: 4\n                  resolution: 256\n                  in_channels: 3\n                  out_ch: 3\n                  ch: 128\n                  ch_mult: [1, 2, 4, 4]\n                  num_res_blocks: 2\n                  attn_resolutions: []\n                  dropout: 0.0\n                lossconfig:\n                  target: torch.nn.Identity\n\n        - input_key: cond_aug\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n    first_stage_config:\n      target: sgm.models.autoencoder.AutoencoderKL\n      params:\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          attn_type: vanilla-xformers\n          double_z: True\n          z_channels: 4\n          resolution: 256\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult: [1, 2, 4, 4]\n          num_res_blocks: 2\n          attn_resolutions: []\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity\n\n    sampler_config:\n      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler\n      params:\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization\n          params:\n            sigma_max: 700.0\n\n        guider_config:\n          target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider\n          params:\n            max_scale: 2.5\n            min_scale: 1.0"
  },
  {
    "path": "scripts/sampling/configs/svd_xt.yaml",
    "content": "model:\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    scale_factor: 0.18215\n    disable_first_stage_autocast: True\n    ckpt_path: checkpoints/svd_xt.safetensors\n\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.Denoiser\n      params:\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise\n\n    network_config:\n      target: sgm.modules.diffusionmodules.video_model.VideoUNet\n      params:\n        adm_in_channels: 768\n        num_classes: sequential\n        use_checkpoint: True\n        in_channels: 8\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [4, 2, 1]\n        num_res_blocks: 2\n        channel_mult: [1, 2, 4, 4]\n        num_head_channels: 64\n        use_linear_in_transformer: True\n        transformer_depth: 1\n        context_dim: 1024\n        spatial_transformer_attn_type: softmax-xformers\n        extra_ff_mix_layer: True\n        use_spatial_context: True\n        merge_strategy: learned_with_images\n        video_kernel_size: [3, 1, 1]\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n        - is_trainable: False\n          input_key: cond_frames_without_noise\n          target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder\n          params:\n            n_cond_frames: 1\n            n_copies: 1\n            open_clip_embedding_config:\n              target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder\n              params:\n                freeze: True\n\n        - input_key: fps_id\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n        - input_key: motion_bucket_id\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n        - input_key: cond_frames\n          is_trainable: False\n          target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder\n          params:\n            disable_encoder_autocast: True\n            n_cond_frames: 1\n            n_copies: 1\n            is_ae: True\n            encoder_config:\n              target: sgm.models.autoencoder.AutoencoderKLModeOnly\n              params:\n                embed_dim: 4\n                monitor: val/rec_loss\n                ddconfig:\n                  attn_type: vanilla-xformers\n                  double_z: True\n                  z_channels: 4\n                  resolution: 256\n                  in_channels: 3\n                  out_ch: 3\n                  ch: 128\n                  ch_mult: [1, 2, 4, 4]\n                  num_res_blocks: 2\n                  attn_resolutions: []\n                  dropout: 0.0\n                lossconfig:\n                  target: torch.nn.Identity\n\n        - input_key: cond_aug\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n    first_stage_config:\n      target: sgm.models.autoencoder.AutoencodingEngine\n      params:\n        loss_config:\n          target: torch.nn.Identity\n        regularizer_config:\n          target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer\n        encoder_config: \n          target: sgm.modules.diffusionmodules.model.Encoder\n          params:\n            attn_type: vanilla\n            double_z: True\n            z_channels: 4\n            resolution: 256\n            in_channels: 3\n            out_ch: 3\n            ch: 128\n            ch_mult: [1, 2, 4, 4]\n            num_res_blocks: 2\n            attn_resolutions: []\n            dropout: 0.0\n        decoder_config:\n          target: sgm.modules.autoencoding.temporal_ae.VideoDecoder\n          params:\n            attn_type: vanilla\n            double_z: True\n            z_channels: 4\n            resolution: 256\n            in_channels: 3\n            out_ch: 3\n            ch: 128\n            ch_mult: [1, 2, 4, 4]\n            num_res_blocks: 2\n            attn_resolutions: []\n            dropout: 0.0\n            video_kernel_size: [3, 1, 1]\n\n    sampler_config:\n      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler\n      params:\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization\n          params:\n            sigma_max: 700.0\n\n        guider_config:\n          target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider\n          params:\n            max_scale: 3.0\n            min_scale: 1.5"
  },
  {
    "path": "scripts/sampling/configs/svd_xt_1_1.yaml",
    "content": "model:\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    scale_factor: 0.18215\n    disable_first_stage_autocast: True\n    ckpt_path: checkpoints/svd_xt_1_1.safetensors\n\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.Denoiser\n      params:\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise\n\n    network_config:\n      target: sgm.modules.diffusionmodules.video_model.VideoUNet\n      params:\n        adm_in_channels: 768\n        num_classes: sequential\n        use_checkpoint: True\n        in_channels: 8\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [4, 2, 1]\n        num_res_blocks: 2\n        channel_mult: [1, 2, 4, 4]\n        num_head_channels: 64\n        use_linear_in_transformer: True\n        transformer_depth: 1\n        context_dim: 1024\n        spatial_transformer_attn_type: softmax-xformers\n        extra_ff_mix_layer: True\n        use_spatial_context: True\n        merge_strategy: learned_with_images\n        video_kernel_size: [3, 1, 1]\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n        - is_trainable: False\n          input_key: cond_frames_without_noise\n          target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder\n          params:\n            n_cond_frames: 1\n            n_copies: 1\n            open_clip_embedding_config:\n              target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder\n              params:\n                freeze: True\n\n        - input_key: fps_id\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n        - input_key: motion_bucket_id\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n        - input_key: cond_frames\n          is_trainable: False\n          target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder\n          params:\n            disable_encoder_autocast: True\n            n_cond_frames: 1\n            n_copies: 1\n            is_ae: True\n            encoder_config:\n              target: sgm.models.autoencoder.AutoencoderKLModeOnly\n              params:\n                embed_dim: 4\n                monitor: val/rec_loss\n                ddconfig:\n                  attn_type: vanilla-xformers\n                  double_z: True\n                  z_channels: 4\n                  resolution: 256\n                  in_channels: 3\n                  out_ch: 3\n                  ch: 128\n                  ch_mult: [1, 2, 4, 4]\n                  num_res_blocks: 2\n                  attn_resolutions: []\n                  dropout: 0.0\n                lossconfig:\n                  target: torch.nn.Identity\n\n        - input_key: cond_aug\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n    first_stage_config:\n      target: sgm.models.autoencoder.AutoencodingEngine\n      params:\n        loss_config:\n          target: torch.nn.Identity\n        regularizer_config:\n          target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer\n        encoder_config: \n          target: sgm.modules.diffusionmodules.model.Encoder\n          params:\n            attn_type: vanilla\n            double_z: True\n            z_channels: 4\n            resolution: 256\n            in_channels: 3\n            out_ch: 3\n            ch: 128\n            ch_mult: [1, 2, 4, 4]\n            num_res_blocks: 2\n            attn_resolutions: []\n            dropout: 0.0\n        decoder_config:\n          target: sgm.modules.autoencoding.temporal_ae.VideoDecoder\n          params:\n            attn_type: vanilla\n            double_z: True\n            z_channels: 4\n            resolution: 256\n            in_channels: 3\n            out_ch: 3\n            ch: 128\n            ch_mult: [1, 2, 4, 4]\n            num_res_blocks: 2\n            attn_resolutions: []\n            dropout: 0.0\n            video_kernel_size: [3, 1, 1]\n\n    sampler_config:\n      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler\n      params:\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization\n          params:\n            sigma_max: 700.0\n\n        guider_config:\n          target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider\n          params:\n            max_scale: 3.0\n            min_scale: 1.5\n"
  },
  {
    "path": "scripts/sampling/configs/svd_xt_image_decoder.yaml",
    "content": "model:\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    scale_factor: 0.18215\n    disable_first_stage_autocast: True\n    ckpt_path: checkpoints/svd_xt_image_decoder.safetensors\n\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.Denoiser\n      params:\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise\n\n    network_config:\n      target: sgm.modules.diffusionmodules.video_model.VideoUNet\n      params:\n        adm_in_channels: 768\n        num_classes: sequential\n        use_checkpoint: True\n        in_channels: 8\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [4, 2, 1]\n        num_res_blocks: 2\n        channel_mult: [1, 2, 4, 4]\n        num_head_channels: 64\n        use_linear_in_transformer: True\n        transformer_depth: 1\n        context_dim: 1024\n        spatial_transformer_attn_type: softmax-xformers\n        extra_ff_mix_layer: True\n        use_spatial_context: True\n        merge_strategy: learned_with_images\n        video_kernel_size: [3, 1, 1]\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n        - is_trainable: False\n          input_key: cond_frames_without_noise\n          target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder\n          params:\n            n_cond_frames: 1\n            n_copies: 1\n            open_clip_embedding_config:\n              target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder\n              params:\n                freeze: True\n\n        - input_key: fps_id\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n        - input_key: motion_bucket_id\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n        - input_key: cond_frames\n          is_trainable: False\n          target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder\n          params:\n            disable_encoder_autocast: True\n            n_cond_frames: 1\n            n_copies: 1\n            is_ae: True\n            encoder_config:\n              target: sgm.models.autoencoder.AutoencoderKLModeOnly\n              params:\n                embed_dim: 4\n                monitor: val/rec_loss\n                ddconfig:\n                  attn_type: vanilla-xformers\n                  double_z: True\n                  z_channels: 4\n                  resolution: 256\n                  in_channels: 3\n                  out_ch: 3\n                  ch: 128\n                  ch_mult: [1, 2, 4, 4]\n                  num_res_blocks: 2\n                  attn_resolutions: []\n                  dropout: 0.0\n                lossconfig:\n                  target: torch.nn.Identity\n\n        - input_key: cond_aug\n          is_trainable: False\n          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n          params:\n            outdim: 256\n\n    first_stage_config:\n      target: sgm.models.autoencoder.AutoencoderKL\n      params:\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          attn_type: vanilla-xformers\n          double_z: True\n          z_channels: 4\n          resolution: 256\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult: [1, 2, 4, 4]\n          num_res_blocks: 2\n          attn_resolutions: []\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity\n\n    sampler_config:\n      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler\n      params:\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization\n          params:\n            sigma_max: 700.0\n\n        guider_config:\n          target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider\n          params:\n            max_scale: 3.0\n            min_scale: 1.5"
  },
  {
    "path": "scripts/sampling/simple_video_sample.py",
    "content": "import math\nimport os\nimport sys\nfrom glob import glob\nfrom pathlib import Path\nfrom typing import List, Optional\n\nsys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), \"../../\")))\nimport cv2\nimport imageio\nimport numpy as np\nimport torch\nfrom einops import rearrange, repeat\nfrom fire import Fire\nfrom omegaconf import OmegaConf\nfrom PIL import Image\nfrom rembg import remove\nfrom scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering\nfrom sgm.inference.helpers import embed_watermark\nfrom sgm.util import default, instantiate_from_config\nfrom torchvision.transforms import ToTensor\n\n\ndef sample(\n    input_path: str = \"assets/test_image.png\",  # Can either be image file or folder with image files\n    num_frames: Optional[int] = None,  # 21 for SV3D\n    num_steps: Optional[int] = None,\n    version: str = \"svd\",\n    fps_id: int = 6,\n    motion_bucket_id: int = 127,\n    cond_aug: float = 0.02,\n    seed: int = 23,\n    decoding_t: int = 14,  # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.\n    device: str = \"cuda\",\n    output_folder: Optional[str] = None,\n    elevations_deg: Optional[float | List[float]] = 10.0,  # For SV3D\n    azimuths_deg: Optional[List[float]] = None,  # For SV3D\n    image_frame_ratio: Optional[float] = None,\n    verbose: Optional[bool] = False,\n):\n    \"\"\"\n    Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each\n    image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.\n    \"\"\"\n\n    if version == \"svd\":\n        num_frames = default(num_frames, 14)\n        num_steps = default(num_steps, 25)\n        output_folder = default(output_folder, \"outputs/simple_video_sample/svd/\")\n        model_config = \"scripts/sampling/configs/svd.yaml\"\n    elif version == \"svd_xt\":\n        num_frames = default(num_frames, 25)\n        num_steps = default(num_steps, 30)\n        output_folder = default(output_folder, \"outputs/simple_video_sample/svd_xt/\")\n        model_config = \"scripts/sampling/configs/svd_xt.yaml\"\n    elif version == \"svd_image_decoder\":\n        num_frames = default(num_frames, 14)\n        num_steps = default(num_steps, 25)\n        output_folder = default(\n            output_folder, \"outputs/simple_video_sample/svd_image_decoder/\"\n        )\n        model_config = \"scripts/sampling/configs/svd_image_decoder.yaml\"\n    elif version == \"svd_xt_image_decoder\":\n        num_frames = default(num_frames, 25)\n        num_steps = default(num_steps, 30)\n        output_folder = default(\n            output_folder, \"outputs/simple_video_sample/svd_xt_image_decoder/\"\n        )\n        model_config = \"scripts/sampling/configs/svd_xt_image_decoder.yaml\"\n    elif version == \"sv3d_u\":\n        num_frames = 21\n        num_steps = default(num_steps, 50)\n        output_folder = default(output_folder, \"outputs/simple_video_sample/sv3d_u/\")\n        model_config = \"scripts/sampling/configs/sv3d_u.yaml\"\n        cond_aug = 1e-5\n    elif version == \"sv3d_p\":\n        num_frames = 21\n        num_steps = default(num_steps, 50)\n        output_folder = default(output_folder, \"outputs/simple_video_sample/sv3d_p/\")\n        model_config = \"scripts/sampling/configs/sv3d_p.yaml\"\n        cond_aug = 1e-5\n        if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):\n            elevations_deg = [elevations_deg] * num_frames\n        assert (\n            len(elevations_deg) == num_frames\n        ), f\"Please provide 1 value, or a list of {num_frames} values for elevations_deg! Given {len(elevations_deg)}\"\n        polars_rad = [np.deg2rad(90 - e) for e in elevations_deg]\n        if azimuths_deg is None:\n            azimuths_deg = np.linspace(0, 360, num_frames + 1)[1:] % 360\n        assert (\n            len(azimuths_deg) == num_frames\n        ), f\"Please provide a list of {num_frames} values for azimuths_deg! Given {len(azimuths_deg)}\"\n        azimuths_rad = [np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]\n        azimuths_rad[:-1].sort()\n    else:\n        raise ValueError(f\"Version {version} does not exist.\")\n\n    model, filter = load_model(\n        model_config,\n        device,\n        num_frames,\n        num_steps,\n        verbose,\n    )\n    torch.manual_seed(seed)\n\n    path = Path(input_path)\n    all_img_paths = []\n    if path.is_file():\n        if any([input_path.endswith(x) for x in [\"jpg\", \"jpeg\", \"png\"]]):\n            all_img_paths = [input_path]\n        else:\n            raise ValueError(\"Path is not valid image file.\")\n    elif path.is_dir():\n        all_img_paths = sorted(\n            [\n                f\n                for f in path.iterdir()\n                if f.is_file() and f.suffix.lower() in [\".jpg\", \".jpeg\", \".png\"]\n            ]\n        )\n        if len(all_img_paths) == 0:\n            raise ValueError(\"Folder does not contain any images.\")\n    else:\n        raise ValueError\n\n    for input_img_path in all_img_paths:\n        if \"sv3d\" in version:\n            image = Image.open(input_img_path)\n            if image.mode == \"RGBA\":\n                pass\n            else:\n                # remove bg\n                image.thumbnail([768, 768], Image.Resampling.LANCZOS)\n                image = remove(image.convert(\"RGBA\"), alpha_matting=True)\n\n            # resize object in frame\n            image_arr = np.array(image)\n            in_w, in_h = image_arr.shape[:2]\n            ret, mask = cv2.threshold(\n                np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY\n            )\n            x, y, w, h = cv2.boundingRect(mask)\n            max_size = max(w, h)\n            side_len = (\n                int(max_size / image_frame_ratio)\n                if image_frame_ratio is not None\n                else in_w\n            )\n            padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)\n            center = side_len // 2\n            padded_image[\n                center - h // 2 : center - h // 2 + h,\n                center - w // 2 : center - w // 2 + w,\n            ] = image_arr[y : y + h, x : x + w]\n            # resize frame to 576x576\n            rgba = Image.fromarray(padded_image).resize((576, 576), Image.LANCZOS)\n            # white bg\n            rgba_arr = np.array(rgba) / 255.0\n            rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])\n            input_image = Image.fromarray((rgb * 255).astype(np.uint8))\n\n        else:\n            with Image.open(input_img_path) as image:\n                if image.mode == \"RGBA\":\n                    image = image.convert(\"RGB\")\n                w, h = image.size\n\n                if h % 64 != 0 or w % 64 != 0:\n                    width, height = map(lambda x: x - x % 64, (w, h))\n                    input_image = input_image.resize((width, height))\n                    print(\n                        f\"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!\"\n                    )\n                input_image = np.array(image)\n                \n        image = ToTensor()(input_image)\n        image = image * 2.0 - 1.0\n\n        image = image.unsqueeze(0).to(device)\n        H, W = image.shape[2:]\n        assert image.shape[1] == 3\n        F = 8\n        C = 4\n        shape = (num_frames, C, H // F, W // F)\n        if (H, W) != (576, 1024) and \"sv3d\" not in version:\n            print(\n                \"WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`.\"\n            )\n        if (H, W) != (576, 576) and \"sv3d\" in version:\n            print(\n                \"WARNING: The conditioning frame you provided is not 576x576. This leads to suboptimal performance as model was only trained on 576x576.\"\n            )\n        if motion_bucket_id > 255:\n            print(\n                \"WARNING: High motion bucket! This may lead to suboptimal performance.\"\n            )\n\n        if fps_id < 5:\n            print(\"WARNING: Small fps value! This may lead to suboptimal performance.\")\n\n        if fps_id > 30:\n            print(\"WARNING: Large fps value! This may lead to suboptimal performance.\")\n\n        value_dict = {}\n        value_dict[\"cond_frames_without_noise\"] = image\n        value_dict[\"motion_bucket_id\"] = motion_bucket_id\n        value_dict[\"fps_id\"] = fps_id\n        value_dict[\"cond_aug\"] = cond_aug\n        value_dict[\"cond_frames\"] = image + cond_aug * torch.randn_like(image)\n        if \"sv3d_p\" in version:\n            value_dict[\"polars_rad\"] = polars_rad\n            value_dict[\"azimuths_rad\"] = azimuths_rad\n\n        with torch.no_grad():\n            with torch.autocast(device):\n                batch, batch_uc = get_batch(\n                    get_unique_embedder_keys_from_conditioner(model.conditioner),\n                    value_dict,\n                    [1, num_frames],\n                    T=num_frames,\n                    device=device,\n                )\n                c, uc = model.conditioner.get_unconditional_conditioning(\n                    batch,\n                    batch_uc=batch_uc,\n                    force_uc_zero_embeddings=[\n                        \"cond_frames\",\n                        \"cond_frames_without_noise\",\n                    ],\n                )\n\n                for k in [\"crossattn\", \"concat\"]:\n                    uc[k] = repeat(uc[k], \"b ... -> b t ...\", t=num_frames)\n                    uc[k] = rearrange(uc[k], \"b t ... -> (b t) ...\", t=num_frames)\n                    c[k] = repeat(c[k], \"b ... -> b t ...\", t=num_frames)\n                    c[k] = rearrange(c[k], \"b t ... -> (b t) ...\", t=num_frames)\n\n                randn = torch.randn(shape, device=device)\n\n                additional_model_inputs = {}\n                additional_model_inputs[\"image_only_indicator\"] = torch.zeros(\n                    2, num_frames\n                ).to(device)\n                additional_model_inputs[\"num_video_frames\"] = batch[\"num_video_frames\"]\n\n                def denoiser(input, sigma, c):\n                    return model.denoiser(\n                        model.model, input, sigma, c, **additional_model_inputs\n                    )\n\n                samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)\n                model.en_and_decode_n_samples_a_time = decoding_t\n                samples_x = model.decode_first_stage(samples_z)\n                if \"sv3d\" in version:\n                    samples_x[-1:] = value_dict[\"cond_frames_without_noise\"]\n                samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)\n\n                os.makedirs(output_folder, exist_ok=True)\n                base_count = len(glob(os.path.join(output_folder, \"*.mp4\")))\n\n                imageio.imwrite(\n                    os.path.join(output_folder, f\"{base_count:06d}.jpg\"), input_image\n                )\n\n                samples = embed_watermark(samples)\n                samples = filter(samples)\n                vid = (\n                    (rearrange(samples, \"t c h w -> t h w c\") * 255)\n                    .cpu()\n                    .numpy()\n                    .astype(np.uint8)\n                )\n                video_path = os.path.join(output_folder, f\"{base_count:06d}.mp4\")\n                imageio.mimwrite(video_path, vid)\n\n\ndef get_unique_embedder_keys_from_conditioner(conditioner):\n    return list(set([x.input_key for x in conditioner.embedders]))\n\n\ndef get_batch(keys, value_dict, N, T, device):\n    batch = {}\n    batch_uc = {}\n\n    for key in keys:\n        if key == \"fps_id\":\n            batch[key] = (\n                torch.tensor([value_dict[\"fps_id\"]])\n                .to(device)\n                .repeat(int(math.prod(N)))\n            )\n        elif key == \"motion_bucket_id\":\n            batch[key] = (\n                torch.tensor([value_dict[\"motion_bucket_id\"]])\n                .to(device)\n                .repeat(int(math.prod(N)))\n            )\n        elif key == \"cond_aug\":\n            batch[key] = repeat(\n                torch.tensor([value_dict[\"cond_aug\"]]).to(device),\n                \"1 -> b\",\n                b=math.prod(N),\n            )\n        elif key == \"cond_frames\" or key == \"cond_frames_without_noise\":\n            batch[key] = repeat(value_dict[key], \"1 ... -> b ...\", b=N[0])\n        elif key == \"polars_rad\" or key == \"azimuths_rad\":\n            batch[key] = torch.tensor(value_dict[key]).to(device).repeat(N[0])\n        else:\n            batch[key] = value_dict[key]\n\n    if T is not None:\n        batch[\"num_video_frames\"] = T\n\n    for key in batch.keys():\n        if key not in batch_uc and isinstance(batch[key], torch.Tensor):\n            batch_uc[key] = torch.clone(batch[key])\n    return batch, batch_uc\n\n\ndef load_model(\n    config: str,\n    device: str,\n    num_frames: int,\n    num_steps: int,\n    verbose: bool = False,\n):\n    config = OmegaConf.load(config)\n    if device == \"cuda\":\n        config.model.params.conditioner_config.params.emb_models[\n            0\n        ].params.open_clip_embedding_config.params.init_device = device\n\n    config.model.params.sampler_config.params.verbose = verbose\n    config.model.params.sampler_config.params.num_steps = num_steps\n    config.model.params.sampler_config.params.guider_config.params.num_frames = (\n        num_frames\n    )\n    if device == \"cuda\":\n        with torch.device(device):\n            model = instantiate_from_config(config.model).to(device).eval()\n    else:\n        model = instantiate_from_config(config.model).to(device).eval()\n\n    filter = DeepFloydDataFiltering(verbose=False, device=device)\n    return model, filter\n\n\nif __name__ == \"__main__\":\n    Fire(sample)\n"
  },
  {
    "path": "scripts/sampling/simple_video_sample_4d.py",
    "content": "import os\nimport sys\nfrom glob import glob\nfrom typing import List, Optional, Union\n\nfrom tqdm import tqdm\n\nsys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), \"../../\")))\nimport numpy as np\nimport torch\nfrom fire import Fire\n\nfrom sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder\nfrom scripts.demo.sv4d_helpers import (\n    decode_latents,\n    load_model,\n    initial_model_load,\n    read_video,\n    run_img2vid,\n    prepare_sampling,\n    prepare_inputs,\n    do_sample_per_step,\n    sample_sv3d,\n    save_video,\n    preprocess_video,\n)\n\n\ndef sample(\n    input_path: str = \"assets/sv4d_videos/test_video1.mp4\",  # Can either be image file or folder with image files\n    output_folder: Optional[str] = \"outputs/sv4d\",\n    num_steps: Optional[int] = 20,\n    sv3d_version: str = \"sv3d_u\",  # sv3d_u or sv3d_p\n    img_size: int = 576, # image resolution\n    fps_id: int = 6,\n    motion_bucket_id: int = 127,\n    cond_aug: float = 1e-5,\n    seed: int = 23,\n    encoding_t: int = 8,  # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.\n    decoding_t: int = 4,  # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.\n    device: str = \"cuda\",\n    elevations_deg: Optional[Union[float, List[float]]] = 10.0,\n    azimuths_deg: Optional[List[float]] = None,\n    image_frame_ratio: Optional[float] = 0.917,\n    verbose: Optional[bool] = False,\n    remove_bg: bool = False,\n):\n    \"\"\"\n    Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each\n    image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t` and `encoding_t`.\n    \"\"\"\n    # Set model config\n    T = 5  # number of frames per sample\n    V = 8  # number of views per sample\n    F = 8  # vae factor to downsize image->latent\n    C = 4\n    H, W = img_size, img_size\n    n_frames = 21  # number of input and output video frames\n    n_views = V + 1  # number of output video views (1 input view + 8 novel views)\n    n_views_sv3d = 21\n    subsampled_views = np.array(\n        [0, 2, 5, 7, 9, 12, 14, 16, 19]\n    )  # subsample (V+1=)9 (uniform) views from 21 SV3D views\n\n    model_config = \"scripts/sampling/configs/sv4d.yaml\"\n    version_dict = {\n        \"T\": T * V,\n        \"H\": H,\n        \"W\": W,\n        \"C\": C,\n        \"f\": F,\n        \"options\": {\n            \"discretization\": 1,\n            \"cfg\": 2.0,\n            \"num_views\": V,\n            \"sigma_min\": 0.002,\n            \"sigma_max\": 700.0,\n            \"rho\": 7.0,\n            \"guider\": 5,\n            \"num_steps\": num_steps,\n            \"force_uc_zero_embeddings\": [\n                \"cond_frames\",\n                \"cond_frames_without_noise\",\n                \"cond_view\",\n                \"cond_motion\",\n            ],\n            \"additional_guider_kwargs\": {\n                \"additional_cond_keys\": [\"cond_view\", \"cond_motion\"]\n            },\n        },\n    }\n\n    torch.manual_seed(seed)\n    os.makedirs(output_folder, exist_ok=True)\n\n    # Read input video frames i.e. images at view 0\n    print(f\"Reading {input_path}\")\n    base_count = len(glob(os.path.join(output_folder, \"*.mp4\"))) // 11\n    processed_input_path = preprocess_video(\n        input_path,\n        remove_bg=remove_bg,\n        n_frames=n_frames,\n        W=W,\n        H=H,\n        output_folder=output_folder,\n        image_frame_ratio=image_frame_ratio,\n        base_count=base_count,\n    )\n    images_v0 = read_video(processed_input_path, n_frames=n_frames, device=device)\n\n    # Get camera viewpoints\n    if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):\n        elevations_deg = [elevations_deg] * n_views_sv3d\n    assert (\n        len(elevations_deg) == n_views_sv3d\n    ), f\"Please provide 1 value, or a list of {n_views_sv3d} values for elevations_deg! Given {len(elevations_deg)}\"\n    if azimuths_deg is None:\n        azimuths_deg = np.linspace(0, 360, n_views_sv3d + 1)[1:] % 360\n    assert (\n        len(azimuths_deg) == n_views_sv3d\n    ), f\"Please provide a list of {n_views_sv3d} values for azimuths_deg! Given {len(azimuths_deg)}\"\n    polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])\n    azimuths_rad = np.array(\n        [np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]\n    )\n\n    # Sample multi-view images of the first frame using SV3D i.e. images at time 0\n    images_t0 = sample_sv3d(\n        images_v0[0],\n        n_views_sv3d,\n        num_steps,\n        sv3d_version,\n        fps_id,\n        motion_bucket_id,\n        cond_aug,\n        decoding_t,\n        device,\n        polars_rad,\n        azimuths_rad,\n        verbose,\n    )\n    images_t0 = torch.roll(images_t0, 1, 0)  # move conditioning image to first frame\n\n    # Initialize image matrix\n    img_matrix = [[None] * n_views for _ in range(n_frames)]\n    for i, v in enumerate(subsampled_views):\n        img_matrix[0][i] = images_t0[v].unsqueeze(0)\n    for t in range(n_frames):\n        img_matrix[t][0] = images_v0[t]\n\n    save_video(\n        os.path.join(output_folder, f\"{base_count:06d}_t000.mp4\"),\n        img_matrix[0],\n    )\n    # save_video(\n    #     os.path.join(output_folder, f\"{base_count:06d}_v000.mp4\"),\n    #     [img_matrix[t][0] for t in range(n_frames)],\n    # )\n\n    # Load SV4D model\n    model, filter = load_model(\n        model_config,\n        device,\n        version_dict[\"T\"],\n        num_steps,\n        verbose,\n    )\n    model = initial_model_load(model)\n    for emb in model.conditioner.embedders:\n        if isinstance(emb, VideoPredictionEmbedderWithEncoder):\n            emb.en_and_decode_n_samples_a_time = encoding_t\n    model.en_and_decode_n_samples_a_time = decoding_t\n\n    # Interleaved sampling for anchor frames\n    t0, v0 = 0, 0\n    frame_indices = np.arange(T - 1, n_frames, T - 1)  # [4, 8, 12, 16, 20]\n    view_indices = np.arange(V) + 1\n    print(f\"Sampling anchor frames {frame_indices}\")\n    image = img_matrix[t0][v0]\n    cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)\n    cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)\n    polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()\n    azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()\n    azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)\n    samples = run_img2vid(\n        version_dict, model, image, seed, polars, azims, cond_motion, cond_view, decoding_t\n    )\n    samples = samples.view(T, V, 3, H, W)\n    for i, t in enumerate(frame_indices):\n        for j, v in enumerate(view_indices):\n            if img_matrix[t][v] is None:\n                img_matrix[t][v] = samples[i, j][None] * 2 - 1\n\n    # Dense sampling for the rest\n    print(f\"Sampling dense frames:\")\n    for t0 in tqdm(np.arange(0, n_frames - 1, T - 1)):  # [0, 4, 8, 12, 16]\n        frame_indices = t0 + np.arange(T)\n        print(f\"Sampling dense frames {frame_indices}\")\n        latent_matrix = torch.randn(n_frames, n_views, C, H // F, W // F).to(\"cuda\")\n\n        polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()\n        azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()\n        azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)\n        \n        # alternate between forward and backward conditioning\n        forward_inputs, forward_frame_indices, backward_inputs, backward_frame_indices = prepare_inputs(\n            frame_indices, \n            img_matrix, \n            v0, \n            view_indices, \n            model, \n            version_dict, \n            seed, \n            polars, \n            azims\n        )\n        \n        for step in tqdm(range(num_steps)):\n            if step % 2 == 1:\n                c, uc, additional_model_inputs, sampler = forward_inputs\n                frame_indices = forward_frame_indices\n            else:\n                c, uc, additional_model_inputs, sampler = backward_inputs\n                frame_indices = backward_frame_indices\n            noisy_latents = latent_matrix[frame_indices][:, view_indices].flatten(0, 1)\n                \n            samples = do_sample_per_step(\n                model,\n                sampler,\n                noisy_latents,\n                c,\n                uc,\n                step,\n                additional_model_inputs,\n            )\n            samples = samples.view(T, V, C, H // F, W // F)\n            for i, t in enumerate(frame_indices):\n                for j, v in enumerate(view_indices):\n                    latent_matrix[t, v] = samples[i, j]\n\n        img_matrix = decode_latents(model, latent_matrix, img_matrix, frame_indices, view_indices, T)\n\n    # Save output videos\n    for v in view_indices:\n        vid_file = os.path.join(output_folder, f\"{base_count:06d}_v{v:03d}.mp4\")\n        print(f\"Saving {vid_file}\")\n        save_video(vid_file, [img_matrix[t][v] for t in range(n_frames)])\n\n    # Save diagonal video\n    diag_frames = [\n        img_matrix[t][(t // (n_frames // n_views)) % n_views] for t in range(n_frames)\n    ]\n    vid_file = os.path.join(output_folder, f\"{base_count:06d}_diag.mp4\")\n    print(f\"Saving {vid_file}\")\n    save_video(vid_file, diag_frames)\n\n\nif __name__ == \"__main__\":\n    Fire(sample)\n"
  },
  {
    "path": "scripts/sampling/simple_video_sample_4d2.py",
    "content": "import os\nimport sys\nfrom glob import glob\nfrom typing import List, Optional\n\nfrom tqdm import tqdm\n\nsys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), \"../../\")))\nimport numpy as np\nimport torch\nfrom fire import Fire\nfrom scripts.demo.sv4d_helpers import (\n    load_model,\n    preprocess_video,\n    read_video,\n    run_img2vid,\n    save_video,\n)\nfrom sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder\n\nsv4d2_configs = {\n    \"sv4d2\": {\n        \"T\": 12,  # number of frames per sample\n        \"V\": 4,  # number of views per sample\n        \"model_config\": \"scripts/sampling/configs/sv4d2.yaml\",\n        \"version_dict\": {\n            \"T\": 12 * 4,\n            \"options\": {\n                \"discretization\": 1,\n                \"cfg\": 2.0,\n                \"min_cfg\": 2.0,\n                \"num_views\": 4,\n                \"sigma_min\": 0.002,\n                \"sigma_max\": 700.0,\n                \"rho\": 7.0,\n                \"guider\": 2,\n                \"force_uc_zero_embeddings\": [\n                    \"cond_frames\",\n                    \"cond_frames_without_noise\",\n                    \"cond_view\",\n                    \"cond_motion\",\n                ],\n                \"additional_guider_kwargs\": {\n                    \"additional_cond_keys\": [\"cond_view\", \"cond_motion\"]\n                },\n            },\n        },\n    },\n    \"sv4d2_8views\": {\n        \"T\": 5,  # number of frames per sample\n        \"V\": 8,  # number of views per sample\n        \"model_config\": \"scripts/sampling/configs/sv4d2_8views.yaml\",\n        \"version_dict\": {\n            \"T\": 5 * 8,\n            \"options\": {\n                \"discretization\": 1,\n                \"cfg\": 2.5,\n                \"min_cfg\": 1.5,\n                \"num_views\": 8,\n                \"sigma_min\": 0.002,\n                \"sigma_max\": 700.0,\n                \"rho\": 7.0,\n                \"guider\": 5,\n                \"force_uc_zero_embeddings\": [\n                    \"cond_frames\",\n                    \"cond_frames_without_noise\",\n                    \"cond_view\",\n                    \"cond_motion\",\n                ],\n                \"additional_guider_kwargs\": {\n                    \"additional_cond_keys\": [\"cond_view\", \"cond_motion\"]\n                },\n            },\n        },\n    },\n}\n\n\ndef sample(\n    input_path: str = \"assets/sv4d_videos/camel.gif\",  # Can either be image file or folder with image files\n    model_path: Optional[str] = \"checkpoints/sv4d2.safetensors\",\n    output_folder: Optional[str] = \"outputs\",\n    num_steps: Optional[int] = 50,\n    img_size: int = 576,  # image resolution\n    n_frames: int = 21,  # number of input and output video frames\n    seed: int = 23,\n    encoding_t: int = 8,  # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.\n    decoding_t: int = 4,  # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.\n    device: str = \"cuda\",\n    elevations_deg: Optional[List[float]] = 0.0,\n    azimuths_deg: Optional[List[float]] = None,\n    image_frame_ratio: Optional[float] = 0.9,\n    verbose: Optional[bool] = False,\n    remove_bg: bool = False,\n):\n    \"\"\"\n    Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each\n    image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t` and `encoding_t`.\n    \"\"\"\n    # Set model config\n    assert os.path.basename(model_path) in [\n        \"sv4d2.safetensors\",\n        \"sv4d2_8views.safetensors\",\n    ]\n    sv4d2_model = os.path.splitext(os.path.basename(model_path))[0]\n    config = sv4d2_configs[sv4d2_model]\n    print(sv4d2_model, config)\n    T = config[\"T\"]\n    V = config[\"V\"]\n    model_config = config[\"model_config\"]\n    version_dict = config[\"version_dict\"]\n    F = 8  # vae factor to downsize image->latent\n    C = 4\n    H, W = img_size, img_size\n    n_views = V + 1  # number of output video views (1 input view + 8 novel views)\n    subsampled_views = np.arange(n_views)\n    version_dict[\"H\"] = H\n    version_dict[\"W\"] = W\n    version_dict[\"C\"] = C\n    version_dict[\"f\"] = F\n    version_dict[\"options\"][\"num_steps\"] = num_steps\n\n    torch.manual_seed(seed)\n    output_folder = os.path.join(output_folder, sv4d2_model)\n    os.makedirs(output_folder, exist_ok=True)\n\n    # Read input video frames i.e. images at view 0\n    print(f\"Reading {input_path}\")\n    base_count = len(glob(os.path.join(output_folder, \"*.mp4\"))) // n_views\n    processed_input_path = preprocess_video(\n        input_path,\n        remove_bg=remove_bg,\n        n_frames=n_frames,\n        W=W,\n        H=H,\n        output_folder=output_folder,\n        image_frame_ratio=image_frame_ratio,\n        base_count=base_count,\n    )\n    images_v0 = read_video(processed_input_path, n_frames=n_frames, device=device)\n    images_t0 = torch.zeros(n_views, 3, H, W).float().to(device)\n\n    # Get camera viewpoints\n    if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):\n        elevations_deg = [elevations_deg] * n_views\n    assert (\n        len(elevations_deg) == n_views\n    ), f\"Please provide 1 value, or a list of {n_views} values for elevations_deg! Given {len(elevations_deg)}\"\n    if azimuths_deg is None:\n        # azimuths_deg = np.linspace(0, 360, n_views + 1)[1:] % 360\n        azimuths_deg = (\n            np.array([0, 60, 120, 180, 240])\n            if sv4d2_model == \"sv4d2\"\n            else np.array([0, 30, 75, 120, 165, 210, 255, 300, 330])\n        )\n    assert (\n        len(azimuths_deg) == n_views\n    ), f\"Please provide a list of {n_views} values for azimuths_deg! Given {len(azimuths_deg)}\"\n    polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])\n    azimuths_rad = np.array(\n        [np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]\n    )\n\n    # Initialize image matrix\n    img_matrix = [[None] * n_views for _ in range(n_frames)]\n    for i, v in enumerate(subsampled_views):\n        img_matrix[0][i] = images_t0[v].unsqueeze(0)\n    for t in range(n_frames):\n        img_matrix[t][0] = images_v0[t]\n\n    # Load SV4D++ model\n    model, _ = load_model(\n        model_config,\n        device,\n        version_dict[\"T\"],\n        num_steps,\n        verbose,\n        model_path,\n    )\n    model.en_and_decode_n_samples_a_time = decoding_t\n    for emb in model.conditioner.embedders:\n        if isinstance(emb, VideoPredictionEmbedderWithEncoder):\n            emb.en_and_decode_n_samples_a_time = encoding_t\n\n    # Sampling novel-view videos\n    v0 = 0\n    view_indices = np.arange(V) + 1\n    t0_list = (\n        range(0, n_frames, T-1)\n        if sv4d2_model == \"sv4d2\"\n        else range(0, n_frames - T + 1, T - 1)\n    )\n    for t0 in tqdm(t0_list):\n        if t0 + T > n_frames:\n            t0 = n_frames - T\n        frame_indices = t0 + np.arange(T)\n        print(f\"Sampling frames {frame_indices}\")\n        image = img_matrix[t0][v0]\n        cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)\n        cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)\n        polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()\n        azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()\n        polars = (polars - polars_rad[v0] + torch.pi / 2) % (torch.pi * 2)\n        azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)\n        cond_mv = False if t0 == 0 else True\n        samples = run_img2vid(\n            version_dict,\n            model,\n            image,\n            seed,\n            polars,\n            azims,\n            cond_motion,\n            cond_view,\n            decoding_t,\n            cond_mv=cond_mv,\n        )\n        samples = samples.view(T, V, 3, H, W)\n\n        for i, t in enumerate(frame_indices):\n            for j, v in enumerate(view_indices):\n                img_matrix[t][v] = samples[i, j][None] * 2 - 1\n\n    # Save output videos\n    for v in view_indices:\n        vid_file = os.path.join(output_folder, f\"{base_count:06d}_v{v:03d}.mp4\")\n        print(f\"Saving {vid_file}\")\n        save_video(\n            vid_file,\n            [img_matrix[t][v] for t in range(n_frames) if img_matrix[t][v] is not None],\n        )\n\n\nif __name__ == \"__main__\":\n    Fire(sample)\n"
  },
  {
    "path": "scripts/tests/attention.py",
    "content": "import einops\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.benchmark as benchmark\nfrom torch.backends.cuda import SDPBackend\n\nfrom sgm.modules.attention import BasicTransformerBlock, SpatialTransformer\n\n\ndef benchmark_attn():\n    # Lets define a helpful benchmarking function:\n    # https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n    def benchmark_torch_function_in_microseconds(f, *args, **kwargs):\n        t0 = benchmark.Timer(\n            stmt=\"f(*args, **kwargs)\", globals={\"args\": args, \"kwargs\": kwargs, \"f\": f}\n        )\n        return t0.blocked_autorange().mean * 1e6\n\n    # Lets define the hyper-parameters of our input\n    batch_size = 32\n    max_sequence_len = 1024\n    num_heads = 32\n    embed_dimension = 32\n\n    dtype = torch.float16\n\n    query = torch.rand(\n        batch_size,\n        num_heads,\n        max_sequence_len,\n        embed_dimension,\n        device=device,\n        dtype=dtype,\n    )\n    key = torch.rand(\n        batch_size,\n        num_heads,\n        max_sequence_len,\n        embed_dimension,\n        device=device,\n        dtype=dtype,\n    )\n    value = torch.rand(\n        batch_size,\n        num_heads,\n        max_sequence_len,\n        embed_dimension,\n        device=device,\n        dtype=dtype,\n    )\n\n    print(f\"q/k/v shape:\", query.shape, key.shape, value.shape)\n\n    # Lets explore the speed of each of the 3 implementations\n    from torch.backends.cuda import SDPBackend, sdp_kernel\n\n    # Helpful arguments mapper\n    backend_map = {\n        SDPBackend.MATH: {\n            \"enable_math\": True,\n            \"enable_flash\": False,\n            \"enable_mem_efficient\": False,\n        },\n        SDPBackend.FLASH_ATTENTION: {\n            \"enable_math\": False,\n            \"enable_flash\": True,\n            \"enable_mem_efficient\": False,\n        },\n        SDPBackend.EFFICIENT_ATTENTION: {\n            \"enable_math\": False,\n            \"enable_flash\": False,\n            \"enable_mem_efficient\": True,\n        },\n    }\n\n    from torch.profiler import ProfilerActivity, profile, record_function\n\n    activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]\n\n    print(\n        f\"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds\"\n    )\n    with profile(\n        activities=activities, record_shapes=False, profile_memory=True\n    ) as prof:\n        with record_function(\"Default detailed stats\"):\n            for _ in range(25):\n                o = F.scaled_dot_product_attention(query, key, value)\n    print(prof.key_averages().table(sort_by=\"cuda_time_total\", row_limit=10))\n\n    print(\n        f\"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds\"\n    )\n    with sdp_kernel(**backend_map[SDPBackend.MATH]):\n        with profile(\n            activities=activities, record_shapes=False, profile_memory=True\n        ) as prof:\n            with record_function(\"Math implmentation stats\"):\n                for _ in range(25):\n                    o = F.scaled_dot_product_attention(query, key, value)\n        print(prof.key_averages().table(sort_by=\"cuda_time_total\", row_limit=10))\n\n    with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):\n        try:\n            print(\n                f\"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds\"\n            )\n        except RuntimeError:\n            print(\"FlashAttention is not supported. See warnings for reasons.\")\n        with profile(\n            activities=activities, record_shapes=False, profile_memory=True\n        ) as prof:\n            with record_function(\"FlashAttention stats\"):\n                for _ in range(25):\n                    o = F.scaled_dot_product_attention(query, key, value)\n        print(prof.key_averages().table(sort_by=\"cuda_time_total\", row_limit=10))\n\n    with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):\n        try:\n            print(\n                f\"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds\"\n            )\n        except RuntimeError:\n            print(\"EfficientAttention is not supported. See warnings for reasons.\")\n        with profile(\n            activities=activities, record_shapes=False, profile_memory=True\n        ) as prof:\n            with record_function(\"EfficientAttention stats\"):\n                for _ in range(25):\n                    o = F.scaled_dot_product_attention(query, key, value)\n        print(prof.key_averages().table(sort_by=\"cuda_time_total\", row_limit=10))\n\n\ndef run_model(model, x, context):\n    return model(x, context)\n\n\ndef benchmark_transformer_blocks():\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n    import torch.utils.benchmark as benchmark\n\n    def benchmark_torch_function_in_microseconds(f, *args, **kwargs):\n        t0 = benchmark.Timer(\n            stmt=\"f(*args, **kwargs)\", globals={\"args\": args, \"kwargs\": kwargs, \"f\": f}\n        )\n        return t0.blocked_autorange().mean * 1e6\n\n    checkpoint = True\n    compile = False\n\n    batch_size = 32\n    h, w = 64, 64\n    context_len = 77\n    embed_dimension = 1024\n    context_dim = 1024\n    d_head = 64\n\n    transformer_depth = 4\n\n    n_heads = embed_dimension // d_head\n\n    dtype = torch.float16\n\n    model_native = SpatialTransformer(\n        embed_dimension,\n        n_heads,\n        d_head,\n        context_dim=context_dim,\n        use_linear=True,\n        use_checkpoint=checkpoint,\n        attn_type=\"softmax\",\n        depth=transformer_depth,\n        sdp_backend=SDPBackend.FLASH_ATTENTION,\n    ).to(device)\n    model_efficient_attn = SpatialTransformer(\n        embed_dimension,\n        n_heads,\n        d_head,\n        context_dim=context_dim,\n        use_linear=True,\n        depth=transformer_depth,\n        use_checkpoint=checkpoint,\n        attn_type=\"softmax-xformers\",\n    ).to(device)\n    if not checkpoint and compile:\n        print(\"compiling models\")\n        model_native = torch.compile(model_native)\n        model_efficient_attn = torch.compile(model_efficient_attn)\n\n    x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)\n    c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)\n\n    from torch.profiler import ProfilerActivity, profile, record_function\n\n    activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]\n\n    with torch.autocast(\"cuda\"):\n        print(\n            f\"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds\"\n        )\n        print(\n            f\"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds\"\n        )\n\n        print(75 * \"+\")\n        print(\"NATIVE\")\n        print(75 * \"+\")\n        torch.cuda.reset_peak_memory_stats()\n        with profile(\n            activities=activities, record_shapes=False, profile_memory=True\n        ) as prof:\n            with record_function(\"NativeAttention stats\"):\n                for _ in range(25):\n                    model_native(x, c)\n        print(prof.key_averages().table(sort_by=\"cuda_time_total\", row_limit=10))\n        print(torch.cuda.max_memory_allocated() * 1e-9, \"GB used by native block\")\n\n        print(75 * \"+\")\n        print(\"Xformers\")\n        print(75 * \"+\")\n        torch.cuda.reset_peak_memory_stats()\n        with profile(\n            activities=activities, record_shapes=False, profile_memory=True\n        ) as prof:\n            with record_function(\"xformers stats\"):\n                for _ in range(25):\n                    model_efficient_attn(x, c)\n        print(prof.key_averages().table(sort_by=\"cuda_time_total\", row_limit=10))\n        print(torch.cuda.max_memory_allocated() * 1e-9, \"GB used by xformers block\")\n\n\ndef test01():\n    # conv1x1 vs linear\n    from sgm.util import count_params\n\n    conv = torch.nn.Conv2d(3, 32, kernel_size=1).cuda()\n    print(count_params(conv))\n    linear = torch.nn.Linear(3, 32).cuda()\n    print(count_params(linear))\n\n    print(conv.weight.shape)\n\n    # use same initialization\n    linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))\n    linear.bias = torch.nn.Parameter(conv.bias)\n\n    print(linear.weight.shape)\n\n    x = torch.randn(11, 3, 64, 64).cuda()\n\n    xr = einops.rearrange(x, \"b c h w -> b (h w) c\").contiguous()\n    print(xr.shape)\n    out_linear = linear(xr)\n    print(out_linear.mean(), out_linear.shape)\n\n    out_conv = conv(x)\n    print(out_conv.mean(), out_conv.shape)\n    print(\"done with test01.\\n\")\n\n\ndef test02():\n    # try cosine flash attention\n    import time\n\n    torch.backends.cuda.matmul.allow_tf32 = True\n    torch.backends.cudnn.allow_tf32 = True\n    torch.backends.cudnn.benchmark = True\n    print(\"testing cosine flash attention...\")\n    DIM = 1024\n    SEQLEN = 4096\n    BS = 16\n\n    print(\" softmax (vanilla) first...\")\n    model = BasicTransformerBlock(\n        dim=DIM,\n        n_heads=16,\n        d_head=64,\n        dropout=0.0,\n        context_dim=None,\n        attn_mode=\"softmax\",\n    ).cuda()\n    try:\n        x = torch.randn(BS, SEQLEN, DIM).cuda()\n        tic = time.time()\n        y = model(x)\n        toc = time.time()\n        print(y.shape, toc - tic)\n    except RuntimeError as e:\n        # likely oom\n        print(str(e))\n\n    print(\"\\n now flash-cosine...\")\n    model = BasicTransformerBlock(\n        dim=DIM,\n        n_heads=16,\n        d_head=64,\n        dropout=0.0,\n        context_dim=None,\n        attn_mode=\"flash-cosine\",\n    ).cuda()\n    x = torch.randn(BS, SEQLEN, DIM).cuda()\n    tic = time.time()\n    y = model(x)\n    toc = time.time()\n    print(y.shape, toc - tic)\n    print(\"done with test02.\\n\")\n\n\nif __name__ == \"__main__\":\n    # test01()\n    # test02()\n    # test03()\n\n    # benchmark_attn()\n    benchmark_transformer_blocks()\n\n    print(\"done.\")\n"
  },
  {
    "path": "scripts/util/__init__.py",
    "content": ""
  },
  {
    "path": "scripts/util/detection/__init__.py",
    "content": ""
  },
  {
    "path": "scripts/util/detection/nsfw_and_watermark_dectection.py",
    "content": "import os\n\nimport clip\nimport numpy as np\nimport torch\nimport torchvision.transforms as T\nfrom PIL import Image\n\nRESOURCES_ROOT = \"scripts/util/detection/\"\n\n\ndef predict_proba(X, weights, biases):\n    logits = X @ weights.T + biases\n    proba = np.where(\n        logits >= 0, 1 / (1 + np.exp(-logits)), np.exp(logits) / (1 + np.exp(logits))\n    )\n    return proba.T\n\n\ndef load_model_weights(path: str):\n    model_weights = np.load(path)\n    return model_weights[\"weights\"], model_weights[\"biases\"]\n\n\ndef clip_process_images(images: torch.Tensor) -> torch.Tensor:\n    min_size = min(images.shape[-2:])\n    return T.Compose(\n        [\n            T.CenterCrop(min_size),  # TODO: this might affect the watermark, check this\n            T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True),\n            T.Normalize(\n                (0.48145466, 0.4578275, 0.40821073),\n                (0.26862954, 0.26130258, 0.27577711),\n            ),\n        ]\n    )(images)\n\n\nclass DeepFloydDataFiltering(object):\n    def __init__(\n        self, verbose: bool = False, device: torch.device = torch.device(\"cpu\")\n    ):\n        super().__init__()\n        self.verbose = verbose\n        self._device = None\n        self.clip_model, _ = clip.load(\"ViT-L/14\", device=device)\n        self.clip_model.eval()\n\n        self.cpu_w_weights, self.cpu_w_biases = load_model_weights(\n            os.path.join(RESOURCES_ROOT, \"w_head_v1.npz\")\n        )\n        self.cpu_p_weights, self.cpu_p_biases = load_model_weights(\n            os.path.join(RESOURCES_ROOT, \"p_head_v1.npz\")\n        )\n        self.w_threshold, self.p_threshold = 0.5, 0.5\n\n    @torch.inference_mode()\n    def __call__(self, images: torch.Tensor) -> torch.Tensor:\n        imgs = clip_process_images(images)\n        if self._device is None:\n            self._device = next(p for p in self.clip_model.parameters()).device\n        image_features = self.clip_model.encode_image(imgs.to(self._device))\n        image_features = image_features.detach().cpu().numpy().astype(np.float16)\n        p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases)\n        w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases)\n        print(f\"p_pred = {p_pred}, w_pred = {w_pred}\") if self.verbose else None\n        query = p_pred > self.p_threshold\n        if query.sum() > 0:\n            print(f\"Hit for p_threshold: {p_pred}\") if self.verbose else None\n            images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query])\n        query = w_pred > self.w_threshold\n        if query.sum() > 0:\n            print(f\"Hit for w_threshold: {w_pred}\") if self.verbose else None\n            images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query])\n        return images\n\n\ndef load_img(path: str) -> torch.Tensor:\n    image = Image.open(path)\n    if not image.mode == \"RGB\":\n        image = image.convert(\"RGB\")\n    image_transforms = T.Compose(\n        [\n            T.ToTensor(),\n        ]\n    )\n    return image_transforms(image)[None, ...]\n\n\ndef test(root):\n    from einops import rearrange\n\n    filter = DeepFloydDataFiltering(verbose=True)\n    for p in os.listdir((root)):\n        print(f\"running on {p}...\")\n        img = load_img(os.path.join(root, p))\n        filtered_img = filter(img)\n        filtered_img = rearrange(\n            255.0 * (filtered_img.numpy())[0], \"c h w -> h w c\"\n        ).astype(np.uint8)\n        Image.fromarray(filtered_img).save(\n            os.path.join(root, f\"{os.path.splitext(p)[0]}-filtered.jpg\")\n        )\n\n\nif __name__ == \"__main__\":\n    import fire\n\n    fire.Fire(test)\n    print(\"done.\")\n"
  },
  {
    "path": "sgm/__init__.py",
    "content": "from .models import AutoencodingEngine, DiffusionEngine\nfrom .util import get_configs_path, instantiate_from_config\n\n__version__ = \"0.1.0\"\n"
  },
  {
    "path": "sgm/data/__init__.py",
    "content": "from .dataset import StableDataModuleFromConfig\n"
  },
  {
    "path": "sgm/data/cifar10.py",
    "content": "import pytorch_lightning as pl\nimport torchvision\nfrom torch.utils.data import DataLoader, Dataset\nfrom torchvision import transforms\n\n\nclass CIFAR10DataDictWrapper(Dataset):\n    def __init__(self, dset):\n        super().__init__()\n        self.dset = dset\n\n    def __getitem__(self, i):\n        x, y = self.dset[i]\n        return {\"jpg\": x, \"cls\": y}\n\n    def __len__(self):\n        return len(self.dset)\n\n\nclass CIFAR10Loader(pl.LightningDataModule):\n    def __init__(self, batch_size, num_workers=0, shuffle=True):\n        super().__init__()\n\n        transform = transforms.Compose(\n            [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]\n        )\n\n        self.batch_size = batch_size\n        self.num_workers = num_workers\n        self.shuffle = shuffle\n        self.train_dataset = CIFAR10DataDictWrapper(\n            torchvision.datasets.CIFAR10(\n                root=\".data/\", train=True, download=True, transform=transform\n            )\n        )\n        self.test_dataset = CIFAR10DataDictWrapper(\n            torchvision.datasets.CIFAR10(\n                root=\".data/\", train=False, download=True, transform=transform\n            )\n        )\n\n    def prepare_data(self):\n        pass\n\n    def train_dataloader(self):\n        return DataLoader(\n            self.train_dataset,\n            batch_size=self.batch_size,\n            shuffle=self.shuffle,\n            num_workers=self.num_workers,\n        )\n\n    def test_dataloader(self):\n        return DataLoader(\n            self.test_dataset,\n            batch_size=self.batch_size,\n            shuffle=self.shuffle,\n            num_workers=self.num_workers,\n        )\n\n    def val_dataloader(self):\n        return DataLoader(\n            self.test_dataset,\n            batch_size=self.batch_size,\n            shuffle=self.shuffle,\n            num_workers=self.num_workers,\n        )\n"
  },
  {
    "path": "sgm/data/dataset.py",
    "content": "from typing import Optional\n\nimport torchdata.datapipes.iter\nimport webdataset as wds\nfrom omegaconf import DictConfig\nfrom pytorch_lightning import LightningDataModule\n\ntry:\n    from sdata import create_dataset, create_dummy_dataset, create_loader\nexcept ImportError as e:\n    print(\"#\" * 100)\n    print(\"Datasets not yet available\")\n    print(\"to enable, we need to add stable-datasets as a submodule\")\n    print(\"please use ``git submodule update --init --recursive``\")\n    print(\"and do ``pip install -e stable-datasets/`` from the root of this repo\")\n    print(\"#\" * 100)\n    exit(1)\n\n\nclass StableDataModuleFromConfig(LightningDataModule):\n    def __init__(\n        self,\n        train: DictConfig,\n        validation: Optional[DictConfig] = None,\n        test: Optional[DictConfig] = None,\n        skip_val_loader: bool = False,\n        dummy: bool = False,\n    ):\n        super().__init__()\n        self.train_config = train\n        assert (\n            \"datapipeline\" in self.train_config and \"loader\" in self.train_config\n        ), \"train config requires the fields `datapipeline` and `loader`\"\n\n        self.val_config = validation\n        if not skip_val_loader:\n            if self.val_config is not None:\n                assert (\n                    \"datapipeline\" in self.val_config and \"loader\" in self.val_config\n                ), \"validation config requires the fields `datapipeline` and `loader`\"\n            else:\n                print(\n                    \"Warning: No Validation datapipeline defined, using that one from training\"\n                )\n                self.val_config = train\n\n        self.test_config = test\n        if self.test_config is not None:\n            assert (\n                \"datapipeline\" in self.test_config and \"loader\" in self.test_config\n            ), \"test config requires the fields `datapipeline` and `loader`\"\n\n        self.dummy = dummy\n        if self.dummy:\n            print(\"#\" * 100)\n            print(\"USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)\")\n            print(\"#\" * 100)\n\n    def setup(self, stage: str) -> None:\n        print(\"Preparing datasets\")\n        if self.dummy:\n            data_fn = create_dummy_dataset\n        else:\n            data_fn = create_dataset\n\n        self.train_datapipeline = data_fn(**self.train_config.datapipeline)\n        if self.val_config:\n            self.val_datapipeline = data_fn(**self.val_config.datapipeline)\n        if self.test_config:\n            self.test_datapipeline = data_fn(**self.test_config.datapipeline)\n\n    def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:\n        loader = create_loader(self.train_datapipeline, **self.train_config.loader)\n        return loader\n\n    def val_dataloader(self) -> wds.DataPipeline:\n        return create_loader(self.val_datapipeline, **self.val_config.loader)\n\n    def test_dataloader(self) -> wds.DataPipeline:\n        return create_loader(self.test_datapipeline, **self.test_config.loader)\n"
  },
  {
    "path": "sgm/data/mnist.py",
    "content": "import pytorch_lightning as pl\nimport torchvision\nfrom torch.utils.data import DataLoader, Dataset\nfrom torchvision import transforms\n\n\nclass MNISTDataDictWrapper(Dataset):\n    def __init__(self, dset):\n        super().__init__()\n        self.dset = dset\n\n    def __getitem__(self, i):\n        x, y = self.dset[i]\n        return {\"jpg\": x, \"cls\": y}\n\n    def __len__(self):\n        return len(self.dset)\n\n\nclass MNISTLoader(pl.LightningDataModule):\n    def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True):\n        super().__init__()\n\n        transform = transforms.Compose(\n            [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]\n        )\n\n        self.batch_size = batch_size\n        self.num_workers = num_workers\n        self.prefetch_factor = prefetch_factor if num_workers > 0 else 0\n        self.shuffle = shuffle\n        self.train_dataset = MNISTDataDictWrapper(\n            torchvision.datasets.MNIST(\n                root=\".data/\", train=True, download=True, transform=transform\n            )\n        )\n        self.test_dataset = MNISTDataDictWrapper(\n            torchvision.datasets.MNIST(\n                root=\".data/\", train=False, download=True, transform=transform\n            )\n        )\n\n    def prepare_data(self):\n        pass\n\n    def train_dataloader(self):\n        return DataLoader(\n            self.train_dataset,\n            batch_size=self.batch_size,\n            shuffle=self.shuffle,\n            num_workers=self.num_workers,\n            prefetch_factor=self.prefetch_factor,\n        )\n\n    def test_dataloader(self):\n        return DataLoader(\n            self.test_dataset,\n            batch_size=self.batch_size,\n            shuffle=self.shuffle,\n            num_workers=self.num_workers,\n            prefetch_factor=self.prefetch_factor,\n        )\n\n    def val_dataloader(self):\n        return DataLoader(\n            self.test_dataset,\n            batch_size=self.batch_size,\n            shuffle=self.shuffle,\n            num_workers=self.num_workers,\n            prefetch_factor=self.prefetch_factor,\n        )\n\n\nif __name__ == \"__main__\":\n    dset = MNISTDataDictWrapper(\n        torchvision.datasets.MNIST(\n            root=\".data/\",\n            train=False,\n            download=True,\n            transform=transforms.Compose(\n                [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]\n            ),\n        )\n    )\n    ex = dset[0]\n"
  },
  {
    "path": "sgm/inference/api.py",
    "content": "import pathlib\r\nfrom dataclasses import asdict, dataclass\r\nfrom enum import Enum\r\nfrom typing import Optional\r\n\r\nfrom omegaconf import OmegaConf\r\n\r\nfrom sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img,\r\n                                   do_sample)\r\nfrom sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,\r\n                                                   DPMPP2SAncestralSampler,\r\n                                                   EulerAncestralSampler,\r\n                                                   EulerEDMSampler,\r\n                                                   HeunEDMSampler,\r\n                                                   LinearMultistepSampler)\r\nfrom sgm.util import load_model_from_config\r\n\r\n\r\nclass ModelArchitecture(str, Enum):\r\n    SDXL_V0_9_BASE = \"stable-diffusion-xl-v0-9-base\"\r\n    SDXL_V0_9_REFINER = \"stable-diffusion-xl-v0-9-refiner\"\r\n    SDXL_V1_BASE = \"stable-diffusion-xl-v1-base\"\r\n    SDXL_V1_REFINER = \"stable-diffusion-xl-v1-refiner\"\r\n\r\n\r\nclass Sampler(str, Enum):\r\n    EULER_EDM = \"EulerEDMSampler\"\r\n    HEUN_EDM = \"HeunEDMSampler\"\r\n    EULER_ANCESTRAL = \"EulerAncestralSampler\"\r\n    DPMPP2S_ANCESTRAL = \"DPMPP2SAncestralSampler\"\r\n    DPMPP2M = \"DPMPP2MSampler\"\r\n    LINEAR_MULTISTEP = \"LinearMultistepSampler\"\r\n\r\n\r\nclass Discretization(str, Enum):\r\n    LEGACY_DDPM = \"LegacyDDPMDiscretization\"\r\n    EDM = \"EDMDiscretization\"\r\n\r\n\r\nclass Guider(str, Enum):\r\n    VANILLA = \"VanillaCFG\"\r\n    IDENTITY = \"IdentityGuider\"\r\n\r\n\r\nclass Thresholder(str, Enum):\r\n    NONE = \"None\"\r\n\r\n\r\n@dataclass\r\nclass SamplingParams:\r\n    width: int = 1024\r\n    height: int = 1024\r\n    steps: int = 50\r\n    sampler: Sampler = Sampler.DPMPP2M\r\n    discretization: Discretization = Discretization.LEGACY_DDPM\r\n    guider: Guider = Guider.VANILLA\r\n    thresholder: Thresholder = Thresholder.NONE\r\n    scale: float = 6.0\r\n    aesthetic_score: float = 5.0\r\n    negative_aesthetic_score: float = 5.0\r\n    img2img_strength: float = 1.0\r\n    orig_width: int = 1024\r\n    orig_height: int = 1024\r\n    crop_coords_top: int = 0\r\n    crop_coords_left: int = 0\r\n    sigma_min: float = 0.0292\r\n    sigma_max: float = 14.6146\r\n    rho: float = 3.0\r\n    s_churn: float = 0.0\r\n    s_tmin: float = 0.0\r\n    s_tmax: float = 999.0\r\n    s_noise: float = 1.0\r\n    eta: float = 1.0\r\n    order: int = 4\r\n\r\n\r\n@dataclass\r\nclass SamplingSpec:\r\n    width: int\r\n    height: int\r\n    channels: int\r\n    factor: int\r\n    is_legacy: bool\r\n    config: str\r\n    ckpt: str\r\n    is_guided: bool\r\n\r\n\r\nmodel_specs = {\r\n    ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(\r\n        height=1024,\r\n        width=1024,\r\n        channels=4,\r\n        factor=8,\r\n        is_legacy=False,\r\n        config=\"sd_xl_base.yaml\",\r\n        ckpt=\"sd_xl_base_0.9.safetensors\",\r\n        is_guided=True,\r\n    ),\r\n    ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(\r\n        height=1024,\r\n        width=1024,\r\n        channels=4,\r\n        factor=8,\r\n        is_legacy=True,\r\n        config=\"sd_xl_refiner.yaml\",\r\n        ckpt=\"sd_xl_refiner_0.9.safetensors\",\r\n        is_guided=True,\r\n    ),\r\n    ModelArchitecture.SDXL_V1_BASE: SamplingSpec(\r\n        height=1024,\r\n        width=1024,\r\n        channels=4,\r\n        factor=8,\r\n        is_legacy=False,\r\n        config=\"sd_xl_base.yaml\",\r\n        ckpt=\"sd_xl_base_1.0.safetensors\",\r\n        is_guided=True,\r\n    ),\r\n    ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(\r\n        height=1024,\r\n        width=1024,\r\n        channels=4,\r\n        factor=8,\r\n        is_legacy=True,\r\n        config=\"sd_xl_refiner.yaml\",\r\n        ckpt=\"sd_xl_refiner_1.0.safetensors\",\r\n        is_guided=True,\r\n    ),\r\n}\r\n\r\n\r\nclass SamplingPipeline:\r\n    def __init__(\r\n        self,\r\n        model_id: ModelArchitecture,\r\n        model_path=\"checkpoints\",\r\n        config_path=\"configs/inference\",\r\n        device=\"cuda\",\r\n        use_fp16=True,\r\n    ) -> None:\r\n        if model_id not in model_specs:\r\n            raise ValueError(f\"Model {model_id} not supported\")\r\n        self.model_id = model_id\r\n        self.specs = model_specs[self.model_id]\r\n        self.config = str(pathlib.Path(config_path, self.specs.config))\r\n        self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))\r\n        self.device = device\r\n        self.model = self._load_model(device=device, use_fp16=use_fp16)\r\n\r\n    def _load_model(self, device=\"cuda\", use_fp16=True):\r\n        config = OmegaConf.load(self.config)\r\n        model = load_model_from_config(config, self.ckpt)\r\n        if model is None:\r\n            raise ValueError(f\"Model {self.model_id} could not be loaded\")\r\n        model.to(device)\r\n        if use_fp16:\r\n            model.conditioner.half()\r\n            model.model.half()\r\n        return model\r\n\r\n    def text_to_image(\r\n        self,\r\n        params: SamplingParams,\r\n        prompt: str,\r\n        negative_prompt: str = \"\",\r\n        samples: int = 1,\r\n        return_latents: bool = False,\r\n    ):\r\n        sampler = get_sampler_config(params)\r\n        value_dict = asdict(params)\r\n        value_dict[\"prompt\"] = prompt\r\n        value_dict[\"negative_prompt\"] = negative_prompt\r\n        value_dict[\"target_width\"] = params.width\r\n        value_dict[\"target_height\"] = params.height\r\n        return do_sample(\r\n            self.model,\r\n            sampler,\r\n            value_dict,\r\n            samples,\r\n            params.height,\r\n            params.width,\r\n            self.specs.channels,\r\n            self.specs.factor,\r\n            force_uc_zero_embeddings=[\"txt\"] if not self.specs.is_legacy else [],\r\n            return_latents=return_latents,\r\n            filter=None,\r\n        )\r\n\r\n    def image_to_image(\r\n        self,\r\n        params: SamplingParams,\r\n        image,\r\n        prompt: str,\r\n        negative_prompt: str = \"\",\r\n        samples: int = 1,\r\n        return_latents: bool = False,\r\n    ):\r\n        sampler = get_sampler_config(params)\r\n\r\n        if params.img2img_strength < 1.0:\r\n            sampler.discretization = Img2ImgDiscretizationWrapper(\r\n                sampler.discretization,\r\n                strength=params.img2img_strength,\r\n            )\r\n        height, width = image.shape[2], image.shape[3]\r\n        value_dict = asdict(params)\r\n        value_dict[\"prompt\"] = prompt\r\n        value_dict[\"negative_prompt\"] = negative_prompt\r\n        value_dict[\"target_width\"] = width\r\n        value_dict[\"target_height\"] = height\r\n        return do_img2img(\r\n            image,\r\n            self.model,\r\n            sampler,\r\n            value_dict,\r\n            samples,\r\n            force_uc_zero_embeddings=[\"txt\"] if not self.specs.is_legacy else [],\r\n            return_latents=return_latents,\r\n            filter=None,\r\n        )\r\n\r\n    def refiner(\r\n        self,\r\n        params: SamplingParams,\r\n        image,\r\n        prompt: str,\r\n        negative_prompt: Optional[str] = None,\r\n        samples: int = 1,\r\n        return_latents: bool = False,\r\n    ):\r\n        sampler = get_sampler_config(params)\r\n        value_dict = {\r\n            \"orig_width\": image.shape[3] * 8,\r\n            \"orig_height\": image.shape[2] * 8,\r\n            \"target_width\": image.shape[3] * 8,\r\n            \"target_height\": image.shape[2] * 8,\r\n            \"prompt\": prompt,\r\n            \"negative_prompt\": negative_prompt,\r\n            \"crop_coords_top\": 0,\r\n            \"crop_coords_left\": 0,\r\n            \"aesthetic_score\": 6.0,\r\n            \"negative_aesthetic_score\": 2.5,\r\n        }\r\n\r\n        return do_img2img(\r\n            image,\r\n            self.model,\r\n            sampler,\r\n            value_dict,\r\n            samples,\r\n            skip_encode=True,\r\n            return_latents=return_latents,\r\n            filter=None,\r\n        )\r\n\r\n\r\ndef get_guider_config(params: SamplingParams):\r\n    if params.guider == Guider.IDENTITY:\r\n        guider_config = {\r\n            \"target\": \"sgm.modules.diffusionmodules.guiders.IdentityGuider\"\r\n        }\r\n    elif params.guider == Guider.VANILLA:\r\n        scale = params.scale\r\n\r\n        thresholder = params.thresholder\r\n\r\n        if thresholder == Thresholder.NONE:\r\n            dyn_thresh_config = {\r\n                \"target\": \"sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding\"\r\n            }\r\n        else:\r\n            raise NotImplementedError\r\n\r\n        guider_config = {\r\n            \"target\": \"sgm.modules.diffusionmodules.guiders.VanillaCFG\",\r\n            \"params\": {\"scale\": scale, \"dyn_thresh_config\": dyn_thresh_config},\r\n        }\r\n    else:\r\n        raise NotImplementedError\r\n    return guider_config\r\n\r\n\r\ndef get_discretization_config(params: SamplingParams):\r\n    if params.discretization == Discretization.LEGACY_DDPM:\r\n        discretization_config = {\r\n            \"target\": \"sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization\",\r\n        }\r\n    elif params.discretization == Discretization.EDM:\r\n        discretization_config = {\r\n            \"target\": \"sgm.modules.diffusionmodules.discretizer.EDMDiscretization\",\r\n            \"params\": {\r\n                \"sigma_min\": params.sigma_min,\r\n                \"sigma_max\": params.sigma_max,\r\n                \"rho\": params.rho,\r\n            },\r\n        }\r\n    else:\r\n        raise ValueError(f\"unknown discretization {params.discretization}\")\r\n    return discretization_config\r\n\r\n\r\ndef get_sampler_config(params: SamplingParams):\r\n    discretization_config = get_discretization_config(params)\r\n    guider_config = get_guider_config(params)\r\n    sampler = None\r\n    if params.sampler == Sampler.EULER_EDM:\r\n        return EulerEDMSampler(\r\n            num_steps=params.steps,\r\n            discretization_config=discretization_config,\r\n            guider_config=guider_config,\r\n            s_churn=params.s_churn,\r\n            s_tmin=params.s_tmin,\r\n            s_tmax=params.s_tmax,\r\n            s_noise=params.s_noise,\r\n            verbose=True,\r\n        )\r\n    if params.sampler == Sampler.HEUN_EDM:\r\n        return HeunEDMSampler(\r\n            num_steps=params.steps,\r\n            discretization_config=discretization_config,\r\n            guider_config=guider_config,\r\n            s_churn=params.s_churn,\r\n            s_tmin=params.s_tmin,\r\n            s_tmax=params.s_tmax,\r\n            s_noise=params.s_noise,\r\n            verbose=True,\r\n        )\r\n    if params.sampler == Sampler.EULER_ANCESTRAL:\r\n        return EulerAncestralSampler(\r\n            num_steps=params.steps,\r\n            discretization_config=discretization_config,\r\n            guider_config=guider_config,\r\n            eta=params.eta,\r\n            s_noise=params.s_noise,\r\n            verbose=True,\r\n        )\r\n    if params.sampler == Sampler.DPMPP2S_ANCESTRAL:\r\n        return DPMPP2SAncestralSampler(\r\n            num_steps=params.steps,\r\n            discretization_config=discretization_config,\r\n            guider_config=guider_config,\r\n            eta=params.eta,\r\n            s_noise=params.s_noise,\r\n            verbose=True,\r\n        )\r\n    if params.sampler == Sampler.DPMPP2M:\r\n        return DPMPP2MSampler(\r\n            num_steps=params.steps,\r\n            discretization_config=discretization_config,\r\n            guider_config=guider_config,\r\n            verbose=True,\r\n        )\r\n    if params.sampler == Sampler.LINEAR_MULTISTEP:\r\n        return LinearMultistepSampler(\r\n            num_steps=params.steps,\r\n            discretization_config=discretization_config,\r\n            guider_config=guider_config,\r\n            order=params.order,\r\n            verbose=True,\r\n        )\r\n\r\n    raise ValueError(f\"unknown sampler {params.sampler}!\")\r\n"
  },
  {
    "path": "sgm/inference/helpers.py",
    "content": "import math\nimport os\nfrom typing import List, Optional, Union\n\nimport numpy as np\nimport torch\nfrom einops import rearrange\nfrom imwatermark import WatermarkEncoder\nfrom omegaconf import ListConfig\nfrom PIL import Image\nfrom torch import autocast\n\nfrom sgm.util import append_dims\n\n\nclass WatermarkEmbedder:\n    def __init__(self, watermark):\n        self.watermark = watermark\n        self.num_bits = len(WATERMARK_BITS)\n        self.encoder = WatermarkEncoder()\n        self.encoder.set_watermark(\"bits\", self.watermark)\n\n    def __call__(self, image: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Adds a predefined watermark to the input image\n\n        Args:\n            image: ([N,] B, RGB, H, W) in range [0, 1]\n\n        Returns:\n            same as input but watermarked\n        \"\"\"\n        squeeze = len(image.shape) == 4\n        if squeeze:\n            image = image[None, ...]\n        n = image.shape[0]\n        image_np = rearrange(\n            (255 * image).detach().cpu(), \"n b c h w -> (n b) h w c\"\n        ).numpy()[:, :, :, ::-1]\n        # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]\n        # watermarking libary expects input as cv2 BGR format\n        for k in range(image_np.shape[0]):\n            image_np[k] = self.encoder.encode(image_np[k], \"dwtDct\")\n        image = torch.from_numpy(\n            rearrange(image_np[:, :, :, ::-1], \"(n b) h w c -> n b c h w\", n=n)\n        ).to(image.device)\n        image = torch.clamp(image / 255, min=0.0, max=1.0)\n        if squeeze:\n            image = image[0]\n        return image\n\n\n# A fixed 48-bit message that was choosen at random\n# WATERMARK_MESSAGE = 0xB3EC907BB19E\nWATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110\n# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1\nWATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]\nembed_watermark = WatermarkEmbedder(WATERMARK_BITS)\n\n\ndef get_unique_embedder_keys_from_conditioner(conditioner):\n    return list({x.input_key for x in conditioner.embedders})\n\n\ndef perform_save_locally(save_path, samples):\n    os.makedirs(os.path.join(save_path), exist_ok=True)\n    base_count = len(os.listdir(os.path.join(save_path)))\n    samples = embed_watermark(samples)\n    for sample in samples:\n        sample = 255.0 * rearrange(sample.cpu().numpy(), \"c h w -> h w c\")\n        Image.fromarray(sample.astype(np.uint8)).save(\n            os.path.join(save_path, f\"{base_count:09}.png\")\n        )\n        base_count += 1\n\n\nclass Img2ImgDiscretizationWrapper:\n    \"\"\"\n    wraps a discretizer, and prunes the sigmas\n    params:\n        strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)\n    \"\"\"\n\n    def __init__(self, discretization, strength: float = 1.0):\n        self.discretization = discretization\n        self.strength = strength\n        assert 0.0 <= self.strength <= 1.0\n\n    def __call__(self, *args, **kwargs):\n        # sigmas start large first, and decrease then\n        sigmas = self.discretization(*args, **kwargs)\n        print(f\"sigmas after discretization, before pruning img2img: \", sigmas)\n        sigmas = torch.flip(sigmas, (0,))\n        sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]\n        print(\"prune index:\", max(int(self.strength * len(sigmas)), 1))\n        sigmas = torch.flip(sigmas, (0,))\n        print(f\"sigmas after pruning: \", sigmas)\n        return sigmas\n\n\ndef do_sample(\n    model,\n    sampler,\n    value_dict,\n    num_samples,\n    H,\n    W,\n    C,\n    F,\n    force_uc_zero_embeddings: Optional[List] = None,\n    batch2model_input: Optional[List] = None,\n    return_latents=False,\n    filter=None,\n    device=\"cuda\",\n):\n    if force_uc_zero_embeddings is None:\n        force_uc_zero_embeddings = []\n    if batch2model_input is None:\n        batch2model_input = []\n\n    with torch.no_grad():\n        with autocast(device) as precision_scope:\n            with model.ema_scope():\n                num_samples = [num_samples]\n                batch, batch_uc = get_batch(\n                    get_unique_embedder_keys_from_conditioner(model.conditioner),\n                    value_dict,\n                    num_samples,\n                )\n                for key in batch:\n                    if isinstance(batch[key], torch.Tensor):\n                        print(key, batch[key].shape)\n                    elif isinstance(batch[key], list):\n                        print(key, [len(l) for l in batch[key]])\n                    else:\n                        print(key, batch[key])\n                c, uc = model.conditioner.get_unconditional_conditioning(\n                    batch,\n                    batch_uc=batch_uc,\n                    force_uc_zero_embeddings=force_uc_zero_embeddings,\n                )\n\n                for k in c:\n                    if not k == \"crossattn\":\n                        c[k], uc[k] = map(\n                            lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)\n                        )\n\n                additional_model_inputs = {}\n                for k in batch2model_input:\n                    additional_model_inputs[k] = batch[k]\n\n                shape = (math.prod(num_samples), C, H // F, W // F)\n                randn = torch.randn(shape).to(device)\n\n                def denoiser(input, sigma, c):\n                    return model.denoiser(\n                        model.model, input, sigma, c, **additional_model_inputs\n                    )\n\n                samples_z = sampler(denoiser, randn, cond=c, uc=uc)\n                samples_x = model.decode_first_stage(samples_z)\n                samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)\n\n                if filter is not None:\n                    samples = filter(samples)\n\n                if return_latents:\n                    return samples, samples_z\n                return samples\n\n\ndef get_batch(keys, value_dict, N: Union[List, ListConfig], device=\"cuda\"):\n    # Hardcoded demo setups; might undergo some changes in the future\n\n    batch = {}\n    batch_uc = {}\n\n    for key in keys:\n        if key == \"txt\":\n            batch[\"txt\"] = (\n                np.repeat([value_dict[\"prompt\"]], repeats=math.prod(N))\n                .reshape(N)\n                .tolist()\n            )\n            batch_uc[\"txt\"] = (\n                np.repeat([value_dict[\"negative_prompt\"]], repeats=math.prod(N))\n                .reshape(N)\n                .tolist()\n            )\n        elif key == \"original_size_as_tuple\":\n            batch[\"original_size_as_tuple\"] = (\n                torch.tensor([value_dict[\"orig_height\"], value_dict[\"orig_width\"]])\n                .to(device)\n                .repeat(*N, 1)\n            )\n        elif key == \"crop_coords_top_left\":\n            batch[\"crop_coords_top_left\"] = (\n                torch.tensor(\n                    [value_dict[\"crop_coords_top\"], value_dict[\"crop_coords_left\"]]\n                )\n                .to(device)\n                .repeat(*N, 1)\n            )\n        elif key == \"aesthetic_score\":\n            batch[\"aesthetic_score\"] = (\n                torch.tensor([value_dict[\"aesthetic_score\"]]).to(device).repeat(*N, 1)\n            )\n            batch_uc[\"aesthetic_score\"] = (\n                torch.tensor([value_dict[\"negative_aesthetic_score\"]])\n                .to(device)\n                .repeat(*N, 1)\n            )\n\n        elif key == \"target_size_as_tuple\":\n            batch[\"target_size_as_tuple\"] = (\n                torch.tensor([value_dict[\"target_height\"], value_dict[\"target_width\"]])\n                .to(device)\n                .repeat(*N, 1)\n            )\n        else:\n            batch[key] = value_dict[key]\n\n    for key in batch.keys():\n        if key not in batch_uc and isinstance(batch[key], torch.Tensor):\n            batch_uc[key] = torch.clone(batch[key])\n    return batch, batch_uc\n\n\ndef get_input_image_tensor(image: Image.Image, device=\"cuda\"):\n    w, h = image.size\n    print(f\"loaded input image of size ({w}, {h})\")\n    width, height = map(\n        lambda x: x - x % 64, (w, h)\n    )  # resize to integer multiple of 64\n    image = image.resize((width, height))\n    image_array = np.array(image.convert(\"RGB\"))\n    image_array = image_array[None].transpose(0, 3, 1, 2)\n    image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0\n    return image_tensor.to(device)\n\n\ndef do_img2img(\n    img,\n    model,\n    sampler,\n    value_dict,\n    num_samples,\n    force_uc_zero_embeddings=[],\n    additional_kwargs={},\n    offset_noise_level: float = 0.0,\n    return_latents=False,\n    skip_encode=False,\n    filter=None,\n    device=\"cuda\",\n):\n    with torch.no_grad():\n        with autocast(device) as precision_scope:\n            with model.ema_scope():\n                batch, batch_uc = get_batch(\n                    get_unique_embedder_keys_from_conditioner(model.conditioner),\n                    value_dict,\n                    [num_samples],\n                )\n                c, uc = model.conditioner.get_unconditional_conditioning(\n                    batch,\n                    batch_uc=batch_uc,\n                    force_uc_zero_embeddings=force_uc_zero_embeddings,\n                )\n\n                for k in c:\n                    c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))\n\n                for k in additional_kwargs:\n                    c[k] = uc[k] = additional_kwargs[k]\n                if skip_encode:\n                    z = img\n                else:\n                    z = model.encode_first_stage(img)\n                noise = torch.randn_like(z)\n                sigmas = sampler.discretization(sampler.num_steps)\n                sigma = sigmas[0].to(z.device)\n\n                if offset_noise_level > 0.0:\n                    noise = noise + offset_noise_level * append_dims(\n                        torch.randn(z.shape[0], device=z.device), z.ndim\n                    )\n                noised_z = z + noise * append_dims(sigma, z.ndim)\n                noised_z = noised_z / torch.sqrt(\n                    1.0 + sigmas[0] ** 2.0\n                )  # Note: hardcoded to DDPM-like scaling. need to generalize later.\n\n                def denoiser(x, sigma, c):\n                    return model.denoiser(model.model, x, sigma, c)\n\n                samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)\n                samples_x = model.decode_first_stage(samples_z)\n                samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)\n\n                if filter is not None:\n                    samples = filter(samples)\n\n                if return_latents:\n                    return samples, samples_z\n                return samples\n"
  },
  {
    "path": "sgm/lr_scheduler.py",
    "content": "import numpy as np\n\n\nclass LambdaWarmUpCosineScheduler:\n    \"\"\"\n    note: use with a base_lr of 1.0\n    \"\"\"\n\n    def __init__(\n        self,\n        warm_up_steps,\n        lr_min,\n        lr_max,\n        lr_start,\n        max_decay_steps,\n        verbosity_interval=0,\n    ):\n        self.lr_warm_up_steps = warm_up_steps\n        self.lr_start = lr_start\n        self.lr_min = lr_min\n        self.lr_max = lr_max\n        self.lr_max_decay_steps = max_decay_steps\n        self.last_lr = 0.0\n        self.verbosity_interval = verbosity_interval\n\n    def schedule(self, n, **kwargs):\n        if self.verbosity_interval > 0:\n            if n % self.verbosity_interval == 0:\n                print(f\"current step: {n}, recent lr-multiplier: {self.last_lr}\")\n        if n < self.lr_warm_up_steps:\n            lr = (\n                self.lr_max - self.lr_start\n            ) / self.lr_warm_up_steps * n + self.lr_start\n            self.last_lr = lr\n            return lr\n        else:\n            t = (n - self.lr_warm_up_steps) / (\n                self.lr_max_decay_steps - self.lr_warm_up_steps\n            )\n            t = min(t, 1.0)\n            lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (\n                1 + np.cos(t * np.pi)\n            )\n            self.last_lr = lr\n            return lr\n\n    def __call__(self, n, **kwargs):\n        return self.schedule(n, **kwargs)\n\n\nclass LambdaWarmUpCosineScheduler2:\n    \"\"\"\n    supports repeated iterations, configurable via lists\n    note: use with a base_lr of 1.0.\n    \"\"\"\n\n    def __init__(\n        self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0\n    ):\n        assert (\n            len(warm_up_steps)\n            == len(f_min)\n            == len(f_max)\n            == len(f_start)\n            == len(cycle_lengths)\n        )\n        self.lr_warm_up_steps = warm_up_steps\n        self.f_start = f_start\n        self.f_min = f_min\n        self.f_max = f_max\n        self.cycle_lengths = cycle_lengths\n        self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))\n        self.last_f = 0.0\n        self.verbosity_interval = verbosity_interval\n\n    def find_in_interval(self, n):\n        interval = 0\n        for cl in self.cum_cycles[1:]:\n            if n <= cl:\n                return interval\n            interval += 1\n\n    def schedule(self, n, **kwargs):\n        cycle = self.find_in_interval(n)\n        n = n - self.cum_cycles[cycle]\n        if self.verbosity_interval > 0:\n            if n % self.verbosity_interval == 0:\n                print(\n                    f\"current step: {n}, recent lr-multiplier: {self.last_f}, \"\n                    f\"current cycle {cycle}\"\n                )\n        if n < self.lr_warm_up_steps[cycle]:\n            f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[\n                cycle\n            ] * n + self.f_start[cycle]\n            self.last_f = f\n            return f\n        else:\n            t = (n - self.lr_warm_up_steps[cycle]) / (\n                self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]\n            )\n            t = min(t, 1.0)\n            f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (\n                1 + np.cos(t * np.pi)\n            )\n            self.last_f = f\n            return f\n\n    def __call__(self, n, **kwargs):\n        return self.schedule(n, **kwargs)\n\n\nclass LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):\n    def schedule(self, n, **kwargs):\n        cycle = self.find_in_interval(n)\n        n = n - self.cum_cycles[cycle]\n        if self.verbosity_interval > 0:\n            if n % self.verbosity_interval == 0:\n                print(\n                    f\"current step: {n}, recent lr-multiplier: {self.last_f}, \"\n                    f\"current cycle {cycle}\"\n                )\n\n        if n < self.lr_warm_up_steps[cycle]:\n            f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[\n                cycle\n            ] * n + self.f_start[cycle]\n            self.last_f = f\n            return f\n        else:\n            f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (\n                self.cycle_lengths[cycle] - n\n            ) / (self.cycle_lengths[cycle])\n            self.last_f = f\n            return f\n"
  },
  {
    "path": "sgm/models/__init__.py",
    "content": "from .autoencoder import AutoencodingEngine\nfrom .diffusion import DiffusionEngine\n"
  },
  {
    "path": "sgm/models/autoencoder.py",
    "content": "import logging\nimport math\nimport re\nfrom abc import abstractmethod\nfrom contextlib import contextmanager\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport pytorch_lightning as pl\nimport torch\nimport torch.nn as nn\nfrom einops import rearrange\nfrom packaging import version\n\nfrom ..modules.autoencoding.regularizers import AbstractRegularizer\nfrom ..modules.ema import LitEma\nfrom ..util import (default, get_nested_attribute, get_obj_from_str,\n                    instantiate_from_config)\n\nlogpy = logging.getLogger(__name__)\n\n\nclass AbstractAutoencoder(pl.LightningModule):\n    \"\"\"\n    This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,\n    unCLIP models, etc. Hence, it is fairly general, and specific features\n    (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.\n    \"\"\"\n\n    def __init__(\n        self,\n        ema_decay: Union[None, float] = None,\n        monitor: Union[None, str] = None,\n        input_key: str = \"jpg\",\n    ):\n        super().__init__()\n\n        self.input_key = input_key\n        self.use_ema = ema_decay is not None\n        if monitor is not None:\n            self.monitor = monitor\n\n        if self.use_ema:\n            self.model_ema = LitEma(self, decay=ema_decay)\n            logpy.info(f\"Keeping EMAs of {len(list(self.model_ema.buffers()))}.\")\n\n        if version.parse(torch.__version__) >= version.parse(\"2.0.0\"):\n            self.automatic_optimization = False\n\n    def apply_ckpt(self, ckpt: Union[None, str, dict]):\n        if ckpt is None:\n            return\n        if isinstance(ckpt, str):\n            ckpt = {\n                \"target\": \"sgm.modules.checkpoint.CheckpointEngine\",\n                \"params\": {\"ckpt_path\": ckpt},\n            }\n        engine = instantiate_from_config(ckpt)\n        engine(self)\n\n    @abstractmethod\n    def get_input(self, batch) -> Any:\n        raise NotImplementedError()\n\n    def on_train_batch_end(self, *args, **kwargs):\n        # for EMA computation\n        if self.use_ema:\n            self.model_ema(self)\n\n    @contextmanager\n    def ema_scope(self, context=None):\n        if self.use_ema:\n            self.model_ema.store(self.parameters())\n            self.model_ema.copy_to(self)\n            if context is not None:\n                logpy.info(f\"{context}: Switched to EMA weights\")\n        try:\n            yield None\n        finally:\n            if self.use_ema:\n                self.model_ema.restore(self.parameters())\n                if context is not None:\n                    logpy.info(f\"{context}: Restored training weights\")\n\n    @abstractmethod\n    def encode(self, *args, **kwargs) -> torch.Tensor:\n        raise NotImplementedError(\"encode()-method of abstract base class called\")\n\n    @abstractmethod\n    def decode(self, *args, **kwargs) -> torch.Tensor:\n        raise NotImplementedError(\"decode()-method of abstract base class called\")\n\n    def instantiate_optimizer_from_config(self, params, lr, cfg):\n        logpy.info(f\"loading >>> {cfg['target']} <<< optimizer from config\")\n        return get_obj_from_str(cfg[\"target\"])(\n            params, lr=lr, **cfg.get(\"params\", dict())\n        )\n\n    def configure_optimizers(self) -> Any:\n        raise NotImplementedError()\n\n\nclass AutoencodingEngine(AbstractAutoencoder):\n    \"\"\"\n    Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL\n    (we also restore them explicitly as special cases for legacy reasons).\n    Regularizations such as KL or VQ are moved to the regularizer class.\n    \"\"\"\n\n    def __init__(\n        self,\n        *args,\n        encoder_config: Dict,\n        decoder_config: Dict,\n        loss_config: Dict,\n        regularizer_config: Dict,\n        optimizer_config: Union[Dict, None] = None,\n        lr_g_factor: float = 1.0,\n        trainable_ae_params: Optional[List[List[str]]] = None,\n        ae_optimizer_args: Optional[List[dict]] = None,\n        trainable_disc_params: Optional[List[List[str]]] = None,\n        disc_optimizer_args: Optional[List[dict]] = None,\n        disc_start_iter: int = 0,\n        diff_boost_factor: float = 3.0,\n        ckpt_engine: Union[None, str, dict] = None,\n        ckpt_path: Optional[str] = None,\n        additional_decode_keys: Optional[List[str]] = None,\n        **kwargs,\n    ):\n        super().__init__(*args, **kwargs)\n        self.automatic_optimization = False  # pytorch lightning\n\n        self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)\n        self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)\n        self.loss: torch.nn.Module = instantiate_from_config(loss_config)\n        self.regularization: AbstractRegularizer = instantiate_from_config(\n            regularizer_config\n        )\n        self.optimizer_config = default(\n            optimizer_config, {\"target\": \"torch.optim.Adam\"}\n        )\n        self.diff_boost_factor = diff_boost_factor\n        self.disc_start_iter = disc_start_iter\n        self.lr_g_factor = lr_g_factor\n        self.trainable_ae_params = trainable_ae_params\n        if self.trainable_ae_params is not None:\n            self.ae_optimizer_args = default(\n                ae_optimizer_args,\n                [{} for _ in range(len(self.trainable_ae_params))],\n            )\n            assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)\n        else:\n            self.ae_optimizer_args = [{}]  # makes type consitent\n\n        self.trainable_disc_params = trainable_disc_params\n        if self.trainable_disc_params is not None:\n            self.disc_optimizer_args = default(\n                disc_optimizer_args,\n                [{} for _ in range(len(self.trainable_disc_params))],\n            )\n            assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)\n        else:\n            self.disc_optimizer_args = [{}]  # makes type consitent\n\n        if ckpt_path is not None:\n            assert ckpt_engine is None, \"Can't set ckpt_engine and ckpt_path\"\n            logpy.warn(\"Checkpoint path is deprecated, use `checkpoint_egnine` instead\")\n        self.apply_ckpt(default(ckpt_path, ckpt_engine))\n        self.additional_decode_keys = set(default(additional_decode_keys, []))\n\n    def get_input(self, batch: Dict) -> torch.Tensor:\n        # assuming unified data format, dataloader returns a dict.\n        # image tensors should be scaled to -1 ... 1 and in channels-first\n        # format (e.g., bchw instead if bhwc)\n        return batch[self.input_key]\n\n    def get_autoencoder_params(self) -> list:\n        params = []\n        if hasattr(self.loss, \"get_trainable_autoencoder_parameters\"):\n            params += list(self.loss.get_trainable_autoencoder_parameters())\n        if hasattr(self.regularization, \"get_trainable_parameters\"):\n            params += list(self.regularization.get_trainable_parameters())\n        params = params + list(self.encoder.parameters())\n        params = params + list(self.decoder.parameters())\n        return params\n\n    def get_discriminator_params(self) -> list:\n        if hasattr(self.loss, \"get_trainable_parameters\"):\n            params = list(self.loss.get_trainable_parameters())  # e.g., discriminator\n        else:\n            params = []\n        return params\n\n    def get_last_layer(self):\n        return self.decoder.get_last_layer()\n\n    def encode(\n        self,\n        x: torch.Tensor,\n        return_reg_log: bool = False,\n        unregularized: bool = False,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:\n        z = self.encoder(x)\n        if unregularized:\n            return z, dict()\n        z, reg_log = self.regularization(z)\n        if return_reg_log:\n            return z, reg_log\n        return z\n\n    def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:\n        x = self.decoder(z, **kwargs)\n        return x\n\n    def forward(\n        self, x: torch.Tensor, **additional_decode_kwargs\n    ) -> Tuple[torch.Tensor, torch.Tensor, dict]:\n        z, reg_log = self.encode(x, return_reg_log=True)\n        dec = self.decode(z, **additional_decode_kwargs)\n        return z, dec, reg_log\n\n    def inner_training_step(\n        self, batch: dict, batch_idx: int, optimizer_idx: int = 0\n    ) -> torch.Tensor:\n        x = self.get_input(batch)\n        additional_decode_kwargs = {\n            key: batch[key] for key in self.additional_decode_keys.intersection(batch)\n        }\n        z, xrec, regularization_log = self(x, **additional_decode_kwargs)\n        if hasattr(self.loss, \"forward_keys\"):\n            extra_info = {\n                \"z\": z,\n                \"optimizer_idx\": optimizer_idx,\n                \"global_step\": self.global_step,\n                \"last_layer\": self.get_last_layer(),\n                \"split\": \"train\",\n                \"regularization_log\": regularization_log,\n                \"autoencoder\": self,\n            }\n            extra_info = {k: extra_info[k] for k in self.loss.forward_keys}\n        else:\n            extra_info = dict()\n\n        if optimizer_idx == 0:\n            # autoencode\n            out_loss = self.loss(x, xrec, **extra_info)\n            if isinstance(out_loss, tuple):\n                aeloss, log_dict_ae = out_loss\n            else:\n                # simple loss function\n                aeloss = out_loss\n                log_dict_ae = {\"train/loss/rec\": aeloss.detach()}\n\n            self.log_dict(\n                log_dict_ae,\n                prog_bar=False,\n                logger=True,\n                on_step=True,\n                on_epoch=True,\n                sync_dist=False,\n            )\n            self.log(\n                \"loss\",\n                aeloss.mean().detach(),\n                prog_bar=True,\n                logger=False,\n                on_epoch=False,\n                on_step=True,\n            )\n            return aeloss\n        elif optimizer_idx == 1:\n            # discriminator\n            discloss, log_dict_disc = self.loss(x, xrec, **extra_info)\n            # -> discriminator always needs to return a tuple\n            self.log_dict(\n                log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True\n            )\n            return discloss\n        else:\n            raise NotImplementedError(f\"Unknown optimizer {optimizer_idx}\")\n\n    def training_step(self, batch: dict, batch_idx: int):\n        opts = self.optimizers()\n        if not isinstance(opts, list):\n            # Non-adversarial case\n            opts = [opts]\n        optimizer_idx = batch_idx % len(opts)\n        if self.global_step < self.disc_start_iter:\n            optimizer_idx = 0\n        opt = opts[optimizer_idx]\n        opt.zero_grad()\n        with opt.toggle_model():\n            loss = self.inner_training_step(\n                batch, batch_idx, optimizer_idx=optimizer_idx\n            )\n            self.manual_backward(loss)\n        opt.step()\n\n    def validation_step(self, batch: dict, batch_idx: int) -> Dict:\n        log_dict = self._validation_step(batch, batch_idx)\n        with self.ema_scope():\n            log_dict_ema = self._validation_step(batch, batch_idx, postfix=\"_ema\")\n            log_dict.update(log_dict_ema)\n        return log_dict\n\n    def _validation_step(self, batch: dict, batch_idx: int, postfix: str = \"\") -> Dict:\n        x = self.get_input(batch)\n\n        z, xrec, regularization_log = self(x)\n        if hasattr(self.loss, \"forward_keys\"):\n            extra_info = {\n                \"z\": z,\n                \"optimizer_idx\": 0,\n                \"global_step\": self.global_step,\n                \"last_layer\": self.get_last_layer(),\n                \"split\": \"val\" + postfix,\n                \"regularization_log\": regularization_log,\n                \"autoencoder\": self,\n            }\n            extra_info = {k: extra_info[k] for k in self.loss.forward_keys}\n        else:\n            extra_info = dict()\n        out_loss = self.loss(x, xrec, **extra_info)\n        if isinstance(out_loss, tuple):\n            aeloss, log_dict_ae = out_loss\n        else:\n            # simple loss function\n            aeloss = out_loss\n            log_dict_ae = {f\"val{postfix}/loss/rec\": aeloss.detach()}\n        full_log_dict = log_dict_ae\n\n        if \"optimizer_idx\" in extra_info:\n            extra_info[\"optimizer_idx\"] = 1\n            discloss, log_dict_disc = self.loss(x, xrec, **extra_info)\n            full_log_dict.update(log_dict_disc)\n        self.log(\n            f\"val{postfix}/loss/rec\",\n            log_dict_ae[f\"val{postfix}/loss/rec\"],\n            sync_dist=True,\n        )\n        self.log_dict(full_log_dict, sync_dist=True)\n        return full_log_dict\n\n    def get_param_groups(\n        self, parameter_names: List[List[str]], optimizer_args: List[dict]\n    ) -> Tuple[List[Dict[str, Any]], int]:\n        groups = []\n        num_params = 0\n        for names, args in zip(parameter_names, optimizer_args):\n            params = []\n            for pattern_ in names:\n                pattern_params = []\n                pattern = re.compile(pattern_)\n                for p_name, param in self.named_parameters():\n                    if re.match(pattern, p_name):\n                        pattern_params.append(param)\n                        num_params += param.numel()\n                if len(pattern_params) == 0:\n                    logpy.warn(f\"Did not find parameters for pattern {pattern_}\")\n                params.extend(pattern_params)\n            groups.append({\"params\": params, **args})\n        return groups, num_params\n\n    def configure_optimizers(self) -> List[torch.optim.Optimizer]:\n        if self.trainable_ae_params is None:\n            ae_params = self.get_autoencoder_params()\n        else:\n            ae_params, num_ae_params = self.get_param_groups(\n                self.trainable_ae_params, self.ae_optimizer_args\n            )\n            logpy.info(f\"Number of trainable autoencoder parameters: {num_ae_params:,}\")\n        if self.trainable_disc_params is None:\n            disc_params = self.get_discriminator_params()\n        else:\n            disc_params, num_disc_params = self.get_param_groups(\n                self.trainable_disc_params, self.disc_optimizer_args\n            )\n            logpy.info(\n                f\"Number of trainable discriminator parameters: {num_disc_params:,}\"\n            )\n        opt_ae = self.instantiate_optimizer_from_config(\n            ae_params,\n            default(self.lr_g_factor, 1.0) * self.learning_rate,\n            self.optimizer_config,\n        )\n        opts = [opt_ae]\n        if len(disc_params) > 0:\n            opt_disc = self.instantiate_optimizer_from_config(\n                disc_params, self.learning_rate, self.optimizer_config\n            )\n            opts.append(opt_disc)\n\n        return opts\n\n    @torch.no_grad()\n    def log_images(\n        self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs\n    ) -> dict:\n        log = dict()\n        additional_decode_kwargs = {}\n        x = self.get_input(batch)\n        additional_decode_kwargs.update(\n            {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}\n        )\n\n        _, xrec, _ = self(x, **additional_decode_kwargs)\n        log[\"inputs\"] = x\n        log[\"reconstructions\"] = xrec\n        diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)\n        diff.clamp_(0, 1.0)\n        log[\"diff\"] = 2.0 * diff - 1.0\n        # diff_boost shows location of small errors, by boosting their\n        # brightness.\n        log[\"diff_boost\"] = (\n            2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1\n        )\n        if hasattr(self.loss, \"log_images\"):\n            log.update(self.loss.log_images(x, xrec))\n        with self.ema_scope():\n            _, xrec_ema, _ = self(x, **additional_decode_kwargs)\n            log[\"reconstructions_ema\"] = xrec_ema\n            diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)\n            diff_ema.clamp_(0, 1.0)\n            log[\"diff_ema\"] = 2.0 * diff_ema - 1.0\n            log[\"diff_boost_ema\"] = (\n                2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1\n            )\n        if additional_log_kwargs:\n            additional_decode_kwargs.update(additional_log_kwargs)\n            _, xrec_add, _ = self(x, **additional_decode_kwargs)\n            log_str = \"reconstructions-\" + \"-\".join(\n                [f\"{key}={additional_log_kwargs[key]}\" for key in additional_log_kwargs]\n            )\n            log[log_str] = xrec_add\n        return log\n\n\nclass AutoencodingEngineLegacy(AutoencodingEngine):\n    def __init__(self, embed_dim: int, **kwargs):\n        self.max_batch_size = kwargs.pop(\"max_batch_size\", None)\n        ddconfig = kwargs.pop(\"ddconfig\")\n        ckpt_path = kwargs.pop(\"ckpt_path\", None)\n        ckpt_engine = kwargs.pop(\"ckpt_engine\", None)\n        super().__init__(\n            encoder_config={\n                \"target\": \"sgm.modules.diffusionmodules.model.Encoder\",\n                \"params\": ddconfig,\n            },\n            decoder_config={\n                \"target\": \"sgm.modules.diffusionmodules.model.Decoder\",\n                \"params\": ddconfig,\n            },\n            **kwargs,\n        )\n        self.quant_conv = torch.nn.Conv2d(\n            (1 + ddconfig[\"double_z\"]) * ddconfig[\"z_channels\"],\n            (1 + ddconfig[\"double_z\"]) * embed_dim,\n            1,\n        )\n        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig[\"z_channels\"], 1)\n        self.embed_dim = embed_dim\n\n        self.apply_ckpt(default(ckpt_path, ckpt_engine))\n\n    def get_autoencoder_params(self) -> list:\n        params = super().get_autoencoder_params()\n        return params\n\n    def encode(\n        self, x: torch.Tensor, return_reg_log: bool = False\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:\n        if self.max_batch_size is None:\n            z = self.encoder(x)\n            z = self.quant_conv(z)\n        else:\n            N = x.shape[0]\n            bs = self.max_batch_size\n            n_batches = int(math.ceil(N / bs))\n            z = list()\n            for i_batch in range(n_batches):\n                z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])\n                z_batch = self.quant_conv(z_batch)\n                z.append(z_batch)\n            z = torch.cat(z, 0)\n\n        z, reg_log = self.regularization(z)\n        if return_reg_log:\n            return z, reg_log\n        return z\n\n    def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:\n        if self.max_batch_size is None:\n            dec = self.post_quant_conv(z)\n            dec = self.decoder(dec, **decoder_kwargs)\n        else:\n            N = z.shape[0]\n            bs = self.max_batch_size\n            n_batches = int(math.ceil(N / bs))\n            dec = list()\n            for i_batch in range(n_batches):\n                dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])\n                dec_batch = self.decoder(dec_batch, **decoder_kwargs)\n                dec.append(dec_batch)\n            dec = torch.cat(dec, 0)\n\n        return dec\n\n\nclass AutoencoderKL(AutoencodingEngineLegacy):\n    def __init__(self, **kwargs):\n        if \"lossconfig\" in kwargs:\n            kwargs[\"loss_config\"] = kwargs.pop(\"lossconfig\")\n        super().__init__(\n            regularizer_config={\n                \"target\": (\n                    \"sgm.modules.autoencoding.regularizers\"\n                    \".DiagonalGaussianRegularizer\"\n                )\n            },\n            **kwargs,\n        )\n\n\nclass AutoencoderLegacyVQ(AutoencodingEngineLegacy):\n    def __init__(\n        self,\n        embed_dim: int,\n        n_embed: int,\n        sane_index_shape: bool = False,\n        **kwargs,\n    ):\n        if \"lossconfig\" in kwargs:\n            logpy.warn(f\"Parameter `lossconfig` is deprecated, use `loss_config`.\")\n            kwargs[\"loss_config\"] = kwargs.pop(\"lossconfig\")\n        super().__init__(\n            regularizer_config={\n                \"target\": (\n                    \"sgm.modules.autoencoding.regularizers.quantize\" \".VectorQuantizer\"\n                ),\n                \"params\": {\n                    \"n_e\": n_embed,\n                    \"e_dim\": embed_dim,\n                    \"sane_index_shape\": sane_index_shape,\n                },\n            },\n            **kwargs,\n        )\n\n\nclass IdentityFirstStage(AbstractAutoencoder):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def get_input(self, x: Any) -> Any:\n        return x\n\n    def encode(self, x: Any, *args, **kwargs) -> Any:\n        return x\n\n    def decode(self, x: Any, *args, **kwargs) -> Any:\n        return x\n\n\nclass AEIntegerWrapper(nn.Module):\n    def __init__(\n        self,\n        model: nn.Module,\n        shape: Union[None, Tuple[int, int], List[int]] = (16, 16),\n        regularization_key: str = \"regularization\",\n        encoder_kwargs: Optional[Dict[str, Any]] = None,\n    ):\n        super().__init__()\n        self.model = model\n        assert hasattr(model, \"encode\") and hasattr(\n            model, \"decode\"\n        ), \"Need AE interface\"\n        self.regularization = get_nested_attribute(model, regularization_key)\n        self.shape = shape\n        self.encoder_kwargs = default(encoder_kwargs, {\"return_reg_log\": True})\n\n    def encode(self, x) -> torch.Tensor:\n        assert (\n            not self.training\n        ), f\"{self.__class__.__name__} only supports inference currently\"\n        _, log = self.model.encode(x, **self.encoder_kwargs)\n        assert isinstance(log, dict)\n        inds = log[\"min_encoding_indices\"]\n        return rearrange(inds, \"b ... -> b (...)\")\n\n    def decode(\n        self, inds: torch.Tensor, shape: Union[None, tuple, list] = None\n    ) -> torch.Tensor:\n        # expect inds shape (b, s) with s = h*w\n        shape = default(shape, self.shape)  # Optional[(h, w)]\n        if shape is not None:\n            assert len(shape) == 2, f\"Unhandeled shape {shape}\"\n            inds = rearrange(inds, \"b (h w) -> b h w\", h=shape[0], w=shape[1])\n        h = self.regularization.get_codebook_entry(inds)  # (b, h, w, c)\n        h = rearrange(h, \"b h w c -> b c h w\")\n        return self.model.decode(h)\n\n\nclass AutoencoderKLModeOnly(AutoencodingEngineLegacy):\n    def __init__(self, **kwargs):\n        if \"lossconfig\" in kwargs:\n            kwargs[\"loss_config\"] = kwargs.pop(\"lossconfig\")\n        super().__init__(\n            regularizer_config={\n                \"target\": (\n                    \"sgm.modules.autoencoding.regularizers\"\n                    \".DiagonalGaussianRegularizer\"\n                ),\n                \"params\": {\"sample\": False},\n            },\n            **kwargs,\n        )\n"
  },
  {
    "path": "sgm/models/diffusion.py",
    "content": "import math\nfrom contextlib import contextmanager\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport pytorch_lightning as pl\nimport torch\nfrom omegaconf import ListConfig, OmegaConf\nfrom safetensors.torch import load_file as load_safetensors\nfrom torch.optim.lr_scheduler import LambdaLR\n\nfrom ..modules import UNCONDITIONAL_CONFIG\nfrom ..modules.autoencoding.temporal_ae import VideoDecoder\nfrom ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER\nfrom ..modules.ema import LitEma\nfrom ..util import (default, disabled_train, get_obj_from_str,\n                    instantiate_from_config, log_txt_as_img)\n\n\nclass DiffusionEngine(pl.LightningModule):\n    def __init__(\n        self,\n        network_config,\n        denoiser_config,\n        first_stage_config,\n        conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,\n        sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,\n        optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,\n        scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,\n        loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,\n        network_wrapper: Union[None, str] = None,\n        ckpt_path: Union[None, str] = None,\n        use_ema: bool = False,\n        ema_decay_rate: float = 0.9999,\n        scale_factor: float = 1.0,\n        disable_first_stage_autocast=False,\n        input_key: str = \"jpg\",\n        log_keys: Union[List, None] = None,\n        no_cond_log: bool = False,\n        compile_model: bool = False,\n        en_and_decode_n_samples_a_time: Optional[int] = None,\n    ):\n        super().__init__()\n        self.log_keys = log_keys\n        self.input_key = input_key\n        self.optimizer_config = default(\n            optimizer_config, {\"target\": \"torch.optim.AdamW\"}\n        )\n        model = instantiate_from_config(network_config)\n        self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(\n            model, compile_model=compile_model\n        )\n\n        self.denoiser = instantiate_from_config(denoiser_config)\n        self.sampler = (\n            instantiate_from_config(sampler_config)\n            if sampler_config is not None\n            else None\n        )\n        self.conditioner = instantiate_from_config(\n            default(conditioner_config, UNCONDITIONAL_CONFIG)\n        )\n        self.scheduler_config = scheduler_config\n        self._init_first_stage(first_stage_config)\n\n        self.loss_fn = (\n            instantiate_from_config(loss_fn_config)\n            if loss_fn_config is not None\n            else None\n        )\n\n        self.use_ema = use_ema\n        if self.use_ema:\n            self.model_ema = LitEma(self.model, decay=ema_decay_rate)\n            print(f\"Keeping EMAs of {len(list(self.model_ema.buffers()))}.\")\n\n        self.scale_factor = scale_factor\n        self.disable_first_stage_autocast = disable_first_stage_autocast\n        self.no_cond_log = no_cond_log\n\n        if ckpt_path is not None:\n            self.init_from_ckpt(ckpt_path)\n\n        self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time\n\n    def init_from_ckpt(\n        self,\n        path: str,\n    ) -> None:\n        if path.endswith(\"ckpt\"):\n            sd = torch.load(path, map_location=\"cpu\")[\"state_dict\"]\n        elif path.endswith(\"safetensors\"):\n            sd = load_safetensors(path)\n        else:\n            raise NotImplementedError\n\n        missing, unexpected = self.load_state_dict(sd, strict=False)\n        print(\n            f\"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys\"\n        )\n        if len(missing) > 0:\n            print(f\"Missing Keys: {missing}\")\n        if len(unexpected) > 0:\n            print(f\"Unexpected Keys: {unexpected}\")\n\n    def _init_first_stage(self, config):\n        model = instantiate_from_config(config).eval()\n        model.train = disabled_train\n        for param in model.parameters():\n            param.requires_grad = False\n        self.first_stage_model = model\n\n    def get_input(self, batch):\n        # assuming unified data format, dataloader returns a dict.\n        # image tensors should be scaled to -1 ... 1 and in bchw format\n        return batch[self.input_key]\n\n    @torch.no_grad()\n    def decode_first_stage(self, z):\n        z = 1.0 / self.scale_factor * z\n        n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])\n\n        n_rounds = math.ceil(z.shape[0] / n_samples)\n        all_out = []\n        with torch.autocast(\"cuda\", enabled=not self.disable_first_stage_autocast):\n            for n in range(n_rounds):\n                if isinstance(self.first_stage_model.decoder, VideoDecoder):\n                    kwargs = {\"timesteps\": len(z[n * n_samples : (n + 1) * n_samples])}\n                else:\n                    kwargs = {}\n                out = self.first_stage_model.decode(\n                    z[n * n_samples : (n + 1) * n_samples], **kwargs\n                )\n                all_out.append(out)\n        out = torch.cat(all_out, dim=0)\n        return out\n\n    @torch.no_grad()\n    def encode_first_stage(self, x):\n        n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])\n        n_rounds = math.ceil(x.shape[0] / n_samples)\n        all_out = []\n        with torch.autocast(\"cuda\", enabled=not self.disable_first_stage_autocast):\n            for n in range(n_rounds):\n                out = self.first_stage_model.encode(\n                    x[n * n_samples : (n + 1) * n_samples]\n                )\n                all_out.append(out)\n        z = torch.cat(all_out, dim=0)\n        z = self.scale_factor * z\n        return z\n\n    def forward(self, x, batch):\n        loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)\n        loss_mean = loss.mean()\n        loss_dict = {\"loss\": loss_mean}\n        return loss_mean, loss_dict\n\n    def shared_step(self, batch: Dict) -> Any:\n        x = self.get_input(batch)\n        x = self.encode_first_stage(x)\n        batch[\"global_step\"] = self.global_step\n        loss, loss_dict = self(x, batch)\n        return loss, loss_dict\n\n    def training_step(self, batch, batch_idx):\n        loss, loss_dict = self.shared_step(batch)\n\n        self.log_dict(\n            loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False\n        )\n\n        self.log(\n            \"global_step\",\n            self.global_step,\n            prog_bar=True,\n            logger=True,\n            on_step=True,\n            on_epoch=False,\n        )\n\n        if self.scheduler_config is not None:\n            lr = self.optimizers().param_groups[0][\"lr\"]\n            self.log(\n                \"lr_abs\", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False\n            )\n\n        return loss\n\n    def on_train_start(self, *args, **kwargs):\n        if self.sampler is None or self.loss_fn is None:\n            raise ValueError(\"Sampler and loss function need to be set for training.\")\n\n    def on_train_batch_end(self, *args, **kwargs):\n        if self.use_ema:\n            self.model_ema(self.model)\n\n    @contextmanager\n    def ema_scope(self, context=None):\n        if self.use_ema:\n            self.model_ema.store(self.model.parameters())\n            self.model_ema.copy_to(self.model)\n            if context is not None:\n                print(f\"{context}: Switched to EMA weights\")\n        try:\n            yield None\n        finally:\n            if self.use_ema:\n                self.model_ema.restore(self.model.parameters())\n                if context is not None:\n                    print(f\"{context}: Restored training weights\")\n\n    def instantiate_optimizer_from_config(self, params, lr, cfg):\n        return get_obj_from_str(cfg[\"target\"])(\n            params, lr=lr, **cfg.get(\"params\", dict())\n        )\n\n    def configure_optimizers(self):\n        lr = self.learning_rate\n        params = list(self.model.parameters())\n        for embedder in self.conditioner.embedders:\n            if embedder.is_trainable:\n                params = params + list(embedder.parameters())\n        opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)\n        if self.scheduler_config is not None:\n            scheduler = instantiate_from_config(self.scheduler_config)\n            print(\"Setting up LambdaLR scheduler...\")\n            scheduler = [\n                {\n                    \"scheduler\": LambdaLR(opt, lr_lambda=scheduler.schedule),\n                    \"interval\": \"step\",\n                    \"frequency\": 1,\n                }\n            ]\n            return [opt], scheduler\n        return opt\n\n    @torch.no_grad()\n    def sample(\n        self,\n        cond: Dict,\n        uc: Union[Dict, None] = None,\n        batch_size: int = 16,\n        shape: Union[None, Tuple, List] = None,\n        **kwargs,\n    ):\n        randn = torch.randn(batch_size, *shape).to(self.device)\n\n        denoiser = lambda input, sigma, c: self.denoiser(\n            self.model, input, sigma, c, **kwargs\n        )\n        samples = self.sampler(denoiser, randn, cond, uc=uc)\n        return samples\n\n    @torch.no_grad()\n    def log_conditionings(self, batch: Dict, n: int) -> Dict:\n        \"\"\"\n        Defines heuristics to log different conditionings.\n        These can be lists of strings (text-to-image), tensors, ints, ...\n        \"\"\"\n        image_h, image_w = batch[self.input_key].shape[2:]\n        log = dict()\n\n        for embedder in self.conditioner.embedders:\n            if (\n                (self.log_keys is None) or (embedder.input_key in self.log_keys)\n            ) and not self.no_cond_log:\n                x = batch[embedder.input_key][:n]\n                if isinstance(x, torch.Tensor):\n                    if x.dim() == 1:\n                        # class-conditional, convert integer to string\n                        x = [str(x[i].item()) for i in range(x.shape[0])]\n                        xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)\n                    elif x.dim() == 2:\n                        # size and crop cond and the like\n                        x = [\n                            \"x\".join([str(xx) for xx in x[i].tolist()])\n                            for i in range(x.shape[0])\n                        ]\n                        xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)\n                    else:\n                        raise NotImplementedError()\n                elif isinstance(x, (List, ListConfig)):\n                    if isinstance(x[0], str):\n                        # strings\n                        xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)\n                    else:\n                        raise NotImplementedError()\n                else:\n                    raise NotImplementedError()\n                log[embedder.input_key] = xc\n        return log\n\n    @torch.no_grad()\n    def log_images(\n        self,\n        batch: Dict,\n        N: int = 8,\n        sample: bool = True,\n        ucg_keys: List[str] = None,\n        **kwargs,\n    ) -> Dict:\n        conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]\n        if ucg_keys:\n            assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (\n                \"Each defined ucg key for sampling must be in the provided conditioner input keys,\"\n                f\"but we have {ucg_keys} vs. {conditioner_input_keys}\"\n            )\n        else:\n            ucg_keys = conditioner_input_keys\n        log = dict()\n\n        x = self.get_input(batch)\n\n        c, uc = self.conditioner.get_unconditional_conditioning(\n            batch,\n            force_uc_zero_embeddings=ucg_keys\n            if len(self.conditioner.embedders) > 0\n            else [],\n        )\n\n        sampling_kwargs = {}\n\n        N = min(x.shape[0], N)\n        x = x.to(self.device)[:N]\n        log[\"inputs\"] = x\n        z = self.encode_first_stage(x)\n        log[\"reconstructions\"] = self.decode_first_stage(z)\n        log.update(self.log_conditionings(batch, N))\n\n        for k in c:\n            if isinstance(c[k], torch.Tensor):\n                c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))\n\n        if sample:\n            with self.ema_scope(\"Plotting\"):\n                samples = self.sample(\n                    c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs\n                )\n            samples = self.decode_first_stage(samples)\n            log[\"samples\"] = samples\n        return log\n"
  },
  {
    "path": "sgm/modules/__init__.py",
    "content": "from .encoders.modules import GeneralConditioner\n\nUNCONDITIONAL_CONFIG = {\n    \"target\": \"sgm.modules.GeneralConditioner\",\n    \"params\": {\"emb_models\": []},\n}\n"
  },
  {
    "path": "sgm/modules/attention.py",
    "content": "import logging\nimport math\nfrom inspect import isfunction\nfrom typing import Any, Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\nfrom packaging import version\nfrom torch import nn\nfrom torch.utils.checkpoint import checkpoint\n\nlogpy = logging.getLogger(__name__)\n\nif version.parse(torch.__version__) >= version.parse(\"2.0.0\"):\n    SDP_IS_AVAILABLE = True\n    from torch.backends.cuda import SDPBackend, sdp_kernel\n\n    BACKEND_MAP = {\n        SDPBackend.MATH: {\n            \"enable_math\": True,\n            \"enable_flash\": False,\n            \"enable_mem_efficient\": False,\n        },\n        SDPBackend.FLASH_ATTENTION: {\n            \"enable_math\": False,\n            \"enable_flash\": True,\n            \"enable_mem_efficient\": False,\n        },\n        SDPBackend.EFFICIENT_ATTENTION: {\n            \"enable_math\": False,\n            \"enable_flash\": False,\n            \"enable_mem_efficient\": True,\n        },\n        None: {\"enable_math\": True, \"enable_flash\": True, \"enable_mem_efficient\": True},\n    }\nelse:\n    from contextlib import nullcontext\n\n    SDP_IS_AVAILABLE = False\n    sdp_kernel = nullcontext\n    BACKEND_MAP = {}\n    logpy.warn(\n        f\"No SDP backend available, likely because you are running in pytorch \"\n        f\"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. \"\n        f\"You might want to consider upgrading.\"\n    )\n\ntry:\n    import xformers\n    import xformers.ops\n\n    XFORMERS_IS_AVAILABLE = True\nexcept:\n    XFORMERS_IS_AVAILABLE = False\n    logpy.warn(\"no module 'xformers'. Processing without...\")\n\n# from .diffusionmodules.util import mixed_checkpoint as checkpoint\n\n\ndef exists(val):\n    return val is not None\n\n\ndef uniq(arr):\n    return {el: True for el in arr}.keys()\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\n\ndef max_neg_value(t):\n    return -torch.finfo(t.dtype).max\n\n\ndef init_(tensor):\n    dim = tensor.shape[-1]\n    std = 1 / math.sqrt(dim)\n    tensor.uniform_(-std, std)\n    return tensor\n\n\n# feedforward\nclass GEGLU(nn.Module):\n    def __init__(self, dim_in, dim_out):\n        super().__init__()\n        self.proj = nn.Linear(dim_in, dim_out * 2)\n\n    def forward(self, x):\n        x, gate = self.proj(x).chunk(2, dim=-1)\n        return x * F.gelu(gate)\n\n\nclass FeedForward(nn.Module):\n    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):\n        super().__init__()\n        inner_dim = int(dim * mult)\n        dim_out = default(dim_out, dim)\n        project_in = (\n            nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())\n            if not glu\n            else GEGLU(dim, inner_dim)\n        )\n\n        self.net = nn.Sequential(\n            project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)\n        )\n\n    def forward(self, x):\n        return self.net(x)\n\n\ndef zero_module(module):\n    \"\"\"\n    Zero out the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().zero_()\n    return module\n\n\ndef Normalize(in_channels):\n    return torch.nn.GroupNorm(\n        num_groups=32, num_channels=in_channels, eps=1e-6, affine=True\n    )\n\n\nclass LinearAttention(nn.Module):\n    def __init__(self, dim, heads=4, dim_head=32):\n        super().__init__()\n        self.heads = heads\n        hidden_dim = dim_head * heads\n        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)\n        self.to_out = nn.Conv2d(hidden_dim, dim, 1)\n\n    def forward(self, x):\n        b, c, h, w = x.shape\n        qkv = self.to_qkv(x)\n        q, k, v = rearrange(\n            qkv, \"b (qkv heads c) h w -> qkv b heads c (h w)\", heads=self.heads, qkv=3\n        )\n        k = k.softmax(dim=-1)\n        context = torch.einsum(\"bhdn,bhen->bhde\", k, v)\n        out = torch.einsum(\"bhde,bhdn->bhen\", context, q)\n        out = rearrange(\n            out, \"b heads c (h w) -> b (heads c) h w\", heads=self.heads, h=h, w=w\n        )\n        return self.to_out(out)\n\n\nclass SelfAttention(nn.Module):\n    ATTENTION_MODES = (\"xformers\", \"torch\", \"math\")\n\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int = 8,\n        qkv_bias: bool = False,\n        qk_scale: Optional[float] = None,\n        attn_drop: float = 0.0,\n        proj_drop: float = 0.0,\n        attn_mode: str = \"xformers\",\n    ):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim**-0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n        assert attn_mode in self.ATTENTION_MODES\n        self.attn_mode = attn_mode\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        B, L, C = x.shape\n\n        qkv = self.qkv(x)\n        if self.attn_mode == \"torch\":\n            qkv = rearrange(\n                qkv, \"B L (K H D) -> K B H L D\", K=3, H=self.num_heads\n            ).float()\n            q, k, v = qkv[0], qkv[1], qkv[2]  # B H L D\n            x = torch.nn.functional.scaled_dot_product_attention(q, k, v)\n            x = rearrange(x, \"B H L D -> B L (H D)\")\n        elif self.attn_mode == \"xformers\":\n            qkv = rearrange(qkv, \"B L (K H D) -> K B L H D\", K=3, H=self.num_heads)\n            q, k, v = qkv[0], qkv[1], qkv[2]  # B L H D\n            x = xformers.ops.memory_efficient_attention(q, k, v)\n            x = rearrange(x, \"B L H D -> B L (H D)\", H=self.num_heads)\n        elif self.attn_mode == \"math\":\n            qkv = rearrange(qkv, \"B L (K H D) -> K B H L D\", K=3, H=self.num_heads)\n            q, k, v = qkv[0], qkv[1], qkv[2]  # B H L D\n            attn = (q @ k.transpose(-2, -1)) * self.scale\n            attn = attn.softmax(dim=-1)\n            attn = self.attn_drop(attn)\n            x = (attn @ v).transpose(1, 2).reshape(B, L, C)\n        else:\n            raise NotImplemented\n\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass SpatialSelfAttention(nn.Module):\n    def __init__(self, in_channels):\n        super().__init__()\n        self.in_channels = in_channels\n\n        self.norm = Normalize(in_channels)\n        self.q = torch.nn.Conv2d(\n            in_channels, in_channels, kernel_size=1, stride=1, padding=0\n        )\n        self.k = torch.nn.Conv2d(\n            in_channels, in_channels, kernel_size=1, stride=1, padding=0\n        )\n        self.v = torch.nn.Conv2d(\n            in_channels, in_channels, kernel_size=1, stride=1, padding=0\n        )\n        self.proj_out = torch.nn.Conv2d(\n            in_channels, in_channels, kernel_size=1, stride=1, padding=0\n        )\n\n    def forward(self, x):\n        h_ = x\n        h_ = self.norm(h_)\n        q = self.q(h_)\n        k = self.k(h_)\n        v = self.v(h_)\n\n        # compute attention\n        b, c, h, w = q.shape\n        q = rearrange(q, \"b c h w -> b (h w) c\")\n        k = rearrange(k, \"b c h w -> b c (h w)\")\n        w_ = torch.einsum(\"bij,bjk->bik\", q, k)\n\n        w_ = w_ * (int(c) ** (-0.5))\n        w_ = torch.nn.functional.softmax(w_, dim=2)\n\n        # attend to values\n        v = rearrange(v, \"b c h w -> b c (h w)\")\n        w_ = rearrange(w_, \"b i j -> b j i\")\n        h_ = torch.einsum(\"bij,bjk->bik\", v, w_)\n        h_ = rearrange(h_, \"b c (h w) -> b c h w\", h=h)\n        h_ = self.proj_out(h_)\n\n        return x + h_\n\n\nclass CrossAttention(nn.Module):\n    def __init__(\n        self,\n        query_dim,\n        context_dim=None,\n        heads=8,\n        dim_head=64,\n        dropout=0.0,\n        backend=None,\n    ):\n        super().__init__()\n        inner_dim = dim_head * heads\n        context_dim = default(context_dim, query_dim)\n\n        self.scale = dim_head**-0.5\n        self.heads = heads\n\n        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)\n        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)\n        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)\n\n        self.to_out = nn.Sequential(\n            nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)\n        )\n        self.backend = backend\n\n    def forward(\n        self,\n        x,\n        context=None,\n        mask=None,\n        additional_tokens=None,\n        n_times_crossframe_attn_in_self=0,\n    ):\n        h = self.heads\n\n        if additional_tokens is not None:\n            # get the number of masked tokens at the beginning of the output sequence\n            n_tokens_to_mask = additional_tokens.shape[1]\n            # add additional token\n            x = torch.cat([additional_tokens, x], dim=1)\n\n        q = self.to_q(x)\n        context = default(context, x)\n        k = self.to_k(context)\n        v = self.to_v(context)\n\n        if n_times_crossframe_attn_in_self:\n            # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439\n            assert x.shape[0] % n_times_crossframe_attn_in_self == 0\n            n_cp = x.shape[0] // n_times_crossframe_attn_in_self\n            k = repeat(\n                k[::n_times_crossframe_attn_in_self], \"b ... -> (b n) ...\", n=n_cp\n            )\n            v = repeat(\n                v[::n_times_crossframe_attn_in_self], \"b ... -> (b n) ...\", n=n_cp\n            )\n\n        q, k, v = map(lambda t: rearrange(t, \"b n (h d) -> b h n d\", h=h), (q, k, v))\n\n        ## old\n        \"\"\"\n        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale\n        del q, k\n\n        if exists(mask):\n            mask = rearrange(mask, 'b ... -> b (...)')\n            max_neg_value = -torch.finfo(sim.dtype).max\n            mask = repeat(mask, 'b j -> (b h) () j', h=h)\n            sim.masked_fill_(~mask, max_neg_value)\n\n        # attention, what we cannot get enough of\n        sim = sim.softmax(dim=-1)\n\n        out = einsum('b i j, b j d -> b i d', sim, v)\n        \"\"\"\n        ## new\n        with sdp_kernel(**BACKEND_MAP[self.backend]):\n            # print(\"dispatching into backend\", self.backend, \"q/k/v shape: \", q.shape, k.shape, v.shape)\n            out = F.scaled_dot_product_attention(\n                q, k, v, attn_mask=mask\n            )  # scale is dim_head ** -0.5 per default\n\n        del q, k, v\n        out = rearrange(out, \"b h n d -> b n (h d)\", h=h)\n\n        if additional_tokens is not None:\n            # remove additional token\n            out = out[:, n_tokens_to_mask:]\n        return self.to_out(out)\n\n\nclass MemoryEfficientCrossAttention(nn.Module):\n    # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223\n    def __init__(\n        self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs\n    ):\n        super().__init__()\n        logpy.debug(\n            f\"Setting up {self.__class__.__name__}. Query dim is {query_dim}, \"\n            f\"context_dim is {context_dim} and using {heads} heads with a \"\n            f\"dimension of {dim_head}.\"\n        )\n        inner_dim = dim_head * heads\n        context_dim = default(context_dim, query_dim)\n\n        self.heads = heads\n        self.dim_head = dim_head\n\n        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)\n        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)\n        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)\n\n        self.to_out = nn.Sequential(\n            nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)\n        )\n        self.attention_op: Optional[Any] = None\n\n    def forward(\n        self,\n        x,\n        context=None,\n        mask=None,\n        additional_tokens=None,\n        n_times_crossframe_attn_in_self=0,\n    ):\n        if additional_tokens is not None:\n            # get the number of masked tokens at the beginning of the output sequence\n            n_tokens_to_mask = additional_tokens.shape[1]\n            # add additional token\n            x = torch.cat([additional_tokens, x], dim=1)\n        q = self.to_q(x)\n        context = default(context, x)\n        k = self.to_k(context)\n        v = self.to_v(context)\n\n        if n_times_crossframe_attn_in_self:\n            # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439\n            assert x.shape[0] % n_times_crossframe_attn_in_self == 0\n            # n_cp = x.shape[0]//n_times_crossframe_attn_in_self\n            k = repeat(\n                k[::n_times_crossframe_attn_in_self],\n                \"b ... -> (b n) ...\",\n                n=n_times_crossframe_attn_in_self,\n            )\n            v = repeat(\n                v[::n_times_crossframe_attn_in_self],\n                \"b ... -> (b n) ...\",\n                n=n_times_crossframe_attn_in_self,\n            )\n\n        b, _, _ = q.shape\n        q, k, v = map(\n            lambda t: t.unsqueeze(3)\n            .reshape(b, t.shape[1], self.heads, self.dim_head)\n            .permute(0, 2, 1, 3)\n            .reshape(b * self.heads, t.shape[1], self.dim_head)\n            .contiguous(),\n            (q, k, v),\n        )\n\n        # actually compute the attention, what we cannot get enough of\n        if version.parse(xformers.__version__) >= version.parse(\"0.0.21\"):\n            # NOTE: workaround for\n            # https://github.com/facebookresearch/xformers/issues/845\n            max_bs = 32768\n            N = q.shape[0]\n            n_batches = math.ceil(N / max_bs)\n            out = list()\n            for i_batch in range(n_batches):\n                batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)\n                out.append(\n                    xformers.ops.memory_efficient_attention(\n                        q[batch],\n                        k[batch],\n                        v[batch],\n                        attn_bias=None,\n                        op=self.attention_op,\n                    )\n                )\n            out = torch.cat(out, 0)\n        else:\n            out = xformers.ops.memory_efficient_attention(\n                q, k, v, attn_bias=None, op=self.attention_op\n            )\n\n        # TODO: Use this directly in the attention operation, as a bias\n        if exists(mask):\n            raise NotImplementedError\n        out = (\n            out.unsqueeze(0)\n            .reshape(b, self.heads, out.shape[1], self.dim_head)\n            .permute(0, 2, 1, 3)\n            .reshape(b, out.shape[1], self.heads * self.dim_head)\n        )\n        if additional_tokens is not None:\n            # remove additional token\n            out = out[:, n_tokens_to_mask:]\n        return self.to_out(out)\n\n\nclass BasicTransformerBlock(nn.Module):\n    ATTENTION_MODES = {\n        \"softmax\": CrossAttention,  # vanilla attention\n        \"softmax-xformers\": MemoryEfficientCrossAttention,  # ampere\n    }\n\n    def __init__(\n        self,\n        dim,\n        n_heads,\n        d_head,\n        dropout=0.0,\n        context_dim=None,\n        gated_ff=True,\n        checkpoint=True,\n        disable_self_attn=False,\n        attn_mode=\"softmax\",\n        sdp_backend=None,\n    ):\n        super().__init__()\n        assert attn_mode in self.ATTENTION_MODES\n        if attn_mode != \"softmax\" and not XFORMERS_IS_AVAILABLE:\n            logpy.warn(\n                f\"Attention mode '{attn_mode}' is not available. Falling \"\n                f\"back to native attention. This is not a problem in \"\n                f\"Pytorch >= 2.0. FYI, you are running with PyTorch \"\n                f\"version {torch.__version__}.\"\n            )\n            attn_mode = \"softmax\"\n        elif attn_mode == \"softmax\" and not SDP_IS_AVAILABLE:\n            logpy.warn(\n                \"We do not support vanilla attention anymore, as it is too \"\n                \"expensive. Sorry.\"\n            )\n            if not XFORMERS_IS_AVAILABLE:\n                assert (\n                    False\n                ), \"Please install xformers via e.g. 'pip install xformers==0.0.16'\"\n            else:\n                logpy.info(\"Falling back to xformers efficient attention.\")\n                attn_mode = \"softmax-xformers\"\n        attn_cls = self.ATTENTION_MODES[attn_mode]\n        if version.parse(torch.__version__) >= version.parse(\"2.0.0\"):\n            assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)\n        else:\n            assert sdp_backend is None\n        self.disable_self_attn = disable_self_attn\n        self.attn1 = attn_cls(\n            query_dim=dim,\n            heads=n_heads,\n            dim_head=d_head,\n            dropout=dropout,\n            context_dim=context_dim if self.disable_self_attn else None,\n            backend=sdp_backend,\n        )  # is a self-attention if not self.disable_self_attn\n        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)\n        self.attn2 = attn_cls(\n            query_dim=dim,\n            context_dim=context_dim,\n            heads=n_heads,\n            dim_head=d_head,\n            dropout=dropout,\n            backend=sdp_backend,\n        )  # is self-attn if context is none\n        self.norm1 = nn.LayerNorm(dim)\n        self.norm2 = nn.LayerNorm(dim)\n        self.norm3 = nn.LayerNorm(dim)\n        self.checkpoint = checkpoint\n        if self.checkpoint:\n            logpy.debug(f\"{self.__class__.__name__} is using checkpointing\")\n\n    def forward(\n        self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0\n    ):\n        kwargs = {\"x\": x}\n\n        if context is not None:\n            kwargs.update({\"context\": context})\n\n        if additional_tokens is not None:\n            kwargs.update({\"additional_tokens\": additional_tokens})\n\n        if n_times_crossframe_attn_in_self:\n            kwargs.update(\n                {\"n_times_crossframe_attn_in_self\": n_times_crossframe_attn_in_self}\n            )\n\n        # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)\n        if self.checkpoint:\n            # inputs = {\"x\": x, \"context\": context}\n            return checkpoint(self._forward, x, context)\n            # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)\n        else:\n            return self._forward(**kwargs)\n\n    def _forward(\n        self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0\n    ):\n        x = (\n            self.attn1(\n                self.norm1(x),\n                context=context if self.disable_self_attn else None,\n                additional_tokens=additional_tokens,\n                n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self\n                if not self.disable_self_attn\n                else 0,\n            )\n            + x\n        )\n        x = (\n            self.attn2(\n                self.norm2(x), context=context, additional_tokens=additional_tokens\n            )\n            + x\n        )\n        x = self.ff(self.norm3(x)) + x\n        return x\n\n\nclass BasicTransformerSingleLayerBlock(nn.Module):\n    ATTENTION_MODES = {\n        \"softmax\": CrossAttention,  # vanilla attention\n        \"softmax-xformers\": MemoryEfficientCrossAttention  # on the A100s not quite as fast as the above version\n        # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])\n    }\n\n    def __init__(\n        self,\n        dim,\n        n_heads,\n        d_head,\n        dropout=0.0,\n        context_dim=None,\n        gated_ff=True,\n        checkpoint=True,\n        attn_mode=\"softmax\",\n    ):\n        super().__init__()\n        assert attn_mode in self.ATTENTION_MODES\n        attn_cls = self.ATTENTION_MODES[attn_mode]\n        self.attn1 = attn_cls(\n            query_dim=dim,\n            heads=n_heads,\n            dim_head=d_head,\n            dropout=dropout,\n            context_dim=context_dim,\n        )\n        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)\n        self.norm1 = nn.LayerNorm(dim)\n        self.norm2 = nn.LayerNorm(dim)\n        self.checkpoint = checkpoint\n\n    def forward(self, x, context=None):\n        # inputs = {\"x\": x, \"context\": context}\n        # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)\n        return checkpoint(self._forward, x, context)\n\n    def _forward(self, x, context=None):\n        x = self.attn1(self.norm1(x), context=context) + x\n        x = self.ff(self.norm2(x)) + x\n        return x\n\n\nclass SpatialTransformer(nn.Module):\n    \"\"\"\n    Transformer block for image-like data.\n    First, project the input (aka embedding)\n    and reshape to b, t, d.\n    Then apply standard transformer action.\n    Finally, reshape to image\n    NEW: use_linear for more efficiency instead of the 1x1 convs\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels,\n        n_heads,\n        d_head,\n        depth=1,\n        dropout=0.0,\n        context_dim=None,\n        disable_self_attn=False,\n        use_linear=False,\n        attn_type=\"softmax\",\n        use_checkpoint=True,\n        # sdp_backend=SDPBackend.FLASH_ATTENTION\n        sdp_backend=None,\n    ):\n        super().__init__()\n        logpy.debug(\n            f\"constructing {self.__class__.__name__} of depth {depth} w/ \"\n            f\"{in_channels} channels and {n_heads} heads.\"\n        )\n\n        if exists(context_dim) and not isinstance(context_dim, list):\n            context_dim = [context_dim]\n        if exists(context_dim) and isinstance(context_dim, list):\n            if depth != len(context_dim):\n                logpy.warn(\n                    f\"{self.__class__.__name__}: Found context dims \"\n                    f\"{context_dim} of depth {len(context_dim)}, which does not \"\n                    f\"match the specified 'depth' of {depth}. Setting context_dim \"\n                    f\"to {depth * [context_dim[0]]} now.\"\n                )\n                # depth does not match context dims.\n                assert all(\n                    map(lambda x: x == context_dim[0], context_dim)\n                ), \"need homogenous context_dim to match depth automatically\"\n                context_dim = depth * [context_dim[0]]\n        elif context_dim is None:\n            context_dim = [None] * depth\n        self.in_channels = in_channels\n        inner_dim = n_heads * d_head\n        self.norm = Normalize(in_channels)\n        if not use_linear:\n            self.proj_in = nn.Conv2d(\n                in_channels, inner_dim, kernel_size=1, stride=1, padding=0\n            )\n        else:\n            self.proj_in = nn.Linear(in_channels, inner_dim)\n\n        self.transformer_blocks = nn.ModuleList(\n            [\n                BasicTransformerBlock(\n                    inner_dim,\n                    n_heads,\n                    d_head,\n                    dropout=dropout,\n                    context_dim=context_dim[d],\n                    disable_self_attn=disable_self_attn,\n                    attn_mode=attn_type,\n                    checkpoint=use_checkpoint,\n                    sdp_backend=sdp_backend,\n                )\n                for d in range(depth)\n            ]\n        )\n        if not use_linear:\n            self.proj_out = zero_module(\n                nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)\n            )\n        else:\n            # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))\n            self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))\n        self.use_linear = use_linear\n\n    def forward(self, x, context=None):\n        # note: if no context is given, cross-attention defaults to self-attention\n        if not isinstance(context, list):\n            context = [context]\n        b, c, h, w = x.shape\n        x_in = x\n        x = self.norm(x)\n        if not self.use_linear:\n            x = self.proj_in(x)\n        x = rearrange(x, \"b c h w -> b (h w) c\").contiguous()\n        if self.use_linear:\n            x = self.proj_in(x)\n        for i, block in enumerate(self.transformer_blocks):\n            if i > 0 and len(context) == 1:\n                i = 0  # use same context for each block\n            x = block(x, context=context[i])\n        if self.use_linear:\n            x = self.proj_out(x)\n        x = rearrange(x, \"b (h w) c -> b c h w\", h=h, w=w).contiguous()\n        if not self.use_linear:\n            x = self.proj_out(x)\n        return x + x_in\n\n\nclass SimpleTransformer(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        depth: int,\n        heads: int,\n        dim_head: int,\n        context_dim: Optional[int] = None,\n        dropout: float = 0.0,\n        checkpoint: bool = True,\n    ):\n        super().__init__()\n        self.layers = nn.ModuleList([])\n        for _ in range(depth):\n            self.layers.append(\n                BasicTransformerBlock(\n                    dim,\n                    heads,\n                    dim_head,\n                    dropout=dropout,\n                    context_dim=context_dim,\n                    attn_mode=\"softmax-xformers\",\n                    checkpoint=checkpoint,\n                )\n            )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        context: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        for layer in self.layers:\n            x = layer(x, context)\n        return x\n"
  },
  {
    "path": "sgm/modules/autoencoding/__init__.py",
    "content": ""
  },
  {
    "path": "sgm/modules/autoencoding/losses/__init__.py",
    "content": "__all__ = [\n    \"GeneralLPIPSWithDiscriminator\",\n    \"LatentLPIPS\",\n]\n\nfrom .discriminator_loss import GeneralLPIPSWithDiscriminator\nfrom .lpips import LatentLPIPS\n"
  },
  {
    "path": "sgm/modules/autoencoding/losses/discriminator_loss.py",
    "content": "from typing import Dict, Iterator, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torchvision\nfrom einops import rearrange\nfrom matplotlib import colormaps\nfrom matplotlib import pyplot as plt\n\nfrom ....util import default, instantiate_from_config\nfrom ..lpips.loss.lpips import LPIPS\nfrom ..lpips.model.model import weights_init\nfrom ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss\n\n\nclass GeneralLPIPSWithDiscriminator(nn.Module):\n    def __init__(\n        self,\n        disc_start: int,\n        logvar_init: float = 0.0,\n        disc_num_layers: int = 3,\n        disc_in_channels: int = 3,\n        disc_factor: float = 1.0,\n        disc_weight: float = 1.0,\n        perceptual_weight: float = 1.0,\n        disc_loss: str = \"hinge\",\n        scale_input_to_tgt_size: bool = False,\n        dims: int = 2,\n        learn_logvar: bool = False,\n        regularization_weights: Union[None, Dict[str, float]] = None,\n        additional_log_keys: Optional[List[str]] = None,\n        discriminator_config: Optional[Dict] = None,\n    ):\n        super().__init__()\n        self.dims = dims\n        if self.dims > 2:\n            print(\n                f\"running with dims={dims}. This means that for perceptual loss \"\n                f\"calculation, the LPIPS loss will be applied to each frame \"\n                f\"independently.\"\n            )\n        self.scale_input_to_tgt_size = scale_input_to_tgt_size\n        assert disc_loss in [\"hinge\", \"vanilla\"]\n        self.perceptual_loss = LPIPS().eval()\n        self.perceptual_weight = perceptual_weight\n        # output log variance\n        self.logvar = nn.Parameter(\n            torch.full((), logvar_init), requires_grad=learn_logvar\n        )\n        self.learn_logvar = learn_logvar\n\n        discriminator_config = default(\n            discriminator_config,\n            {\n                \"target\": \"sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator\",\n                \"params\": {\n                    \"input_nc\": disc_in_channels,\n                    \"n_layers\": disc_num_layers,\n                    \"use_actnorm\": False,\n                },\n            },\n        )\n\n        self.discriminator = instantiate_from_config(discriminator_config).apply(\n            weights_init\n        )\n        self.discriminator_iter_start = disc_start\n        self.disc_loss = hinge_d_loss if disc_loss == \"hinge\" else vanilla_d_loss\n        self.disc_factor = disc_factor\n        self.discriminator_weight = disc_weight\n        self.regularization_weights = default(regularization_weights, {})\n\n        self.forward_keys = [\n            \"optimizer_idx\",\n            \"global_step\",\n            \"last_layer\",\n            \"split\",\n            \"regularization_log\",\n        ]\n\n        self.additional_log_keys = set(default(additional_log_keys, []))\n        self.additional_log_keys.update(set(self.regularization_weights.keys()))\n\n    def get_trainable_parameters(self) -> Iterator[nn.Parameter]:\n        return self.discriminator.parameters()\n\n    def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]:\n        if self.learn_logvar:\n            yield self.logvar\n        yield from ()\n\n    @torch.no_grad()\n    def log_images(\n        self, inputs: torch.Tensor, reconstructions: torch.Tensor\n    ) -> Dict[str, torch.Tensor]:\n        # calc logits of real/fake\n        logits_real = self.discriminator(inputs.contiguous().detach())\n        if len(logits_real.shape) < 4:\n            # Non patch-discriminator\n            return dict()\n        logits_fake = self.discriminator(reconstructions.contiguous().detach())\n        # -> (b, 1, h, w)\n\n        # parameters for colormapping\n        high = max(logits_fake.abs().max(), logits_real.abs().max()).item()\n        cmap = colormaps[\"PiYG\"]  # diverging colormap\n\n        def to_colormap(logits: torch.Tensor) -> torch.Tensor:\n            \"\"\"(b, 1, ...) -> (b, 3, ...)\"\"\"\n            logits = (logits + high) / (2 * high)\n            logits_np = cmap(logits.cpu().numpy())[..., :3]  # truncate alpha channel\n            # -> (b, 1, ..., 3)\n            logits = torch.from_numpy(logits_np).to(logits.device)\n            return rearrange(logits, \"b 1 ... c -> b c ...\")\n\n        logits_real = torch.nn.functional.interpolate(\n            logits_real,\n            size=inputs.shape[-2:],\n            mode=\"nearest\",\n            antialias=False,\n        )\n        logits_fake = torch.nn.functional.interpolate(\n            logits_fake,\n            size=reconstructions.shape[-2:],\n            mode=\"nearest\",\n            antialias=False,\n        )\n\n        # alpha value of logits for overlay\n        alpha_real = torch.abs(logits_real) / high\n        alpha_fake = torch.abs(logits_fake) / high\n        # -> (b, 1, h, w) in range [0, 0.5]\n        # alpha value of lines don't really matter, since the values are the same\n        # for both images and logits anyway\n        grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4)\n        grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4)\n        grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1)\n        # -> (1, h, w)\n        # blend logits and images together\n\n        # prepare logits for plotting\n        logits_real = to_colormap(logits_real)\n        logits_fake = to_colormap(logits_fake)\n        # resize logits\n        # -> (b, 3, h, w)\n\n        # make some grids\n        # add all logits to one plot\n        logits_real = torchvision.utils.make_grid(logits_real, nrow=4)\n        logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4)\n        # I just love how torchvision calls the number of columns `nrow`\n        grid_logits = torch.cat((logits_real, logits_fake), dim=1)\n        # -> (3, h, w)\n\n        grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4)\n        grid_images_fake = torchvision.utils.make_grid(\n            0.5 * reconstructions + 0.5, nrow=4\n        )\n        grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1)\n        # -> (3, h, w) in range [0, 1]\n\n        grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images\n\n        # Create labeled colorbar\n        dpi = 100\n        height = 128 / dpi\n        width = grid_logits.shape[2] / dpi\n        fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)\n        img = ax.imshow(np.array([[-high, high]]), cmap=cmap)\n        plt.colorbar(\n            img,\n            cax=ax,\n            orientation=\"horizontal\",\n            fraction=0.9,\n            aspect=width / height,\n            pad=0.0,\n        )\n        img.set_visible(False)\n        fig.tight_layout()\n        fig.canvas.draw()\n        # manually convert figure to numpy\n        cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n        cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))\n        cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0\n        cbar = rearrange(cbar, \"h w c -> c h w\").to(grid_logits.device)\n\n        # Add colorbar to plot\n        annotated_grid = torch.cat((grid_logits, cbar), dim=1)\n        blended_grid = torch.cat((grid_blend, cbar), dim=1)\n        return {\n            \"vis_logits\": 2 * annotated_grid[None, ...] - 1,\n            \"vis_logits_blended\": 2 * blended_grid[None, ...] - 1,\n        }\n\n    def calculate_adaptive_weight(\n        self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor\n    ) -> torch.Tensor:\n        nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]\n        g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]\n\n        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)\n        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()\n        d_weight = d_weight * self.discriminator_weight\n        return d_weight\n\n    def forward(\n        self,\n        inputs: torch.Tensor,\n        reconstructions: torch.Tensor,\n        *,  # added because I changed the order here\n        regularization_log: Dict[str, torch.Tensor],\n        optimizer_idx: int,\n        global_step: int,\n        last_layer: torch.Tensor,\n        split: str = \"train\",\n        weights: Union[None, float, torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, dict]:\n        if self.scale_input_to_tgt_size:\n            inputs = torch.nn.functional.interpolate(\n                inputs, reconstructions.shape[2:], mode=\"bicubic\", antialias=True\n            )\n\n        if self.dims > 2:\n            inputs, reconstructions = map(\n                lambda x: rearrange(x, \"b c t h w -> (b t) c h w\"),\n                (inputs, reconstructions),\n            )\n\n        rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())\n        if self.perceptual_weight > 0:\n            p_loss = self.perceptual_loss(\n                inputs.contiguous(), reconstructions.contiguous()\n            )\n            rec_loss = rec_loss + self.perceptual_weight * p_loss\n\n        nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)\n\n        # now the GAN part\n        if optimizer_idx == 0:\n            # generator update\n            if global_step >= self.discriminator_iter_start or not self.training:\n                logits_fake = self.discriminator(reconstructions.contiguous())\n                g_loss = -torch.mean(logits_fake)\n                if self.training:\n                    d_weight = self.calculate_adaptive_weight(\n                        nll_loss, g_loss, last_layer=last_layer\n                    )\n                else:\n                    d_weight = torch.tensor(1.0)\n            else:\n                d_weight = torch.tensor(0.0)\n                g_loss = torch.tensor(0.0, requires_grad=True)\n\n            loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss\n            log = dict()\n            for k in regularization_log:\n                if k in self.regularization_weights:\n                    loss = loss + self.regularization_weights[k] * regularization_log[k]\n                if k in self.additional_log_keys:\n                    log[f\"{split}/{k}\"] = regularization_log[k].detach().float().mean()\n\n            log.update(\n                {\n                    f\"{split}/loss/total\": loss.clone().detach().mean(),\n                    f\"{split}/loss/nll\": nll_loss.detach().mean(),\n                    f\"{split}/loss/rec\": rec_loss.detach().mean(),\n                    f\"{split}/loss/g\": g_loss.detach().mean(),\n                    f\"{split}/scalars/logvar\": self.logvar.detach(),\n                    f\"{split}/scalars/d_weight\": d_weight.detach(),\n                }\n            )\n\n            return loss, log\n        elif optimizer_idx == 1:\n            # second pass for discriminator update\n            logits_real = self.discriminator(inputs.contiguous().detach())\n            logits_fake = self.discriminator(reconstructions.contiguous().detach())\n\n            if global_step >= self.discriminator_iter_start or not self.training:\n                d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake)\n            else:\n                d_loss = torch.tensor(0.0, requires_grad=True)\n\n            log = {\n                f\"{split}/loss/disc\": d_loss.clone().detach().mean(),\n                f\"{split}/logits/real\": logits_real.detach().mean(),\n                f\"{split}/logits/fake\": logits_fake.detach().mean(),\n            }\n            return d_loss, log\n        else:\n            raise NotImplementedError(f\"Unknown optimizer_idx {optimizer_idx}\")\n\n    def get_nll_loss(\n        self,\n        rec_loss: torch.Tensor,\n        weights: Optional[Union[float, torch.Tensor]] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar\n        weighted_nll_loss = nll_loss\n        if weights is not None:\n            weighted_nll_loss = weights * nll_loss\n        weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]\n        nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]\n\n        return nll_loss, weighted_nll_loss\n"
  },
  {
    "path": "sgm/modules/autoencoding/losses/lpips.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom ....util import default, instantiate_from_config\nfrom ..lpips.loss.lpips import LPIPS\n\n\nclass LatentLPIPS(nn.Module):\n    def __init__(\n        self,\n        decoder_config,\n        perceptual_weight=1.0,\n        latent_weight=1.0,\n        scale_input_to_tgt_size=False,\n        scale_tgt_to_input_size=False,\n        perceptual_weight_on_inputs=0.0,\n    ):\n        super().__init__()\n        self.scale_input_to_tgt_size = scale_input_to_tgt_size\n        self.scale_tgt_to_input_size = scale_tgt_to_input_size\n        self.init_decoder(decoder_config)\n        self.perceptual_loss = LPIPS().eval()\n        self.perceptual_weight = perceptual_weight\n        self.latent_weight = latent_weight\n        self.perceptual_weight_on_inputs = perceptual_weight_on_inputs\n\n    def init_decoder(self, config):\n        self.decoder = instantiate_from_config(config)\n        if hasattr(self.decoder, \"encoder\"):\n            del self.decoder.encoder\n\n    def forward(self, latent_inputs, latent_predictions, image_inputs, split=\"train\"):\n        log = dict()\n        loss = (latent_inputs - latent_predictions) ** 2\n        log[f\"{split}/latent_l2_loss\"] = loss.mean().detach()\n        image_reconstructions = None\n        if self.perceptual_weight > 0.0:\n            image_reconstructions = self.decoder.decode(latent_predictions)\n            image_targets = self.decoder.decode(latent_inputs)\n            perceptual_loss = self.perceptual_loss(\n                image_targets.contiguous(), image_reconstructions.contiguous()\n            )\n            loss = (\n                self.latent_weight * loss.mean()\n                + self.perceptual_weight * perceptual_loss.mean()\n            )\n            log[f\"{split}/perceptual_loss\"] = perceptual_loss.mean().detach()\n\n        if self.perceptual_weight_on_inputs > 0.0:\n            image_reconstructions = default(\n                image_reconstructions, self.decoder.decode(latent_predictions)\n            )\n            if self.scale_input_to_tgt_size:\n                image_inputs = torch.nn.functional.interpolate(\n                    image_inputs,\n                    image_reconstructions.shape[2:],\n                    mode=\"bicubic\",\n                    antialias=True,\n                )\n            elif self.scale_tgt_to_input_size:\n                image_reconstructions = torch.nn.functional.interpolate(\n                    image_reconstructions,\n                    image_inputs.shape[2:],\n                    mode=\"bicubic\",\n                    antialias=True,\n                )\n\n            perceptual_loss2 = self.perceptual_loss(\n                image_inputs.contiguous(), image_reconstructions.contiguous()\n            )\n            loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()\n            log[f\"{split}/perceptual_loss_on_inputs\"] = perceptual_loss2.mean().detach()\n        return loss, log\n"
  },
  {
    "path": "sgm/modules/autoencoding/lpips/__init__.py",
    "content": ""
  },
  {
    "path": "sgm/modules/autoencoding/lpips/loss/.gitignore",
    "content": "vgg.pth"
  },
  {
    "path": "sgm/modules/autoencoding/lpips/loss/LICENSE",
    "content": "Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n* Redistributions of source code must retain the above copyright notice, this\n  list of conditions and the following disclaimer.\n\n* Redistributions in binary form must reproduce the above copyright notice,\n  this list of conditions and the following disclaimer in the documentation\n  and/or other materials provided with the distribution.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."
  },
  {
    "path": "sgm/modules/autoencoding/lpips/loss/__init__.py",
    "content": ""
  },
  {
    "path": "sgm/modules/autoencoding/lpips/loss/lpips.py",
    "content": "\"\"\"Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models\"\"\"\n\nfrom collections import namedtuple\n\nimport torch\nimport torch.nn as nn\nfrom torchvision import models\n\nfrom ..util import get_ckpt_path\n\n\nclass LPIPS(nn.Module):\n    # Learned perceptual metric\n    def __init__(self, use_dropout=True):\n        super().__init__()\n        self.scaling_layer = ScalingLayer()\n        self.chns = [64, 128, 256, 512, 512]  # vg16 features\n        self.net = vgg16(pretrained=True, requires_grad=False)\n        self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)\n        self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)\n        self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)\n        self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)\n        self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)\n        self.load_from_pretrained()\n        for param in self.parameters():\n            param.requires_grad = False\n\n    def load_from_pretrained(self, name=\"vgg_lpips\"):\n        ckpt = get_ckpt_path(name, \"sgm/modules/autoencoding/lpips/loss\")\n        self.load_state_dict(\n            torch.load(ckpt, map_location=torch.device(\"cpu\")), strict=False\n        )\n        print(\"loaded pretrained LPIPS loss from {}\".format(ckpt))\n\n    @classmethod\n    def from_pretrained(cls, name=\"vgg_lpips\"):\n        if name != \"vgg_lpips\":\n            raise NotImplementedError\n        model = cls()\n        ckpt = get_ckpt_path(name)\n        model.load_state_dict(\n            torch.load(ckpt, map_location=torch.device(\"cpu\")), strict=False\n        )\n        return model\n\n    def forward(self, input, target):\n        in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))\n        outs0, outs1 = self.net(in0_input), self.net(in1_input)\n        feats0, feats1, diffs = {}, {}, {}\n        lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]\n        for kk in range(len(self.chns)):\n            feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(\n                outs1[kk]\n            )\n            diffs[kk] = (feats0[kk] - feats1[kk]) ** 2\n\n        res = [\n            spatial_average(lins[kk].model(diffs[kk]), keepdim=True)\n            for kk in range(len(self.chns))\n        ]\n        val = res[0]\n        for l in range(1, len(self.chns)):\n            val += res[l]\n        return val\n\n\nclass ScalingLayer(nn.Module):\n    def __init__(self):\n        super(ScalingLayer, self).__init__()\n        self.register_buffer(\n            \"shift\", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]\n        )\n        self.register_buffer(\n            \"scale\", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]\n        )\n\n    def forward(self, inp):\n        return (inp - self.shift) / self.scale\n\n\nclass NetLinLayer(nn.Module):\n    \"\"\"A single linear layer which does a 1x1 conv\"\"\"\n\n    def __init__(self, chn_in, chn_out=1, use_dropout=False):\n        super(NetLinLayer, self).__init__()\n        layers = (\n            [\n                nn.Dropout(),\n            ]\n            if (use_dropout)\n            else []\n        )\n        layers += [\n            nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),\n        ]\n        self.model = nn.Sequential(*layers)\n\n\nclass vgg16(torch.nn.Module):\n    def __init__(self, requires_grad=False, pretrained=True):\n        super(vgg16, self).__init__()\n        vgg_pretrained_features = models.vgg16(pretrained=pretrained).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        self.N_slices = 5\n        for x in range(4):\n            self.slice1.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(4, 9):\n            self.slice2.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(9, 16):\n            self.slice3.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(16, 23):\n            self.slice4.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(23, 30):\n            self.slice5.add_module(str(x), vgg_pretrained_features[x])\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        h = self.slice1(X)\n        h_relu1_2 = h\n        h = self.slice2(h)\n        h_relu2_2 = h\n        h = self.slice3(h)\n        h_relu3_3 = h\n        h = self.slice4(h)\n        h_relu4_3 = h\n        h = self.slice5(h)\n        h_relu5_3 = h\n        vgg_outputs = namedtuple(\n            \"VggOutputs\", [\"relu1_2\", \"relu2_2\", \"relu3_3\", \"relu4_3\", \"relu5_3\"]\n        )\n        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)\n        return out\n\n\ndef normalize_tensor(x, eps=1e-10):\n    norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))\n    return x / (norm_factor + eps)\n\n\ndef spatial_average(x, keepdim=True):\n    return x.mean([2, 3], keepdim=keepdim)\n"
  },
  {
    "path": "sgm/modules/autoencoding/lpips/model/LICENSE",
    "content": "Copyright (c) 2017, Jun-Yan Zhu and Taesung Park\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n* Redistributions of source code must retain the above copyright notice, this\n  list of conditions and the following disclaimer.\n\n* Redistributions in binary form must reproduce the above copyright notice,\n  this list of conditions and the following disclaimer in the documentation\n  and/or other materials provided with the distribution.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n\n--------------------------- LICENSE FOR pix2pix --------------------------------\nBSD License\n\nFor pix2pix software\nCopyright (c) 2016, Phillip Isola and Jun-Yan Zhu\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n* Redistributions of source code must retain the above copyright notice, this\n  list of conditions and the following disclaimer.\n\n* Redistributions in binary form must reproduce the above copyright notice,\n  this list of conditions and the following disclaimer in the documentation\n  and/or other materials provided with the distribution.\n\n----------------------------- LICENSE FOR DCGAN --------------------------------\nBSD License\n\nFor dcgan.torch software\n\nCopyright (c) 2015, Facebook, Inc. All rights reserved.\n\nRedistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\n\nRedistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\n\nRedistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\n\nNeither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."
  },
  {
    "path": "sgm/modules/autoencoding/lpips/model/__init__.py",
    "content": ""
  },
  {
    "path": "sgm/modules/autoencoding/lpips/model/model.py",
    "content": "import functools\n\nimport torch.nn as nn\n\nfrom ..util import ActNorm\n\n\ndef weights_init(m):\n    classname = m.__class__.__name__\n    if classname.find(\"Conv\") != -1:\n        nn.init.normal_(m.weight.data, 0.0, 0.02)\n    elif classname.find(\"BatchNorm\") != -1:\n        nn.init.normal_(m.weight.data, 1.0, 0.02)\n        nn.init.constant_(m.bias.data, 0)\n\n\nclass NLayerDiscriminator(nn.Module):\n    \"\"\"Defines a PatchGAN discriminator as in Pix2Pix\n    --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py\n    \"\"\"\n\n    def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):\n        \"\"\"Construct a PatchGAN discriminator\n        Parameters:\n            input_nc (int)  -- the number of channels in input images\n            ndf (int)       -- the number of filters in the last conv layer\n            n_layers (int)  -- the number of conv layers in the discriminator\n            norm_layer      -- normalization layer\n        \"\"\"\n        super(NLayerDiscriminator, self).__init__()\n        if not use_actnorm:\n            norm_layer = nn.BatchNorm2d\n        else:\n            norm_layer = ActNorm\n        if (\n            type(norm_layer) == functools.partial\n        ):  # no need to use bias as BatchNorm2d has affine parameters\n            use_bias = norm_layer.func != nn.BatchNorm2d\n        else:\n            use_bias = norm_layer != nn.BatchNorm2d\n\n        kw = 4\n        padw = 1\n        sequence = [\n            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),\n            nn.LeakyReLU(0.2, True),\n        ]\n        nf_mult = 1\n        nf_mult_prev = 1\n        for n in range(1, n_layers):  # gradually increase the number of filters\n            nf_mult_prev = nf_mult\n            nf_mult = min(2**n, 8)\n            sequence += [\n                nn.Conv2d(\n                    ndf * nf_mult_prev,\n                    ndf * nf_mult,\n                    kernel_size=kw,\n                    stride=2,\n                    padding=padw,\n                    bias=use_bias,\n                ),\n                norm_layer(ndf * nf_mult),\n                nn.LeakyReLU(0.2, True),\n            ]\n\n        nf_mult_prev = nf_mult\n        nf_mult = min(2**n_layers, 8)\n        sequence += [\n            nn.Conv2d(\n                ndf * nf_mult_prev,\n                ndf * nf_mult,\n                kernel_size=kw,\n                stride=1,\n                padding=padw,\n                bias=use_bias,\n            ),\n            norm_layer(ndf * nf_mult),\n            nn.LeakyReLU(0.2, True),\n        ]\n\n        sequence += [\n            nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)\n        ]  # output 1 channel prediction map\n        self.main = nn.Sequential(*sequence)\n\n    def forward(self, input):\n        \"\"\"Standard forward.\"\"\"\n        return self.main(input)\n"
  },
  {
    "path": "sgm/modules/autoencoding/lpips/util.py",
    "content": "import hashlib\nimport os\n\nimport requests\nimport torch\nimport torch.nn as nn\nfrom tqdm import tqdm\n\nURL_MAP = {\"vgg_lpips\": \"https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1\"}\n\nCKPT_MAP = {\"vgg_lpips\": \"vgg.pth\"}\n\nMD5_MAP = {\"vgg_lpips\": \"d507d7349b931f0638a25a48a722f98a\"}\n\n\ndef download(url, local_path, chunk_size=1024):\n    os.makedirs(os.path.split(local_path)[0], exist_ok=True)\n    with requests.get(url, stream=True) as r:\n        total_size = int(r.headers.get(\"content-length\", 0))\n        with tqdm(total=total_size, unit=\"B\", unit_scale=True) as pbar:\n            with open(local_path, \"wb\") as f:\n                for data in r.iter_content(chunk_size=chunk_size):\n                    if data:\n                        f.write(data)\n                        pbar.update(chunk_size)\n\n\ndef md5_hash(path):\n    with open(path, \"rb\") as f:\n        content = f.read()\n    return hashlib.md5(content).hexdigest()\n\n\ndef get_ckpt_path(name, root, check=False):\n    assert name in URL_MAP\n    path = os.path.join(root, CKPT_MAP[name])\n    if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):\n        print(\"Downloading {} model from {} to {}\".format(name, URL_MAP[name], path))\n        download(URL_MAP[name], path)\n        md5 = md5_hash(path)\n        assert md5 == MD5_MAP[name], md5\n    return path\n\n\nclass ActNorm(nn.Module):\n    def __init__(\n        self, num_features, logdet=False, affine=True, allow_reverse_init=False\n    ):\n        assert affine\n        super().__init__()\n        self.logdet = logdet\n        self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))\n        self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))\n        self.allow_reverse_init = allow_reverse_init\n\n        self.register_buffer(\"initialized\", torch.tensor(0, dtype=torch.uint8))\n\n    def initialize(self, input):\n        with torch.no_grad():\n            flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)\n            mean = (\n                flatten.mean(1)\n                .unsqueeze(1)\n                .unsqueeze(2)\n                .unsqueeze(3)\n                .permute(1, 0, 2, 3)\n            )\n            std = (\n                flatten.std(1)\n                .unsqueeze(1)\n                .unsqueeze(2)\n                .unsqueeze(3)\n                .permute(1, 0, 2, 3)\n            )\n\n            self.loc.data.copy_(-mean)\n            self.scale.data.copy_(1 / (std + 1e-6))\n\n    def forward(self, input, reverse=False):\n        if reverse:\n            return self.reverse(input)\n        if len(input.shape) == 2:\n            input = input[:, :, None, None]\n            squeeze = True\n        else:\n            squeeze = False\n\n        _, _, height, width = input.shape\n\n        if self.training and self.initialized.item() == 0:\n            self.initialize(input)\n            self.initialized.fill_(1)\n\n        h = self.scale * (input + self.loc)\n\n        if squeeze:\n            h = h.squeeze(-1).squeeze(-1)\n\n        if self.logdet:\n            log_abs = torch.log(torch.abs(self.scale))\n            logdet = height * width * torch.sum(log_abs)\n            logdet = logdet * torch.ones(input.shape[0]).to(input)\n            return h, logdet\n\n        return h\n\n    def reverse(self, output):\n        if self.training and self.initialized.item() == 0:\n            if not self.allow_reverse_init:\n                raise RuntimeError(\n                    \"Initializing ActNorm in reverse direction is \"\n                    \"disabled by default. Use allow_reverse_init=True to enable.\"\n                )\n            else:\n                self.initialize(output)\n                self.initialized.fill_(1)\n\n        if len(output.shape) == 2:\n            output = output[:, :, None, None]\n            squeeze = True\n        else:\n            squeeze = False\n\n        h = output / self.scale - self.loc\n\n        if squeeze:\n            h = h.squeeze(-1).squeeze(-1)\n        return h\n"
  },
  {
    "path": "sgm/modules/autoencoding/lpips/vqperceptual.py",
    "content": "import torch\nimport torch.nn.functional as F\n\n\ndef hinge_d_loss(logits_real, logits_fake):\n    loss_real = torch.mean(F.relu(1.0 - logits_real))\n    loss_fake = torch.mean(F.relu(1.0 + logits_fake))\n    d_loss = 0.5 * (loss_real + loss_fake)\n    return d_loss\n\n\ndef vanilla_d_loss(logits_real, logits_fake):\n    d_loss = 0.5 * (\n        torch.mean(torch.nn.functional.softplus(-logits_real))\n        + torch.mean(torch.nn.functional.softplus(logits_fake))\n    )\n    return d_loss\n"
  },
  {
    "path": "sgm/modules/autoencoding/regularizers/__init__.py",
    "content": "from abc import abstractmethod\nfrom typing import Any, Tuple\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom ....modules.distributions.distributions import \\\n    DiagonalGaussianDistribution\nfrom .base import AbstractRegularizer\n\n\nclass DiagonalGaussianRegularizer(AbstractRegularizer):\n    def __init__(self, sample: bool = True):\n        super().__init__()\n        self.sample = sample\n\n    def get_trainable_parameters(self) -> Any:\n        yield from ()\n\n    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:\n        log = dict()\n        posterior = DiagonalGaussianDistribution(z)\n        if self.sample:\n            z = posterior.sample()\n        else:\n            z = posterior.mode()\n        kl_loss = posterior.kl()\n        kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]\n        log[\"kl_loss\"] = kl_loss\n        return z, log\n"
  },
  {
    "path": "sgm/modules/autoencoding/regularizers/base.py",
    "content": "from abc import abstractmethod\nfrom typing import Any, Tuple\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\n\nclass AbstractRegularizer(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:\n        raise NotImplementedError()\n\n    @abstractmethod\n    def get_trainable_parameters(self) -> Any:\n        raise NotImplementedError()\n\n\nclass IdentityRegularizer(AbstractRegularizer):\n    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:\n        return z, dict()\n\n    def get_trainable_parameters(self) -> Any:\n        yield from ()\n\n\ndef measure_perplexity(\n    predicted_indices: torch.Tensor, num_centroids: int\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py\n    # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally\n    encodings = (\n        F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)\n    )\n    avg_probs = encodings.mean(0)\n    perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()\n    cluster_use = torch.sum(avg_probs > 0)\n    return perplexity, cluster_use\n"
  },
  {
    "path": "sgm/modules/autoencoding/regularizers/quantize.py",
    "content": "import logging\nfrom abc import abstractmethod\nfrom typing import Dict, Iterator, Literal, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom torch import einsum\n\nfrom .base import AbstractRegularizer, measure_perplexity\n\nlogpy = logging.getLogger(__name__)\n\n\nclass AbstractQuantizer(AbstractRegularizer):\n    def __init__(self):\n        super().__init__()\n        # Define these in your init\n        # shape (N,)\n        self.used: Optional[torch.Tensor]\n        self.re_embed: int\n        self.unknown_index: Union[Literal[\"random\"], int]\n\n    def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor:\n        assert self.used is not None, \"You need to define used indices for remap\"\n        ishape = inds.shape\n        assert len(ishape) > 1\n        inds = inds.reshape(ishape[0], -1)\n        used = self.used.to(inds)\n        match = (inds[:, :, None] == used[None, None, ...]).long()\n        new = match.argmax(-1)\n        unknown = match.sum(2) < 1\n        if self.unknown_index == \"random\":\n            new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(\n                device=new.device\n            )\n        else:\n            new[unknown] = self.unknown_index\n        return new.reshape(ishape)\n\n    def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor:\n        assert self.used is not None, \"You need to define used indices for remap\"\n        ishape = inds.shape\n        assert len(ishape) > 1\n        inds = inds.reshape(ishape[0], -1)\n        used = self.used.to(inds)\n        if self.re_embed > self.used.shape[0]:  # extra token\n            inds[inds >= self.used.shape[0]] = 0  # simply set to zero\n        back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)\n        return back.reshape(ishape)\n\n    @abstractmethod\n    def get_codebook_entry(\n        self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None\n    ) -> torch.Tensor:\n        raise NotImplementedError()\n\n    def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]:\n        yield from self.parameters()\n\n\nclass GumbelQuantizer(AbstractQuantizer):\n    \"\"\"\n    credit to @karpathy:\n    https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)\n    Gumbel Softmax trick quantizer\n    Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016\n    https://arxiv.org/abs/1611.01144\n    \"\"\"\n\n    def __init__(\n        self,\n        num_hiddens: int,\n        embedding_dim: int,\n        n_embed: int,\n        straight_through: bool = True,\n        kl_weight: float = 5e-4,\n        temp_init: float = 1.0,\n        remap: Optional[str] = None,\n        unknown_index: str = \"random\",\n        loss_key: str = \"loss/vq\",\n    ) -> None:\n        super().__init__()\n\n        self.loss_key = loss_key\n        self.embedding_dim = embedding_dim\n        self.n_embed = n_embed\n\n        self.straight_through = straight_through\n        self.temperature = temp_init\n        self.kl_weight = kl_weight\n\n        self.proj = nn.Conv2d(num_hiddens, n_embed, 1)\n        self.embed = nn.Embedding(n_embed, embedding_dim)\n\n        self.remap = remap\n        if self.remap is not None:\n            self.register_buffer(\"used\", torch.tensor(np.load(self.remap)))\n            self.re_embed = self.used.shape[0]\n        else:\n            self.used = None\n            self.re_embed = n_embed\n        if unknown_index == \"extra\":\n            self.unknown_index = self.re_embed\n            self.re_embed = self.re_embed + 1\n        else:\n            assert unknown_index == \"random\" or isinstance(\n                unknown_index, int\n            ), \"unknown index needs to be 'random', 'extra' or any integer\"\n            self.unknown_index = unknown_index  # \"random\" or \"extra\" or integer\n        if self.remap is not None:\n            logpy.info(\n                f\"Remapping {self.n_embed} indices to {self.re_embed} indices. \"\n                f\"Using {self.unknown_index} for unknown indices.\"\n            )\n\n    def forward(\n        self, z: torch.Tensor, temp: Optional[float] = None, return_logits: bool = False\n    ) -> Tuple[torch.Tensor, Dict]:\n        # force hard = True when we are in eval mode, as we must quantize.\n        # actually, always true seems to work\n        hard = self.straight_through if self.training else True\n        temp = self.temperature if temp is None else temp\n        out_dict = {}\n        logits = self.proj(z)\n        if self.remap is not None:\n            # continue only with used logits\n            full_zeros = torch.zeros_like(logits)\n            logits = logits[:, self.used, ...]\n\n        soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)\n        if self.remap is not None:\n            # go back to all entries but unused set to zero\n            full_zeros[:, self.used, ...] = soft_one_hot\n            soft_one_hot = full_zeros\n        z_q = einsum(\"b n h w, n d -> b d h w\", soft_one_hot, self.embed.weight)\n\n        # + kl divergence to the prior loss\n        qy = F.softmax(logits, dim=1)\n        diff = (\n            self.kl_weight\n            * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()\n        )\n        out_dict[self.loss_key] = diff\n\n        ind = soft_one_hot.argmax(dim=1)\n        out_dict[\"indices\"] = ind\n        if self.remap is not None:\n            ind = self.remap_to_used(ind)\n\n        if return_logits:\n            out_dict[\"logits\"] = logits\n\n        return z_q, out_dict\n\n    def get_codebook_entry(self, indices, shape):\n        # TODO: shape not yet optional\n        b, h, w, c = shape\n        assert b * h * w == indices.shape[0]\n        indices = rearrange(indices, \"(b h w) -> b h w\", b=b, h=h, w=w)\n        if self.remap is not None:\n            indices = self.unmap_to_all(indices)\n        one_hot = (\n            F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()\n        )\n        z_q = einsum(\"b n h w, n d -> b d h w\", one_hot, self.embed.weight)\n        return z_q\n\n\nclass VectorQuantizer(AbstractQuantizer):\n    \"\"\"\n    ____________________________________________\n    Discretization bottleneck part of the VQ-VAE.\n    Inputs:\n    - n_e : number of embeddings\n    - e_dim : dimension of embedding\n    - beta : commitment cost used in loss term,\n        beta * ||z_e(x)-sg[e]||^2\n    _____________________________________________\n    \"\"\"\n\n    def __init__(\n        self,\n        n_e: int,\n        e_dim: int,\n        beta: float = 0.25,\n        remap: Optional[str] = None,\n        unknown_index: str = \"random\",\n        sane_index_shape: bool = False,\n        log_perplexity: bool = False,\n        embedding_weight_norm: bool = False,\n        loss_key: str = \"loss/vq\",\n    ):\n        super().__init__()\n        self.n_e = n_e\n        self.e_dim = e_dim\n        self.beta = beta\n        self.loss_key = loss_key\n\n        if not embedding_weight_norm:\n            self.embedding = nn.Embedding(self.n_e, self.e_dim)\n            self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)\n        else:\n            self.embedding = torch.nn.utils.weight_norm(\n                nn.Embedding(self.n_e, self.e_dim), dim=1\n            )\n\n        self.remap = remap\n        if self.remap is not None:\n            self.register_buffer(\"used\", torch.tensor(np.load(self.remap)))\n            self.re_embed = self.used.shape[0]\n        else:\n            self.used = None\n            self.re_embed = n_e\n        if unknown_index == \"extra\":\n            self.unknown_index = self.re_embed\n            self.re_embed = self.re_embed + 1\n        else:\n            assert unknown_index == \"random\" or isinstance(\n                unknown_index, int\n            ), \"unknown index needs to be 'random', 'extra' or any integer\"\n            self.unknown_index = unknown_index  # \"random\" or \"extra\" or integer\n        if self.remap is not None:\n            logpy.info(\n                f\"Remapping {self.n_e} indices to {self.re_embed} indices. \"\n                f\"Using {self.unknown_index} for unknown indices.\"\n            )\n\n        self.sane_index_shape = sane_index_shape\n        self.log_perplexity = log_perplexity\n\n    def forward(\n        self,\n        z: torch.Tensor,\n    ) -> Tuple[torch.Tensor, Dict]:\n        do_reshape = z.ndim == 4\n        if do_reshape:\n            #     # reshape z -> (batch, height, width, channel) and flatten\n            z = rearrange(z, \"b c h w -> b h w c\").contiguous()\n\n        else:\n            assert z.ndim < 4, \"No reshaping strategy for inputs > 4 dimensions defined\"\n            z = z.contiguous()\n\n        z_flattened = z.view(-1, self.e_dim)\n        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z\n\n        d = (\n            torch.sum(z_flattened**2, dim=1, keepdim=True)\n            + torch.sum(self.embedding.weight**2, dim=1)\n            - 2\n            * torch.einsum(\n                \"bd,dn->bn\", z_flattened, rearrange(self.embedding.weight, \"n d -> d n\")\n            )\n        )\n\n        min_encoding_indices = torch.argmin(d, dim=1)\n        z_q = self.embedding(min_encoding_indices).view(z.shape)\n        loss_dict = {}\n        if self.log_perplexity:\n            perplexity, cluster_usage = measure_perplexity(\n                min_encoding_indices.detach(), self.n_e\n            )\n            loss_dict.update({\"perplexity\": perplexity, \"cluster_usage\": cluster_usage})\n\n        # compute loss for embedding\n        loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(\n            (z_q - z.detach()) ** 2\n        )\n        loss_dict[self.loss_key] = loss\n\n        # preserve gradients\n        z_q = z + (z_q - z).detach()\n\n        # reshape back to match original input shape\n        if do_reshape:\n            z_q = rearrange(z_q, \"b h w c -> b c h w\").contiguous()\n\n        if self.remap is not None:\n            min_encoding_indices = min_encoding_indices.reshape(\n                z.shape[0], -1\n            )  # add batch axis\n            min_encoding_indices = self.remap_to_used(min_encoding_indices)\n            min_encoding_indices = min_encoding_indices.reshape(-1, 1)  # flatten\n\n        if self.sane_index_shape:\n            if do_reshape:\n                min_encoding_indices = min_encoding_indices.reshape(\n                    z_q.shape[0], z_q.shape[2], z_q.shape[3]\n                )\n            else:\n                min_encoding_indices = rearrange(\n                    min_encoding_indices, \"(b s) 1 -> b s\", b=z_q.shape[0]\n                )\n\n        loss_dict[\"min_encoding_indices\"] = min_encoding_indices\n\n        return z_q, loss_dict\n\n    def get_codebook_entry(\n        self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None\n    ) -> torch.Tensor:\n        # shape specifying (batch, height, width, channel)\n        if self.remap is not None:\n            assert shape is not None, \"Need to give shape for remap\"\n            indices = indices.reshape(shape[0], -1)  # add batch axis\n            indices = self.unmap_to_all(indices)\n            indices = indices.reshape(-1)  # flatten again\n\n        # get quantized latent vectors\n        z_q = self.embedding(indices)\n\n        if shape is not None:\n            z_q = z_q.view(shape)\n            # reshape back to match original input shape\n            z_q = z_q.permute(0, 3, 1, 2).contiguous()\n\n        return z_q\n\n\nclass EmbeddingEMA(nn.Module):\n    def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):\n        super().__init__()\n        self.decay = decay\n        self.eps = eps\n        weight = torch.randn(num_tokens, codebook_dim)\n        self.weight = nn.Parameter(weight, requires_grad=False)\n        self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)\n        self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)\n        self.update = True\n\n    def forward(self, embed_id):\n        return F.embedding(embed_id, self.weight)\n\n    def cluster_size_ema_update(self, new_cluster_size):\n        self.cluster_size.data.mul_(self.decay).add_(\n            new_cluster_size, alpha=1 - self.decay\n        )\n\n    def embed_avg_ema_update(self, new_embed_avg):\n        self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)\n\n    def weight_update(self, num_tokens):\n        n = self.cluster_size.sum()\n        smoothed_cluster_size = (\n            (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n\n        )\n        # normalize embedding average with smoothed cluster size\n        embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)\n        self.weight.data.copy_(embed_normalized)\n\n\nclass EMAVectorQuantizer(AbstractQuantizer):\n    def __init__(\n        self,\n        n_embed: int,\n        embedding_dim: int,\n        beta: float,\n        decay: float = 0.99,\n        eps: float = 1e-5,\n        remap: Optional[str] = None,\n        unknown_index: str = \"random\",\n        loss_key: str = \"loss/vq\",\n    ):\n        super().__init__()\n        self.codebook_dim = embedding_dim\n        self.num_tokens = n_embed\n        self.beta = beta\n        self.loss_key = loss_key\n\n        self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)\n\n        self.remap = remap\n        if self.remap is not None:\n            self.register_buffer(\"used\", torch.tensor(np.load(self.remap)))\n            self.re_embed = self.used.shape[0]\n        else:\n            self.used = None\n            self.re_embed = n_embed\n        if unknown_index == \"extra\":\n            self.unknown_index = self.re_embed\n            self.re_embed = self.re_embed + 1\n        else:\n            assert unknown_index == \"random\" or isinstance(\n                unknown_index, int\n            ), \"unknown index needs to be 'random', 'extra' or any integer\"\n            self.unknown_index = unknown_index  # \"random\" or \"extra\" or integer\n        if self.remap is not None:\n            logpy.info(\n                f\"Remapping {self.n_embed} indices to {self.re_embed} indices. \"\n                f\"Using {self.unknown_index} for unknown indices.\"\n            )\n\n    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:\n        # reshape z -> (batch, height, width, channel) and flatten\n        # z, 'b c h w -> b h w c'\n        z = rearrange(z, \"b c h w -> b h w c\")\n        z_flattened = z.reshape(-1, self.codebook_dim)\n\n        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z\n        d = (\n            z_flattened.pow(2).sum(dim=1, keepdim=True)\n            + self.embedding.weight.pow(2).sum(dim=1)\n            - 2 * torch.einsum(\"bd,nd->bn\", z_flattened, self.embedding.weight)\n        )  # 'n d -> d n'\n\n        encoding_indices = torch.argmin(d, dim=1)\n\n        z_q = self.embedding(encoding_indices).view(z.shape)\n        encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)\n        avg_probs = torch.mean(encodings, dim=0)\n        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))\n\n        if self.training and self.embedding.update:\n            # EMA cluster size\n            encodings_sum = encodings.sum(0)\n            self.embedding.cluster_size_ema_update(encodings_sum)\n            # EMA embedding average\n            embed_sum = encodings.transpose(0, 1) @ z_flattened\n            self.embedding.embed_avg_ema_update(embed_sum)\n            # normalize embed_avg and update weight\n            self.embedding.weight_update(self.num_tokens)\n\n        # compute loss for embedding\n        loss = self.beta * F.mse_loss(z_q.detach(), z)\n\n        # preserve gradients\n        z_q = z + (z_q - z).detach()\n\n        # reshape back to match original input shape\n        # z_q, 'b h w c -> b c h w'\n        z_q = rearrange(z_q, \"b h w c -> b c h w\")\n\n        out_dict = {\n            self.loss_key: loss,\n            \"encodings\": encodings,\n            \"encoding_indices\": encoding_indices,\n            \"perplexity\": perplexity,\n        }\n\n        return z_q, out_dict\n\n\nclass VectorQuantizerWithInputProjection(VectorQuantizer):\n    def __init__(\n        self,\n        input_dim: int,\n        n_codes: int,\n        codebook_dim: int,\n        beta: float = 1.0,\n        output_dim: Optional[int] = None,\n        **kwargs,\n    ):\n        super().__init__(n_codes, codebook_dim, beta, **kwargs)\n        self.proj_in = nn.Linear(input_dim, codebook_dim)\n        self.output_dim = output_dim\n        if output_dim is not None:\n            self.proj_out = nn.Linear(codebook_dim, output_dim)\n        else:\n            self.proj_out = nn.Identity()\n\n    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:\n        rearr = False\n        in_shape = z.shape\n\n        if z.ndim > 3:\n            rearr = self.output_dim is not None\n            z = rearrange(z, \"b c ... -> b (...) c\")\n        z = self.proj_in(z)\n        z_q, loss_dict = super().forward(z)\n\n        z_q = self.proj_out(z_q)\n        if rearr:\n            if len(in_shape) == 4:\n                z_q = rearrange(z_q, \"b (h w) c -> b c h w \", w=in_shape[-1])\n            elif len(in_shape) == 5:\n                z_q = rearrange(\n                    z_q, \"b (t h w) c -> b c t h w \", w=in_shape[-1], h=in_shape[-2]\n                )\n            else:\n                raise NotImplementedError(\n                    f\"rearranging not available for {len(in_shape)}-dimensional input.\"\n                )\n\n        return z_q, loss_dict\n"
  },
  {
    "path": "sgm/modules/autoencoding/temporal_ae.py",
    "content": "from typing import Callable, Iterable, Union\n\nimport torch\nfrom einops import rearrange, repeat\n\nfrom sgm.modules.diffusionmodules.model import (XFORMERS_IS_AVAILABLE,\n                                                AttnBlock, Decoder,\n                                                MemoryEfficientAttnBlock,\n                                                ResnetBlock)\nfrom sgm.modules.diffusionmodules.openaimodel import (ResBlock,\n                                                      timestep_embedding)\nfrom sgm.modules.video_attention import VideoTransformerBlock\nfrom sgm.util import partialclass\n\n\nclass VideoResBlock(ResnetBlock):\n    def __init__(\n        self,\n        out_channels,\n        *args,\n        dropout=0.0,\n        video_kernel_size=3,\n        alpha=0.0,\n        merge_strategy=\"learned\",\n        **kwargs,\n    ):\n        super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)\n        if video_kernel_size is None:\n            video_kernel_size = [3, 1, 1]\n        self.time_stack = ResBlock(\n            channels=out_channels,\n            emb_channels=0,\n            dropout=dropout,\n            dims=3,\n            use_scale_shift_norm=False,\n            use_conv=False,\n            up=False,\n            down=False,\n            kernel_size=video_kernel_size,\n            use_checkpoint=False,\n            skip_t_emb=True,\n        )\n\n        self.merge_strategy = merge_strategy\n        if self.merge_strategy == \"fixed\":\n            self.register_buffer(\"mix_factor\", torch.Tensor([alpha]))\n        elif self.merge_strategy == \"learned\":\n            self.register_parameter(\n                \"mix_factor\", torch.nn.Parameter(torch.Tensor([alpha]))\n            )\n        else:\n            raise ValueError(f\"unknown merge strategy {self.merge_strategy}\")\n\n    def get_alpha(self, bs):\n        if self.merge_strategy == \"fixed\":\n            return self.mix_factor\n        elif self.merge_strategy == \"learned\":\n            return torch.sigmoid(self.mix_factor)\n        else:\n            raise NotImplementedError()\n\n    def forward(self, x, temb, skip_video=False, timesteps=None):\n        if timesteps is None:\n            timesteps = self.timesteps\n\n        b, c, h, w = x.shape\n\n        x = super().forward(x, temb)\n\n        if not skip_video:\n            x_mix = rearrange(x, \"(b t) c h w -> b c t h w\", t=timesteps)\n\n            x = rearrange(x, \"(b t) c h w -> b c t h w\", t=timesteps)\n\n            x = self.time_stack(x, temb)\n\n            alpha = self.get_alpha(bs=b // timesteps)\n            x = alpha * x + (1.0 - alpha) * x_mix\n\n            x = rearrange(x, \"b c t h w -> (b t) c h w\")\n        return x\n\n\nclass AE3DConv(torch.nn.Conv2d):\n    def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):\n        super().__init__(in_channels, out_channels, *args, **kwargs)\n        if isinstance(video_kernel_size, Iterable):\n            padding = [int(k // 2) for k in video_kernel_size]\n        else:\n            padding = int(video_kernel_size // 2)\n\n        self.time_mix_conv = torch.nn.Conv3d(\n            in_channels=out_channels,\n            out_channels=out_channels,\n            kernel_size=video_kernel_size,\n            padding=padding,\n        )\n\n    def forward(self, input, timesteps, skip_video=False):\n        x = super().forward(input)\n        if skip_video:\n            return x\n        x = rearrange(x, \"(b t) c h w -> b c t h w\", t=timesteps)\n        x = self.time_mix_conv(x)\n        return rearrange(x, \"b c t h w -> (b t) c h w\")\n\n\nclass VideoBlock(AttnBlock):\n    def __init__(\n        self, in_channels: int, alpha: float = 0, merge_strategy: str = \"learned\"\n    ):\n        super().__init__(in_channels)\n        # no context, single headed, as in base class\n        self.time_mix_block = VideoTransformerBlock(\n            dim=in_channels,\n            n_heads=1,\n            d_head=in_channels,\n            checkpoint=False,\n            ff_in=True,\n            attn_mode=\"softmax\",\n        )\n\n        time_embed_dim = self.in_channels * 4\n        self.video_time_embed = torch.nn.Sequential(\n            torch.nn.Linear(self.in_channels, time_embed_dim),\n            torch.nn.SiLU(),\n            torch.nn.Linear(time_embed_dim, self.in_channels),\n        )\n\n        self.merge_strategy = merge_strategy\n        if self.merge_strategy == \"fixed\":\n            self.register_buffer(\"mix_factor\", torch.Tensor([alpha]))\n        elif self.merge_strategy == \"learned\":\n            self.register_parameter(\n                \"mix_factor\", torch.nn.Parameter(torch.Tensor([alpha]))\n            )\n        else:\n            raise ValueError(f\"unknown merge strategy {self.merge_strategy}\")\n\n    def forward(self, x, timesteps, skip_video=False):\n        if skip_video:\n            return super().forward(x)\n\n        x_in = x\n        x = self.attention(x)\n        h, w = x.shape[2:]\n        x = rearrange(x, \"b c h w -> b (h w) c\")\n\n        x_mix = x\n        num_frames = torch.arange(timesteps, device=x.device)\n        num_frames = repeat(num_frames, \"t -> b t\", b=x.shape[0] // timesteps)\n        num_frames = rearrange(num_frames, \"b t -> (b t)\")\n        t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)\n        emb = self.video_time_embed(t_emb)  # b, n_channels\n        emb = emb[:, None, :]\n        x_mix = x_mix + emb\n\n        alpha = self.get_alpha()\n        x_mix = self.time_mix_block(x_mix, timesteps=timesteps)\n        x = alpha * x + (1.0 - alpha) * x_mix  # alpha merge\n\n        x = rearrange(x, \"b (h w) c -> b c h w\", h=h, w=w)\n        x = self.proj_out(x)\n\n        return x_in + x\n\n    def get_alpha(\n        self,\n    ):\n        if self.merge_strategy == \"fixed\":\n            return self.mix_factor\n        elif self.merge_strategy == \"learned\":\n            return torch.sigmoid(self.mix_factor)\n        else:\n            raise NotImplementedError(f\"unknown merge strategy {self.merge_strategy}\")\n\n\nclass MemoryEfficientVideoBlock(MemoryEfficientAttnBlock):\n    def __init__(\n        self, in_channels: int, alpha: float = 0, merge_strategy: str = \"learned\"\n    ):\n        super().__init__(in_channels)\n        # no context, single headed, as in base class\n        self.time_mix_block = VideoTransformerBlock(\n            dim=in_channels,\n            n_heads=1,\n            d_head=in_channels,\n            checkpoint=False,\n            ff_in=True,\n            attn_mode=\"softmax-xformers\",\n        )\n\n        time_embed_dim = self.in_channels * 4\n        self.video_time_embed = torch.nn.Sequential(\n            torch.nn.Linear(self.in_channels, time_embed_dim),\n            torch.nn.SiLU(),\n            torch.nn.Linear(time_embed_dim, self.in_channels),\n        )\n\n        self.merge_strategy = merge_strategy\n        if self.merge_strategy == \"fixed\":\n            self.register_buffer(\"mix_factor\", torch.Tensor([alpha]))\n        elif self.merge_strategy == \"learned\":\n            self.register_parameter(\n                \"mix_factor\", torch.nn.Parameter(torch.Tensor([alpha]))\n            )\n        else:\n            raise ValueError(f\"unknown merge strategy {self.merge_strategy}\")\n\n    def forward(self, x, timesteps, skip_time_block=False):\n        if skip_time_block:\n            return super().forward(x)\n\n        x_in = x\n        x = self.attention(x)\n        h, w = x.shape[2:]\n        x = rearrange(x, \"b c h w -> b (h w) c\")\n\n        x_mix = x\n        num_frames = torch.arange(timesteps, device=x.device)\n        num_frames = repeat(num_frames, \"t -> b t\", b=x.shape[0] // timesteps)\n        num_frames = rearrange(num_frames, \"b t -> (b t)\")\n        t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)\n        emb = self.video_time_embed(t_emb)  # b, n_channels\n        emb = emb[:, None, :]\n        x_mix = x_mix + emb\n\n        alpha = self.get_alpha()\n        x_mix = self.time_mix_block(x_mix, timesteps=timesteps)\n        x = alpha * x + (1.0 - alpha) * x_mix  # alpha merge\n\n        x = rearrange(x, \"b (h w) c -> b c h w\", h=h, w=w)\n        x = self.proj_out(x)\n\n        return x_in + x\n\n    def get_alpha(\n        self,\n    ):\n        if self.merge_strategy == \"fixed\":\n            return self.mix_factor\n        elif self.merge_strategy == \"learned\":\n            return torch.sigmoid(self.mix_factor)\n        else:\n            raise NotImplementedError(f\"unknown merge strategy {self.merge_strategy}\")\n\n\ndef make_time_attn(\n    in_channels,\n    attn_type=\"vanilla\",\n    attn_kwargs=None,\n    alpha: float = 0,\n    merge_strategy: str = \"learned\",\n):\n    assert attn_type in [\n        \"vanilla\",\n        \"vanilla-xformers\",\n    ], f\"attn_type {attn_type} not supported for spatio-temporal attention\"\n    print(\n        f\"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels\"\n    )\n    if not XFORMERS_IS_AVAILABLE and attn_type == \"vanilla-xformers\":\n        print(\n            f\"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. \"\n            f\"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}\"\n        )\n        attn_type = \"vanilla\"\n\n    if attn_type == \"vanilla\":\n        assert attn_kwargs is None\n        return partialclass(\n            VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy\n        )\n    elif attn_type == \"vanilla-xformers\":\n        print(f\"building MemoryEfficientAttnBlock with {in_channels} in_channels...\")\n        return partialclass(\n            MemoryEfficientVideoBlock,\n            in_channels,\n            alpha=alpha,\n            merge_strategy=merge_strategy,\n        )\n    else:\n        return NotImplementedError()\n\n\nclass Conv2DWrapper(torch.nn.Conv2d):\n    def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:\n        return super().forward(input)\n\n\nclass VideoDecoder(Decoder):\n    available_time_modes = [\"all\", \"conv-only\", \"attn-only\"]\n\n    def __init__(\n        self,\n        *args,\n        video_kernel_size: Union[int, list] = 3,\n        alpha: float = 0.0,\n        merge_strategy: str = \"learned\",\n        time_mode: str = \"conv-only\",\n        **kwargs,\n    ):\n        self.video_kernel_size = video_kernel_size\n        self.alpha = alpha\n        self.merge_strategy = merge_strategy\n        self.time_mode = time_mode\n        assert (\n            self.time_mode in self.available_time_modes\n        ), f\"time_mode parameter has to be in {self.available_time_modes}\"\n        super().__init__(*args, **kwargs)\n\n    def get_last_layer(self, skip_time_mix=False, **kwargs):\n        if self.time_mode == \"attn-only\":\n            raise NotImplementedError(\"TODO\")\n        else:\n            return (\n                self.conv_out.time_mix_conv.weight\n                if not skip_time_mix\n                else self.conv_out.weight\n            )\n\n    def _make_attn(self) -> Callable:\n        if self.time_mode not in [\"conv-only\", \"only-last-conv\"]:\n            return partialclass(\n                make_time_attn,\n                alpha=self.alpha,\n                merge_strategy=self.merge_strategy,\n            )\n        else:\n            return super()._make_attn()\n\n    def _make_conv(self) -> Callable:\n        if self.time_mode != \"attn-only\":\n            return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)\n        else:\n            return Conv2DWrapper\n\n    def _make_resblock(self) -> Callable:\n        if self.time_mode not in [\"attn-only\", \"only-last-conv\"]:\n            return partialclass(\n                VideoResBlock,\n                video_kernel_size=self.video_kernel_size,\n                alpha=self.alpha,\n                merge_strategy=self.merge_strategy,\n            )\n        else:\n            return super()._make_resblock()\n"
  },
  {
    "path": "sgm/modules/diffusionmodules/__init__.py",
    "content": ""
  },
  {
    "path": "sgm/modules/diffusionmodules/denoiser.py",
    "content": "from typing import Dict, Union\n\nimport torch\nimport torch.nn as nn\n\nfrom ...util import append_dims, instantiate_from_config\nfrom .denoiser_scaling import DenoiserScaling\nfrom .discretizer import Discretization\n\n\nclass Denoiser(nn.Module):\n    def __init__(self, scaling_config: Dict):\n        super().__init__()\n\n        self.scaling: DenoiserScaling = instantiate_from_config(scaling_config)\n\n    def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:\n        return sigma\n\n    def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:\n        return c_noise\n\n    def forward(\n        self,\n        network: nn.Module,\n        input: torch.Tensor,\n        sigma: torch.Tensor,\n        cond: Dict,\n        **additional_model_inputs,\n    ) -> torch.Tensor:\n        sigma = self.possibly_quantize_sigma(sigma)\n        sigma_shape = sigma.shape\n        sigma = append_dims(sigma, input.ndim)\n        c_skip, c_out, c_in, c_noise = self.scaling(sigma)\n        c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))\n        return (\n            network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out\n            + input * c_skip\n        )\n\n\nclass DiscreteDenoiser(Denoiser):\n    def __init__(\n        self,\n        scaling_config: Dict,\n        num_idx: int,\n        discretization_config: Dict,\n        do_append_zero: bool = False,\n        quantize_c_noise: bool = True,\n        flip: bool = True,\n    ):\n        super().__init__(scaling_config)\n        self.discretization: Discretization = instantiate_from_config(\n            discretization_config\n        )\n        sigmas = self.discretization(num_idx, do_append_zero=do_append_zero, flip=flip)\n        self.register_buffer(\"sigmas\", sigmas)\n        self.quantize_c_noise = quantize_c_noise\n        self.num_idx = num_idx\n\n    def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor:\n        dists = sigma - self.sigmas[:, None]\n        return dists.abs().argmin(dim=0).view(sigma.shape)\n\n    def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor:\n        return self.sigmas[idx]\n\n    def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:\n        return self.idx_to_sigma(self.sigma_to_idx(sigma))\n\n    def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:\n        if self.quantize_c_noise:\n            return self.sigma_to_idx(c_noise)\n        else:\n            return c_noise\n"
  },
  {
    "path": "sgm/modules/diffusionmodules/denoiser_scaling.py",
    "content": "from abc import ABC, abstractmethod\nfrom typing import Tuple\n\nimport torch\n\n\nclass DenoiserScaling(ABC):\n    @abstractmethod\n    def __call__(\n        self, sigma: torch.Tensor\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n        pass\n\n\nclass EDMScaling:\n    def __init__(self, sigma_data: float = 0.5):\n        self.sigma_data = sigma_data\n\n    def __call__(\n        self, sigma: torch.Tensor\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n        c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)\n        c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5\n        c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5\n        c_noise = 0.25 * sigma.log()\n        return c_skip, c_out, c_in, c_noise\n\n\nclass EpsScaling:\n    def __call__(\n        self, sigma: torch.Tensor\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n        c_skip = torch.ones_like(sigma, device=sigma.device)\n        c_out = -sigma\n        c_in = 1 / (sigma**2 + 1.0) ** 0.5\n        c_noise = sigma.clone()\n        return c_skip, c_out, c_in, c_noise\n\n\nclass VScaling:\n    def __call__(\n        self, sigma: torch.Tensor\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n        c_skip = 1.0 / (sigma**2 + 1.0)\n        c_out = -sigma / (sigma**2 + 1.0) ** 0.5\n        c_in = 1.0 / (sigma**2 + 1.0) ** 0.5\n        c_noise = sigma.clone()\n        return c_skip, c_out, c_in, c_noise\n\n\nclass VScalingWithEDMcNoise(DenoiserScaling):\n    def __call__(\n        self, sigma: torch.Tensor\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n        c_skip = 1.0 / (sigma**2 + 1.0)\n        c_out = -sigma / (sigma**2 + 1.0) ** 0.5\n        c_in = 1.0 / (sigma**2 + 1.0) ** 0.5\n        c_noise = 0.25 * sigma.log()\n        return c_skip, c_out, c_in, c_noise\n"
  },
  {
    "path": "sgm/modules/diffusionmodules/denoiser_weighting.py",
    "content": "import torch\n\n\nclass UnitWeighting:\n    def __call__(self, sigma):\n        return torch.ones_like(sigma, device=sigma.device)\n\n\nclass EDMWeighting:\n    def __init__(self, sigma_data=0.5):\n        self.sigma_data = sigma_data\n\n    def __call__(self, sigma):\n        return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2\n\n\nclass VWeighting(EDMWeighting):\n    def __init__(self):\n        super().__init__(sigma_data=1.0)\n\n\nclass EpsWeighting:\n    def __call__(self, sigma):\n        return sigma**-2.0\n"
  },
  {
    "path": "sgm/modules/diffusionmodules/discretizer.py",
    "content": "from abc import abstractmethod\nfrom functools import partial\n\nimport numpy as np\nimport torch\n\nfrom ...modules.diffusionmodules.util import make_beta_schedule\nfrom ...util import append_zero\n\n\ndef generate_roughly_equally_spaced_steps(\n    num_substeps: int, max_step: int\n) -> np.ndarray:\n    return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]\n\n\nclass Discretization:\n    def __call__(self, n, do_append_zero=True, device=\"cpu\", flip=False):\n        sigmas = self.get_sigmas(n, device=device)\n        sigmas = append_zero(sigmas) if do_append_zero else sigmas\n        return sigmas if not flip else torch.flip(sigmas, (0,))\n\n    @abstractmethod\n    def get_sigmas(self, n, device):\n        pass\n\n\nclass EDMDiscretization(Discretization):\n    def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):\n        self.sigma_min = sigma_min\n        self.sigma_max = sigma_max\n        self.rho = rho\n\n    def get_sigmas(self, n, device=\"cpu\"):\n        ramp = torch.linspace(0, 1, n, device=device)\n        min_inv_rho = self.sigma_min ** (1 / self.rho)\n        max_inv_rho = self.sigma_max ** (1 / self.rho)\n        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho\n        return sigmas\n\n\nclass LegacyDDPMDiscretization(Discretization):\n    def __init__(\n        self,\n        linear_start=0.00085,\n        linear_end=0.0120,\n        num_timesteps=1000,\n    ):\n        super().__init__()\n        self.num_timesteps = num_timesteps\n        betas = make_beta_schedule(\n            \"linear\", num_timesteps, linear_start=linear_start, linear_end=linear_end\n        )\n        alphas = 1.0 - betas\n        self.alphas_cumprod = np.cumprod(alphas, axis=0)\n        self.to_torch = partial(torch.tensor, dtype=torch.float32)\n\n    def get_sigmas(self, n, device=\"cpu\"):\n        if n < self.num_timesteps:\n            timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)\n            alphas_cumprod = self.alphas_cumprod[timesteps]\n        elif n == self.num_timesteps:\n            alphas_cumprod = self.alphas_cumprod\n        else:\n            raise ValueError\n\n        to_torch = partial(torch.tensor, dtype=torch.float32, device=device)\n        sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5\n        return torch.flip(sigmas, (0,))\n"
  },
  {
    "path": "sgm/modules/diffusionmodules/guiders.py",
    "content": "import logging\nfrom abc import ABC, abstractmethod\nfrom typing import Dict, List, Literal, Optional, Tuple, Union\n\nimport torch\nfrom einops import rearrange, repeat\n\nfrom ...util import append_dims, default\n\nlogpy = logging.getLogger(__name__)\n\n\nclass Guider(ABC):\n    @abstractmethod\n    def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:\n        pass\n\n    def prepare_inputs(\n        self, x: torch.Tensor, s: float, c: Dict, uc: Dict\n    ) -> Tuple[torch.Tensor, float, Dict]:\n        pass\n\n\nclass VanillaCFG(Guider):\n    def __init__(self, scale: float):\n        self.scale = scale\n\n    def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:\n        x_u, x_c = x.chunk(2)\n        x_pred = x_u + self.scale * (x_c - x_u)\n        return x_pred\n\n    def prepare_inputs(self, x, s, c, uc):\n        c_out = dict()\n\n        for k in c:\n            if k in [\"vector\", \"crossattn\", \"concat\"]:\n                c_out[k] = torch.cat((uc[k], c[k]), 0)\n            else:\n                assert c[k] == uc[k]\n                c_out[k] = c[k]\n        return torch.cat([x] * 2), torch.cat([s] * 2), c_out\n\n\nclass IdentityGuider(Guider):\n    def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:\n        return x\n\n    def prepare_inputs(\n        self, x: torch.Tensor, s: float, c: Dict, uc: Dict\n    ) -> Tuple[torch.Tensor, float, Dict]:\n        c_out = dict()\n\n        for k in c:\n            c_out[k] = c[k]\n\n        return x, s, c_out\n\n\nclass LinearPredictionGuider(Guider):\n    def __init__(\n        self,\n        max_scale: float,\n        num_frames: int,\n        min_scale: float = 1.0,\n        additional_cond_keys: Optional[Union[List[str], str]] = None,\n    ):\n        self.min_scale = min_scale\n        self.max_scale = max_scale\n        self.num_frames = num_frames\n        self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0)\n\n        additional_cond_keys = default(additional_cond_keys, [])\n        if isinstance(additional_cond_keys, str):\n            additional_cond_keys = [additional_cond_keys]\n        self.additional_cond_keys = additional_cond_keys\n\n    def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:\n        x_u, x_c = x.chunk(2)\n\n        x_u = rearrange(x_u, \"(b t) ... -> b t ...\", t=self.num_frames)\n        x_c = rearrange(x_c, \"(b t) ... -> b t ...\", t=self.num_frames)\n        scale = repeat(self.scale, \"1 t -> b t\", b=x_u.shape[0])\n        scale = append_dims(scale, x_u.ndim).to(x_u.device)\n\n        return rearrange(x_u + scale * (x_c - x_u), \"b t ... -> (b t) ...\")\n\n    def prepare_inputs(\n        self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict\n    ) -> Tuple[torch.Tensor, torch.Tensor, dict]:\n        c_out = dict()\n\n        for k in c:\n            if k in [\"vector\", \"crossattn\", \"concat\"] + self.additional_cond_keys:\n                c_out[k] = torch.cat((uc[k], c[k]), 0)\n            else:\n                # assert c[k] == uc[k]\n                c_out[k] = c[k]\n        return torch.cat([x] * 2), torch.cat([s] * 2), c_out\n\n\nclass TrianglePredictionGuider(LinearPredictionGuider):\n    def __init__(\n        self,\n        max_scale: float,\n        num_frames: int,\n        min_scale: float = 1.0,\n        period: Union[float, List[float]] = 1.0,\n        period_fusing: Literal[\"mean\", \"multiply\", \"max\"] = \"max\",\n        additional_cond_keys: Optional[Union[List[str], str]] = None,\n    ):\n        super().__init__(max_scale, num_frames, min_scale, additional_cond_keys)\n        values = torch.linspace(0, 1, num_frames)\n        # Constructs a triangle wave\n        if isinstance(period, float):\n            period = [period]\n\n        scales = []\n        for p in period:\n            scales.append(self.triangle_wave(values, p))\n\n        if period_fusing == \"mean\":\n            scale = sum(scales) / len(period)\n        elif period_fusing == \"multiply\":\n            scale = torch.prod(torch.stack(scales), dim=0)\n        elif period_fusing == \"max\":\n            scale = torch.max(torch.stack(scales), dim=0).values\n        self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0)\n\n    def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor:\n        return 2 * (values / period - torch.floor(values / period + 0.5)).abs()\n\n\nclass TrapezoidPredictionGuider(LinearPredictionGuider):\n    def __init__(\n        self,\n        max_scale: float,\n        num_frames: int,\n        min_scale: float = 1.0,\n        edge_perc: float = 0.1,\n        additional_cond_keys: Optional[Union[List[str], str]] = None,\n    ):\n        super().__init__(max_scale, num_frames, min_scale, additional_cond_keys)\n\n        rise_steps = torch.linspace(min_scale, max_scale, int(num_frames * edge_perc))\n        fall_steps = torch.flip(rise_steps, [0])\n        self.scale = torch.cat(\n            [\n                rise_steps,\n                torch.ones(num_frames - 2 * int(num_frames * edge_perc)),\n                fall_steps,\n            ]\n        ).unsqueeze(0)\n\n        \nclass SpatiotemporalPredictionGuider(LinearPredictionGuider):\n    def __init__(\n        self,\n        max_scale: float,\n        num_frames: int,\n        num_views: int = 1,\n        min_scale: float = 1.0,\n        additional_cond_keys: Optional[Union[List[str], str]] = None,\n    ):\n        super().__init__(max_scale, num_frames, min_scale, additional_cond_keys)\n        V = num_views\n        T = num_frames // V\n        scale = torch.zeros(num_frames).view(T, V)\n        scale += torch.linspace(0, 1, T)[:,None] * 0.5\n        scale += self.triangle_wave(torch.linspace(0, 1, V))[None,:] * 0.5\n        scale = scale.flatten()\n        self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0)\n\n    def triangle_wave(self, values: torch.Tensor, period=1) -> torch.Tensor:\n        return 2 * (values / period - torch.floor(values / period + 0.5)).abs()"
  },
  {
    "path": "sgm/modules/diffusionmodules/loss.py",
    "content": "from typing import Dict, List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\n\nfrom ...modules.autoencoding.lpips.loss.lpips import LPIPS\nfrom ...modules.encoders.modules import GeneralConditioner\nfrom ...util import append_dims, instantiate_from_config\nfrom .denoiser import Denoiser\n\n\nclass StandardDiffusionLoss(nn.Module):\n    def __init__(\n        self,\n        sigma_sampler_config: dict,\n        loss_weighting_config: dict,\n        loss_type: str = \"l2\",\n        offset_noise_level: float = 0.0,\n        batch2model_keys: Optional[Union[str, List[str]]] = None,\n    ):\n        super().__init__()\n\n        assert loss_type in [\"l2\", \"l1\", \"lpips\"]\n\n        self.sigma_sampler = instantiate_from_config(sigma_sampler_config)\n        self.loss_weighting = instantiate_from_config(loss_weighting_config)\n\n        self.loss_type = loss_type\n        self.offset_noise_level = offset_noise_level\n\n        if loss_type == \"lpips\":\n            self.lpips = LPIPS().eval()\n\n        if not batch2model_keys:\n            batch2model_keys = []\n\n        if isinstance(batch2model_keys, str):\n            batch2model_keys = [batch2model_keys]\n\n        self.batch2model_keys = set(batch2model_keys)\n\n    def get_noised_input(\n        self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor\n    ) -> torch.Tensor:\n        noised_input = input + noise * sigmas_bc\n        return noised_input\n\n    def forward(\n        self,\n        network: nn.Module,\n        denoiser: Denoiser,\n        conditioner: GeneralConditioner,\n        input: torch.Tensor,\n        batch: Dict,\n    ) -> torch.Tensor:\n        cond = conditioner(batch)\n        return self._forward(network, denoiser, cond, input, batch)\n\n    def _forward(\n        self,\n        network: nn.Module,\n        denoiser: Denoiser,\n        cond: Dict,\n        input: torch.Tensor,\n        batch: Dict,\n    ) -> Tuple[torch.Tensor, Dict]:\n        additional_model_inputs = {\n            key: batch[key] for key in self.batch2model_keys.intersection(batch)\n        }\n        sigmas = self.sigma_sampler(input.shape[0]).to(input)\n\n        noise = torch.randn_like(input)\n        if self.offset_noise_level > 0.0:\n            offset_shape = (\n                (input.shape[0], 1, input.shape[2])\n                if self.n_frames is not None\n                else (input.shape[0], input.shape[1])\n            )\n            noise = noise + self.offset_noise_level * append_dims(\n                torch.randn(offset_shape, device=input.device),\n                input.ndim,\n            )\n        sigmas_bc = append_dims(sigmas, input.ndim)\n        noised_input = self.get_noised_input(sigmas_bc, noise, input)\n\n        model_output = denoiser(\n            network, noised_input, sigmas, cond, **additional_model_inputs\n        )\n        w = append_dims(self.loss_weighting(sigmas), input.ndim)\n        return self.get_loss(model_output, input, w)\n\n    def get_loss(self, model_output, target, w):\n        if self.loss_type == \"l2\":\n            return torch.mean(\n                (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1\n            )\n        elif self.loss_type == \"l1\":\n            return torch.mean(\n                (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1\n            )\n        elif self.loss_type == \"lpips\":\n            loss = self.lpips(model_output, target).reshape(-1)\n            return loss\n        else:\n            raise NotImplementedError(f\"Unknown loss type {self.loss_type}\")\n"
  },
  {
    "path": "sgm/modules/diffusionmodules/loss_weighting.py",
    "content": "from abc import ABC, abstractmethod\n\nimport torch\n\n\nclass DiffusionLossWeighting(ABC):\n    @abstractmethod\n    def __call__(self, sigma: torch.Tensor) -> torch.Tensor:\n        pass\n\n\nclass UnitWeighting(DiffusionLossWeighting):\n    def __call__(self, sigma: torch.Tensor) -> torch.Tensor:\n        return torch.ones_like(sigma, device=sigma.device)\n\n\nclass EDMWeighting(DiffusionLossWeighting):\n    def __init__(self, sigma_data: float = 0.5):\n        self.sigma_data = sigma_data\n\n    def __call__(self, sigma: torch.Tensor) -> torch.Tensor:\n        return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2\n\n\nclass VWeighting(EDMWeighting):\n    def __init__(self):\n        super().__init__(sigma_data=1.0)\n\n\nclass EpsWeighting(DiffusionLossWeighting):\n    def __call__(self, sigma: torch.Tensor) -> torch.Tensor:\n        return sigma**-2.0\n"
  },
  {
    "path": "sgm/modules/diffusionmodules/model.py",
    "content": "# pytorch_diffusion + derived encoder decoder\nimport logging\nimport math\nfrom typing import Any, Callable, Optional\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom einops import rearrange\nfrom packaging import version\n\nlogpy = logging.getLogger(__name__)\n\ntry:\n    import xformers\n    import xformers.ops\n\n    XFORMERS_IS_AVAILABLE = True\nexcept:\n    XFORMERS_IS_AVAILABLE = False\n    logpy.warning(\"no module 'xformers'. Processing without...\")\n\nfrom ...modules.attention import LinearAttention, MemoryEfficientCrossAttention\n\n\ndef get_timestep_embedding(timesteps, embedding_dim):\n    \"\"\"\n    This matches the implementation in Denoising Diffusion Probabilistic Models:\n    From Fairseq.\n    Build sinusoidal embeddings.\n    This matches the implementation in tensor2tensor, but differs slightly\n    from the description in Section 3.5 of \"Attention Is All You Need\".\n    \"\"\"\n    assert len(timesteps.shape) == 1\n\n    half_dim = embedding_dim // 2\n    emb = math.log(10000) / (half_dim - 1)\n    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)\n    emb = emb.to(device=timesteps.device)\n    emb = timesteps.float()[:, None] * emb[None, :]\n    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n    if embedding_dim % 2 == 1:  # zero pad\n        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))\n    return emb\n\n\ndef nonlinearity(x):\n    # swish\n    return x * torch.sigmoid(x)\n\n\ndef Normalize(in_channels, num_groups=32):\n    return torch.nn.GroupNorm(\n        num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True\n    )\n\n\nclass Upsample(nn.Module):\n    def __init__(self, in_channels, with_conv):\n        super().__init__()\n        self.with_conv = with_conv\n        if self.with_conv:\n            self.conv = torch.nn.Conv2d(\n                in_channels, in_channels, kernel_size=3, stride=1, padding=1\n            )\n\n    def forward(self, x):\n        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode=\"nearest\")\n        if self.with_conv:\n            x = self.conv(x)\n        return x\n\n\nclass Downsample(nn.Module):\n    def __init__(self, in_channels, with_conv):\n        super().__init__()\n        self.with_conv = with_conv\n        if self.with_conv:\n            # no asymmetric padding in torch conv, must do it ourselves\n            self.conv = torch.nn.Conv2d(\n                in_channels, in_channels, kernel_size=3, stride=2, padding=0\n            )\n\n    def forward(self, x):\n        if self.with_conv:\n            pad = (0, 1, 0, 1)\n            x = torch.nn.functional.pad(x, pad, mode=\"constant\", value=0)\n            x = self.conv(x)\n        else:\n            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)\n        return x\n\n\nclass ResnetBlock(nn.Module):\n    def __init__(\n        self,\n        *,\n        in_channels,\n        out_channels=None,\n        conv_shortcut=False,\n        dropout,\n        temb_channels=512,\n    ):\n        super().__init__()\n        self.in_channels = in_channels\n        out_channels = in_channels if out_channels is None else out_channels\n        self.out_channels = out_channels\n        self.use_conv_shortcut = conv_shortcut\n\n        self.norm1 = Normalize(in_channels)\n        self.conv1 = torch.nn.Conv2d(\n            in_channels, out_channels, kernel_size=3, stride=1, padding=1\n        )\n        if temb_channels > 0:\n            self.temb_proj = torch.nn.Linear(temb_channels, out_channels)\n        self.norm2 = Normalize(out_channels)\n        self.dropout = torch.nn.Dropout(dropout)\n        self.conv2 = torch.nn.Conv2d(\n            out_channels, out_channels, kernel_size=3, stride=1, padding=1\n        )\n        if self.in_channels != self.out_channels:\n            if self.use_conv_shortcut:\n                self.conv_shortcut = torch.nn.Conv2d(\n                    in_channels, out_channels, kernel_size=3, stride=1, padding=1\n                )\n            else:\n                self.nin_shortcut = torch.nn.Conv2d(\n                    in_channels, out_channels, kernel_size=1, stride=1, padding=0\n                )\n\n    def forward(self, x, temb):\n        h = x\n        h = self.norm1(h)\n        h = nonlinearity(h)\n        h = self.conv1(h)\n\n        if temb is not None:\n            h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]\n\n        h = self.norm2(h)\n        h = nonlinearity(h)\n        h = self.dropout(h)\n        h = self.conv2(h)\n\n        if self.in_channels != self.out_channels:\n            if self.use_conv_shortcut:\n                x = self.conv_shortcut(x)\n            else:\n                x = self.nin_shortcut(x)\n\n        return x + h\n\n\nclass LinAttnBlock(LinearAttention):\n    \"\"\"to match AttnBlock usage\"\"\"\n\n    def __init__(self, in_channels):\n        super().__init__(dim=in_channels, heads=1, dim_head=in_channels)\n\n\nclass AttnBlock(nn.Module):\n    def __init__(self, in_channels):\n        super().__init__()\n        self.in_channels = in_channels\n\n        self.norm = Normalize(in_channels)\n        self.q = torch.nn.Conv2d(\n            in_channels, in_channels, kernel_size=1, stride=1, padding=0\n        )\n        self.k = torch.nn.Conv2d(\n            in_channels, in_channels, kernel_size=1, stride=1, padding=0\n        )\n        self.v = torch.nn.Conv2d(\n            in_channels, in_channels, kernel_size=1, stride=1, padding=0\n        )\n        self.proj_out = torch.nn.Conv2d(\n            in_channels, in_channels, kernel_size=1, stride=1, padding=0\n        )\n\n    def attention(self, h_: torch.Tensor) -> torch.Tensor:\n        h_ = self.norm(h_)\n        q = self.q(h_)\n        k = self.k(h_)\n        v = self.v(h_)\n\n        b, c, h, w = q.shape\n        q, k, v = map(\n            lambda x: rearrange(x, \"b c h w -> b 1 (h w) c\").contiguous(), (q, k, v)\n        )\n        h_ = torch.nn.functional.scaled_dot_product_attention(\n            q, k, v\n        )  # scale is dim ** -0.5 per default\n        # compute attention\n\n        return rearrange(h_, \"b 1 (h w) c -> b c h w\", h=h, w=w, c=c, b=b)\n\n    def forward(self, x, **kwargs):\n        h_ = x\n        h_ = self.attention(h_)\n        h_ = self.proj_out(h_)\n        return x + h_\n\n\nclass MemoryEfficientAttnBlock(nn.Module):\n    \"\"\"\n    Uses xformers efficient implementation,\n    see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223\n    Note: this is a single-head self-attention operation\n    \"\"\"\n\n    #\n    def __init__(self, in_channels):\n        super().__init__()\n        self.in_channels = in_channels\n\n        self.norm = Normalize(in_channels)\n        self.q = torch.nn.Conv2d(\n            in_channels, in_channels, kernel_size=1, stride=1, padding=0\n        )\n        self.k = torch.nn.Conv2d(\n            in_channels, in_channels, kernel_size=1, stride=1, padding=0\n        )\n        self.v = torch.nn.Conv2d(\n            in_channels, in_channels, kernel_size=1, stride=1, padding=0\n        )\n        self.proj_out = torch.nn.Conv2d(\n            in_channels, in_channels, kernel_size=1, stride=1, padding=0\n        )\n        self.attention_op: Optional[Any] = None\n\n    def attention(self, h_: torch.Tensor) -> torch.Tensor:\n        h_ = self.norm(h_)\n        q = self.q(h_)\n        k = self.k(h_)\n        v = self.v(h_)\n\n        # compute attention\n        B, C, H, W = q.shape\n        q, k, v = map(lambda x: rearrange(x, \"b c h w -> b (h w) c\"), (q, k, v))\n\n        q, k, v = map(\n            lambda t: t.unsqueeze(3)\n            .reshape(B, t.shape[1], 1, C)\n            .permute(0, 2, 1, 3)\n            .reshape(B * 1, t.shape[1], C)\n            .contiguous(),\n            (q, k, v),\n        )\n        out = xformers.ops.memory_efficient_attention(\n            q, k, v, attn_bias=None, op=self.attention_op\n        )\n\n        out = (\n            out.unsqueeze(0)\n            .reshape(B, 1, out.shape[1], C)\n            .permute(0, 2, 1, 3)\n            .reshape(B, out.shape[1], C)\n        )\n        return rearrange(out, \"b (h w) c -> b c h w\", b=B, h=H, w=W, c=C)\n\n    def forward(self, x, **kwargs):\n        h_ = x\n        h_ = self.attention(h_)\n        h_ = self.proj_out(h_)\n        return x + h_\n\n\nclass MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):\n    def forward(self, x, context=None, mask=None, **unused_kwargs):\n        b, c, h, w = x.shape\n        x = rearrange(x, \"b c h w -> b (h w) c\")\n        out = super().forward(x, context=context, mask=mask)\n        out = rearrange(out, \"b (h w) c -> b c h w\", h=h, w=w, c=c)\n        return x + out\n\n\ndef make_attn(in_channels, attn_type=\"vanilla\", attn_kwargs=None):\n    assert attn_type in [\n        \"vanilla\",\n        \"vanilla-xformers\",\n        \"memory-efficient-cross-attn\",\n        \"linear\",\n        \"none\",\n    ], f\"attn_type {attn_type} unknown\"\n    if (\n        version.parse(torch.__version__) < version.parse(\"2.0.0\")\n        and attn_type != \"none\"\n    ):\n        assert XFORMERS_IS_AVAILABLE, (\n            f\"We do not support vanilla attention in {torch.__version__} anymore, \"\n            f\"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'\"\n        )\n        attn_type = \"vanilla-xformers\"\n    logpy.info(f\"making attention of type '{attn_type}' with {in_channels} in_channels\")\n    if attn_type == \"vanilla\":\n        assert attn_kwargs is None\n        return AttnBlock(in_channels)\n    elif attn_type == \"vanilla-xformers\":\n        logpy.info(\n            f\"building MemoryEfficientAttnBlock with {in_channels} in_channels...\"\n        )\n        return MemoryEfficientAttnBlock(in_channels)\n    elif type == \"memory-efficient-cross-attn\":\n        attn_kwargs[\"query_dim\"] = in_channels\n        return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)\n    elif attn_type == \"none\":\n        return nn.Identity(in_channels)\n    else:\n        return LinAttnBlock(in_channels)\n\n\nclass Model(nn.Module):\n    def __init__(\n        self,\n        *,\n        ch,\n        out_ch,\n        ch_mult=(1, 2, 4, 8),\n        num_res_blocks,\n        attn_resolutions,\n        dropout=0.0,\n        resamp_with_conv=True,\n        in_channels,\n        resolution,\n        use_timestep=True,\n        use_linear_attn=False,\n        attn_type=\"vanilla\",\n    ):\n        super().__init__()\n        if use_linear_attn:\n            attn_type = \"linear\"\n        self.ch = ch\n        self.temb_ch = self.ch * 4\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        self.resolution = resolution\n        self.in_channels = in_channels\n\n        self.use_timestep = use_timestep\n        if self.use_timestep:\n            # timestep embedding\n            self.temb = nn.Module()\n            self.temb.dense = nn.ModuleList(\n                [\n                    torch.nn.Linear(self.ch, self.temb_ch),\n                    torch.nn.Linear(self.temb_ch, self.temb_ch),\n                ]\n            )\n\n        # downsampling\n        self.conv_in = torch.nn.Conv2d(\n            in_channels, self.ch, kernel_size=3, stride=1, padding=1\n        )\n\n        curr_res = resolution\n        in_ch_mult = (1,) + tuple(ch_mult)\n        self.down = nn.ModuleList()\n        for i_level in range(self.num_resolutions):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_in = ch * in_ch_mult[i_level]\n            block_out = ch * ch_mult[i_level]\n            for i_block in range(self.num_res_blocks):\n                block.append(\n                    ResnetBlock(\n                        in_channels=block_in,\n                        out_channels=block_out,\n                        temb_channels=self.temb_ch,\n                        dropout=dropout,\n                    )\n                )\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(make_attn(block_in, attn_type=attn_type))\n            down = nn.Module()\n            down.block = block\n            down.attn = attn\n            if i_level != self.num_resolutions - 1:\n                down.downsample = Downsample(block_in, resamp_with_conv)\n                curr_res = curr_res // 2\n            self.down.append(down)\n\n        # middle\n        self.mid = nn.Module()\n        self.mid.block_1 = ResnetBlock(\n            in_channels=block_in,\n            out_channels=block_in,\n            temb_channels=self.temb_ch,\n            dropout=dropout,\n        )\n        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)\n        self.mid.block_2 = ResnetBlock(\n            in_channels=block_in,\n            out_channels=block_in,\n            temb_channels=self.temb_ch,\n            dropout=dropout,\n        )\n\n        # upsampling\n        self.up = nn.ModuleList()\n        for i_level in reversed(range(self.num_resolutions)):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_out = ch * ch_mult[i_level]\n            skip_in = ch * ch_mult[i_level]\n            for i_block in range(self.num_res_blocks + 1):\n                if i_block == self.num_res_blocks:\n                    skip_in = ch * in_ch_mult[i_level]\n                block.append(\n                    ResnetBlock(\n                        in_channels=block_in + skip_in,\n                        out_channels=block_out,\n                        temb_channels=self.temb_ch,\n                        dropout=dropout,\n                    )\n                )\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(make_attn(block_in, attn_type=attn_type))\n            up = nn.Module()\n            up.block = block\n            up.attn = attn\n            if i_level != 0:\n                up.upsample = Upsample(block_in, resamp_with_conv)\n                curr_res = curr_res * 2\n            self.up.insert(0, up)  # prepend to get consistent order\n\n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = torch.nn.Conv2d(\n            block_in, out_ch, kernel_size=3, stride=1, padding=1\n        )\n\n    def forward(self, x, t=None, context=None):\n        # assert x.shape[2] == x.shape[3] == self.resolution\n        if context is not None:\n            # assume aligned context, cat along channel axis\n            x = torch.cat((x, context), dim=1)\n        if self.use_timestep:\n            # timestep embedding\n            assert t is not None\n            temb = get_timestep_embedding(t, self.ch)\n            temb = self.temb.dense[0](temb)\n            temb = nonlinearity(temb)\n            temb = self.temb.dense[1](temb)\n        else:\n            temb = None\n\n        # downsampling\n        hs = [self.conv_in(x)]\n        for i_level in range(self.num_resolutions):\n            for i_block in range(self.num_res_blocks):\n                h = self.down[i_level].block[i_block](hs[-1], temb)\n                if len(self.down[i_level].attn) > 0:\n                    h = self.down[i_level].attn[i_block](h)\n                hs.append(h)\n            if i_level != self.num_resolutions - 1:\n                hs.append(self.down[i_level].downsample(hs[-1]))\n\n        # middle\n        h = hs[-1]\n        h = self.mid.block_1(h, temb)\n        h = self.mid.attn_1(h)\n        h = self.mid.block_2(h, temb)\n\n        # upsampling\n        for i_level in reversed(range(self.num_resolutions)):\n            for i_block in range(self.num_res_blocks + 1):\n                h = self.up[i_level].block[i_block](\n                    torch.cat([h, hs.pop()], dim=1), temb\n                )\n                if len(self.up[i_level].attn) > 0:\n                    h = self.up[i_level].attn[i_block](h)\n            if i_level != 0:\n                h = self.up[i_level].upsample(h)\n\n        # end\n        h = self.norm_out(h)\n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        return h\n\n    def get_last_layer(self):\n        return self.conv_out.weight\n\n\nclass Encoder(nn.Module):\n    def __init__(\n        self,\n        *,\n        ch,\n        out_ch,\n        ch_mult=(1, 2, 4, 8),\n        num_res_blocks,\n        attn_resolutions,\n        dropout=0.0,\n        resamp_with_conv=True,\n        in_channels,\n        resolution,\n        z_channels,\n        double_z=True,\n        use_linear_attn=False,\n        attn_type=\"vanilla\",\n        **ignore_kwargs,\n    ):\n        super().__init__()\n        if use_linear_attn:\n            attn_type = \"linear\"\n        self.ch = ch\n        self.temb_ch = 0\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        self.resolution = resolution\n        self.in_channels = in_channels\n\n        # downsampling\n        self.conv_in = torch.nn.Conv2d(\n            in_channels, self.ch, kernel_size=3, stride=1, padding=1\n        )\n\n        curr_res = resolution\n        in_ch_mult = (1,) + tuple(ch_mult)\n        self.in_ch_mult = in_ch_mult\n        self.down = nn.ModuleList()\n        for i_level in range(self.num_resolutions):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_in = ch * in_ch_mult[i_level]\n            block_out = ch * ch_mult[i_level]\n            for i_block in range(self.num_res_blocks):\n                block.append(\n                    ResnetBlock(\n                        in_channels=block_in,\n                        out_channels=block_out,\n                        temb_channels=self.temb_ch,\n                        dropout=dropout,\n                    )\n                )\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(make_attn(block_in, attn_type=attn_type))\n            down = nn.Module()\n            down.block = block\n            down.attn = attn\n            if i_level != self.num_resolutions - 1:\n                down.downsample = Downsample(block_in, resamp_with_conv)\n                curr_res = curr_res // 2\n            self.down.append(down)\n\n        # middle\n        self.mid = nn.Module()\n        self.mid.block_1 = ResnetBlock(\n            in_channels=block_in,\n            out_channels=block_in,\n            temb_channels=self.temb_ch,\n            dropout=dropout,\n        )\n        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)\n        self.mid.block_2 = ResnetBlock(\n            in_channels=block_in,\n            out_channels=block_in,\n            temb_channels=self.temb_ch,\n            dropout=dropout,\n        )\n\n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = torch.nn.Conv2d(\n            block_in,\n            2 * z_channels if double_z else z_channels,\n            kernel_size=3,\n            stride=1,\n            padding=1,\n        )\n\n    def forward(self, x):\n        # timestep embedding\n        temb = None\n\n        # downsampling\n        hs = [self.conv_in(x)]\n        for i_level in range(self.num_resolutions):\n            for i_block in range(self.num_res_blocks):\n                h = self.down[i_level].block[i_block](hs[-1], temb)\n                if len(self.down[i_level].attn) > 0:\n                    h = self.down[i_level].attn[i_block](h)\n                hs.append(h)\n            if i_level != self.num_resolutions - 1:\n                hs.append(self.down[i_level].downsample(hs[-1]))\n\n        # middle\n        h = hs[-1]\n        h = self.mid.block_1(h, temb)\n        h = self.mid.attn_1(h)\n        h = self.mid.block_2(h, temb)\n\n        # end\n        h = self.norm_out(h)\n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        return h\n\n\nclass Decoder(nn.Module):\n    def __init__(\n        self,\n        *,\n        ch,\n        out_ch,\n        ch_mult=(1, 2, 4, 8),\n        num_res_blocks,\n        attn_resolutions,\n        dropout=0.0,\n        resamp_with_conv=True,\n        in_channels,\n        resolution,\n        z_channels,\n        give_pre_end=False,\n        tanh_out=False,\n        use_linear_attn=False,\n        attn_type=\"vanilla\",\n        **ignorekwargs,\n    ):\n        super().__init__()\n        if use_linear_attn:\n            attn_type = \"linear\"\n        self.ch = ch\n        self.temb_ch = 0\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        self.resolution = resolution\n        self.in_channels = in_channels\n        self.give_pre_end = give_pre_end\n        self.tanh_out = tanh_out\n\n        # compute in_ch_mult, block_in and curr_res at lowest res\n        in_ch_mult = (1,) + tuple(ch_mult)\n        block_in = ch * ch_mult[self.num_resolutions - 1]\n        curr_res = resolution // 2 ** (self.num_resolutions - 1)\n        self.z_shape = (1, z_channels, curr_res, curr_res)\n        logpy.info(\n            \"Working with z of shape {} = {} dimensions.\".format(\n                self.z_shape, np.prod(self.z_shape)\n            )\n        )\n\n        make_attn_cls = self._make_attn()\n        make_resblock_cls = self._make_resblock()\n        make_conv_cls = self._make_conv()\n        # z to block_in\n        self.conv_in = torch.nn.Conv2d(\n            z_channels, block_in, kernel_size=3, stride=1, padding=1\n        )\n\n        # middle\n        self.mid = nn.Module()\n        self.mid.block_1 = make_resblock_cls(\n            in_channels=block_in,\n            out_channels=block_in,\n            temb_channels=self.temb_ch,\n            dropout=dropout,\n        )\n        self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)\n        self.mid.block_2 = make_resblock_cls(\n            in_channels=block_in,\n            out_channels=block_in,\n            temb_channels=self.temb_ch,\n            dropout=dropout,\n        )\n\n        # upsampling\n        self.up = nn.ModuleList()\n        for i_level in reversed(range(self.num_resolutions)):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_out = ch * ch_mult[i_level]\n            for i_block in range(self.num_res_blocks + 1):\n                block.append(\n                    make_resblock_cls(\n                        in_channels=block_in,\n                        out_channels=block_out,\n                        temb_channels=self.temb_ch,\n                        dropout=dropout,\n                    )\n                )\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(make_attn_cls(block_in, attn_type=attn_type))\n            up = nn.Module()\n            up.block = block\n            up.attn = attn\n            if i_level != 0:\n                up.upsample = Upsample(block_in, resamp_with_conv)\n                curr_res = curr_res * 2\n            self.up.insert(0, up)  # prepend to get consistent order\n\n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = make_conv_cls(\n            block_in, out_ch, kernel_size=3, stride=1, padding=1\n        )\n\n    def _make_attn(self) -> Callable:\n        return make_attn\n\n    def _make_resblock(self) -> Callable:\n        return ResnetBlock\n\n    def _make_conv(self) -> Callable:\n        return torch.nn.Conv2d\n\n    def get_last_layer(self, **kwargs):\n        return self.conv_out.weight\n\n    def forward(self, z, **kwargs):\n        # assert z.shape[1:] == self.z_shape[1:]\n        self.last_z_shape = z.shape\n\n        # timestep embedding\n        temb = None\n\n        # z to block_in\n        h = self.conv_in(z)\n\n        # middle\n        h = self.mid.block_1(h, temb, **kwargs)\n        h = self.mid.attn_1(h, **kwargs)\n        h = self.mid.block_2(h, temb, **kwargs)\n\n        # upsampling\n        for i_level in reversed(range(self.num_resolutions)):\n            for i_block in range(self.num_res_blocks + 1):\n                h = self.up[i_level].block[i_block](h, temb, **kwargs)\n                if len(self.up[i_level].attn) > 0:\n                    h = self.up[i_level].attn[i_block](h, **kwargs)\n            if i_level != 0:\n                h = self.up[i_level].upsample(h)\n\n        # end\n        if self.give_pre_end:\n            return h\n\n        h = self.norm_out(h)\n        h = nonlinearity(h)\n        h = self.conv_out(h, **kwargs)\n        if self.tanh_out:\n            h = torch.tanh(h)\n        return h\n"
  },
  {
    "path": "sgm/modules/diffusionmodules/openaimodel.py",
    "content": "import logging\nimport math\nfrom abc import abstractmethod\nfrom typing import Iterable, List, Optional, Tuple, Union\n\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom torch.utils.checkpoint import checkpoint\n\nfrom ...modules.attention import SpatialTransformer\nfrom ...modules.diffusionmodules.util import (avg_pool_nd, conv_nd, linear,\n                                              normalization,\n                                              timestep_embedding, zero_module)\nfrom ...modules.video_attention import SpatialVideoTransformer\nfrom ...util import exists\n\nlogpy = logging.getLogger(__name__)\n\n\nclass AttentionPool2d(nn.Module):\n    \"\"\"\n    Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py\n    \"\"\"\n\n    def __init__(\n        self,\n        spacial_dim: int,\n        embed_dim: int,\n        num_heads_channels: int,\n        output_dim: Optional[int] = None,\n    ):\n        super().__init__()\n        self.positional_embedding = nn.Parameter(\n            th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5\n        )\n        self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)\n        self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)\n        self.num_heads = embed_dim // num_heads_channels\n        self.attention = QKVAttention(self.num_heads)\n\n    def forward(self, x: th.Tensor) -> th.Tensor:\n        b, c, _ = x.shape\n        x = x.reshape(b, c, -1)\n        x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)\n        x = x + self.positional_embedding[None, :, :].to(x.dtype)\n        x = self.qkv_proj(x)\n        x = self.attention(x)\n        x = self.c_proj(x)\n        return x[:, :, 0]\n\n\nclass TimestepBlock(nn.Module):\n    \"\"\"\n    Any module where forward() takes timestep embeddings as a second argument.\n    \"\"\"\n\n    @abstractmethod\n    def forward(self, x: th.Tensor, emb: th.Tensor):\n        \"\"\"\n        Apply the module to `x` given `emb` timestep embeddings.\n        \"\"\"\n\n\nclass TimestepEmbedSequential(nn.Sequential, TimestepBlock):\n    \"\"\"\n    A sequential module that passes timestep embeddings to the children that\n    support it as an extra input.\n    \"\"\"\n\n    def forward(\n        self,\n        x: th.Tensor,\n        emb: th.Tensor,\n        context: Optional[th.Tensor] = None,\n        cam: Optional[th.Tensor] = None,\n        image_only_indicator: Optional[th.Tensor] = None,\n        cond_view: Optional[th.Tensor] = None,\n        cond_motion: Optional[th.Tensor] = None,\n        time_context: Optional[int] = None,\n        num_video_frames: Optional[int] = None,\n        time_step: Optional[int] = None,\n        name: Optional[str] = None,\n    ):\n        from ...modules.diffusionmodules.video_model import VideoResBlock, PostHocResBlockWithTime\n        from ...modules.spacetime_attention import (\n            BasicTransformerTimeMixBlock,\n            PostHocSpatialTransformerWithTimeMixing,\n            PostHocSpatialTransformerWithTimeMixingAndMotion,\n        )\n\n        for layer in self:\n            module = layer\n\n            if isinstance(\n                module,\n                (\n                    BasicTransformerTimeMixBlock,\n                    PostHocSpatialTransformerWithTimeMixing,\n                ),\n            ):\n                x = layer(\n                    x,\n                    context,\n                    emb,\n                    time_context,\n                    num_video_frames,\n                    image_only_indicator,\n                    cond_view,\n                    cond_motion,\n                    time_step,\n                    name,\n                )\n            elif isinstance(\n                module, \n                (\n                    PostHocSpatialTransformerWithTimeMixingAndMotion,\n                ),\n            ):\n                x = layer(\n                    x,\n                    context,\n                    emb,\n                    time_context,\n                    num_video_frames,\n                    image_only_indicator,\n                    cond_view,\n                    cond_motion,\n                    time_step,\n                    name,\n                )\n            elif isinstance(module, SpatialVideoTransformer):\n                x = layer(\n                    x,\n                    context,\n                    time_context,\n                    num_video_frames,\n                    image_only_indicator,\n                    # time_step,\n                )\n            elif isinstance(module, PostHocResBlockWithTime):\n                x = layer(x, emb, num_video_frames, image_only_indicator)\n            elif isinstance(module, VideoResBlock):\n                x = layer(x, emb, num_video_frames, image_only_indicator)\n            elif isinstance(module, TimestepBlock) and not isinstance(\n                module, VideoResBlock\n            ):\n                x = layer(x, emb)\n            elif isinstance(module, SpatialTransformer):\n                x = layer(x, context)\n            else:\n                x = layer(x)\n        return x\n\n\nclass Upsample(nn.Module):\n    \"\"\"\n    An upsampling layer with an optional convolution.\n    :param channels: channels in the inputs and outputs.\n    :param use_conv: a bool determining if a convolution is applied.\n    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then\n                 upsampling occurs in the inner-two dimensions.\n    \"\"\"\n\n    def __init__(\n        self,\n        channels: int,\n        use_conv: bool,\n        dims: int = 2,\n        out_channels: Optional[int] = None,\n        padding: int = 1,\n        third_up: bool = False,\n        kernel_size: int = 3,\n        scale_factor: int = 2,\n    ):\n        super().__init__()\n        self.channels = channels\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.dims = dims\n        self.third_up = third_up\n        self.scale_factor = scale_factor\n        if use_conv:\n            self.conv = conv_nd(\n                dims, self.channels, self.out_channels, kernel_size, padding=padding\n            )\n\n    def forward(self, x: th.Tensor) -> th.Tensor:\n        assert x.shape[1] == self.channels\n\n        if self.dims == 3:\n            t_factor = 1 if not self.third_up else self.scale_factor\n            x = F.interpolate(\n                x,\n                (\n                    t_factor * x.shape[2],\n                    x.shape[3] * self.scale_factor,\n                    x.shape[4] * self.scale_factor,\n                ),\n                mode=\"nearest\",\n            )\n        else:\n            x = F.interpolate(x, scale_factor=self.scale_factor, mode=\"nearest\")\n        if self.use_conv:\n            x = self.conv(x)\n        return x\n\n\nclass Downsample(nn.Module):\n    \"\"\"\n    A downsampling layer with an optional convolution.\n    :param channels: channels in the inputs and outputs.\n    :param use_conv: a bool determining if a convolution is applied.\n    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then\n                 downsampling occurs in the inner-two dimensions.\n    \"\"\"\n\n    def __init__(\n        self,\n        channels: int,\n        use_conv: bool,\n        dims: int = 2,\n        out_channels: Optional[int] = None,\n        padding: int = 1,\n        third_down: bool = False,\n    ):\n        super().__init__()\n        self.channels = channels\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.dims = dims\n        stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))\n        if use_conv:\n            logpy.info(f\"Building a Downsample layer with {dims} dims.\")\n            logpy.info(\n                f\"  --> settings are: \\n in-chn: {self.channels}, out-chn: {self.out_channels}, \"\n                f\"kernel-size: 3, stride: {stride}, padding: {padding}\"\n            )\n            if dims == 3:\n                logpy.info(f\"  --> Downsampling third axis (time): {third_down}\")\n            self.op = conv_nd(\n                dims,\n                self.channels,\n                self.out_channels,\n                3,\n                stride=stride,\n                padding=padding,\n            )\n        else:\n            assert self.channels == self.out_channels\n            self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)\n\n    def forward(self, x: th.Tensor) -> th.Tensor:\n        assert x.shape[1] == self.channels\n\n        return self.op(x)\n\n\nclass ResBlock(TimestepBlock):\n    \"\"\"\n    A residual block that can optionally change the number of channels.\n    :param channels: the number of input channels.\n    :param emb_channels: the number of timestep embedding channels.\n    :param dropout: the rate of dropout.\n    :param out_channels: if specified, the number of out channels.\n    :param use_conv: if True and out_channels is specified, use a spatial\n        convolution instead of a smaller 1x1 convolution to change the\n        channels in the skip connection.\n    :param dims: determines if the signal is 1D, 2D, or 3D.\n    :param use_checkpoint: if True, use gradient checkpointing on this module.\n    :param up: if True, use this block for upsampling.\n    :param down: if True, use this block for downsampling.\n    \"\"\"\n\n    def __init__(\n        self,\n        channels: int,\n        emb_channels: int,\n        dropout: float,\n        out_channels: Optional[int] = None,\n        use_conv: bool = False,\n        use_scale_shift_norm: bool = False,\n        dims: int = 2,\n        use_checkpoint: bool = False,\n        up: bool = False,\n        down: bool = False,\n        kernel_size: int = 3,\n        exchange_temb_dims: bool = False,\n        skip_t_emb: bool = False,\n    ):\n        super().__init__()\n        self.channels = channels\n        self.emb_channels = emb_channels\n        self.dropout = dropout\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.use_checkpoint = use_checkpoint\n        self.use_scale_shift_norm = use_scale_shift_norm\n        self.exchange_temb_dims = exchange_temb_dims\n\n        if isinstance(kernel_size, Iterable):\n            padding = [k // 2 for k in kernel_size]\n        else:\n            padding = kernel_size // 2\n\n        self.in_layers = nn.Sequential(\n            normalization(channels),\n            nn.SiLU(),\n            conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),\n        )\n\n        self.updown = up or down\n\n        if up:\n            self.h_upd = Upsample(channels, False, dims)\n            self.x_upd = Upsample(channels, False, dims)\n        elif down:\n            self.h_upd = Downsample(channels, False, dims)\n            self.x_upd = Downsample(channels, False, dims)\n        else:\n            self.h_upd = self.x_upd = nn.Identity()\n\n        self.skip_t_emb = skip_t_emb\n        self.emb_out_channels = (\n            2 * self.out_channels if use_scale_shift_norm else self.out_channels\n        )\n        if self.skip_t_emb:\n            logpy.info(f\"Skipping timestep embedding in {self.__class__.__name__}\")\n            assert not self.use_scale_shift_norm\n            self.emb_layers = None\n            self.exchange_temb_dims = False\n        else:\n            self.emb_layers = nn.Sequential(\n                nn.SiLU(),\n                linear(\n                    emb_channels,\n                    self.emb_out_channels,\n                ),\n            )\n\n        self.out_layers = nn.Sequential(\n            normalization(self.out_channels),\n            nn.SiLU(),\n            nn.Dropout(p=dropout),\n            zero_module(\n                conv_nd(\n                    dims,\n                    self.out_channels,\n                    self.out_channels,\n                    kernel_size,\n                    padding=padding,\n                )\n            ),\n        )\n\n        if self.out_channels == channels:\n            self.skip_connection = nn.Identity()\n        elif use_conv:\n            self.skip_connection = conv_nd(\n                dims, channels, self.out_channels, kernel_size, padding=padding\n            )\n        else:\n            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)\n\n    def forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor:\n        \"\"\"\n        Apply the block to a Tensor, conditioned on a timestep embedding.\n        :param x: an [N x C x ...] Tensor of features.\n        :param emb: an [N x emb_channels] Tensor of timestep embeddings.\n        :return: an [N x C x ...] Tensor of outputs.\n        \"\"\"\n        if self.use_checkpoint:\n            return checkpoint(self._forward, x, emb)\n        else:\n            return self._forward(x, emb)\n\n    def _forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor:\n        if self.updown:\n            in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]\n            h = in_rest(x)\n            h = self.h_upd(h)\n            x = self.x_upd(x)\n            h = in_conv(h)\n        else:\n            h = self.in_layers(x)\n\n        if self.skip_t_emb:\n            emb_out = th.zeros_like(h)\n        else:\n            emb_out = self.emb_layers(emb).type(h.dtype)\n        while len(emb_out.shape) < len(h.shape):\n            emb_out = emb_out[..., None]\n        if self.use_scale_shift_norm:\n            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]\n            scale, shift = th.chunk(emb_out, 2, dim=1)\n            h = out_norm(h) * (1 + scale) + shift\n            h = out_rest(h)\n        else:\n            if self.exchange_temb_dims:\n                emb_out = rearrange(emb_out, \"b t c ... -> b c t ...\")\n            h = h + emb_out\n            h = self.out_layers(h)\n        return self.skip_connection(x) + h\n\n\nclass AttentionBlock(nn.Module):\n    \"\"\"\n    An attention block that allows spatial positions to attend to each other.\n    Originally ported from here, but adapted to the N-d case.\n    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.\n    \"\"\"\n\n    def __init__(\n        self,\n        channels: int,\n        num_heads: int = 1,\n        num_head_channels: int = -1,\n        use_checkpoint: bool = False,\n        use_new_attention_order: bool = False,\n    ):\n        super().__init__()\n        self.channels = channels\n        if num_head_channels == -1:\n            self.num_heads = num_heads\n        else:\n            assert (\n                channels % num_head_channels == 0\n            ), f\"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}\"\n            self.num_heads = channels // num_head_channels\n        self.use_checkpoint = use_checkpoint\n        self.norm = normalization(channels)\n        self.qkv = conv_nd(1, channels, channels * 3, 1)\n        if use_new_attention_order:\n            # split qkv before split heads\n            self.attention = QKVAttention(self.num_heads)\n        else:\n            # split heads before split qkv\n            self.attention = QKVAttentionLegacy(self.num_heads)\n\n        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))\n\n    def forward(self, x: th.Tensor, **kwargs) -> th.Tensor:\n        return checkpoint(self._forward, x)\n\n    def _forward(self, x: th.Tensor) -> th.Tensor:\n        b, c, *spatial = x.shape\n        x = x.reshape(b, c, -1)\n        qkv = self.qkv(self.norm(x))\n        h = self.attention(qkv)\n        h = self.proj_out(h)\n        return (x + h).reshape(b, c, *spatial)\n\n\nclass QKVAttentionLegacy(nn.Module):\n    \"\"\"\n    A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping\n    \"\"\"\n\n    def __init__(self, n_heads: int):\n        super().__init__()\n        self.n_heads = n_heads\n\n    def forward(self, qkv: th.Tensor) -> th.Tensor:\n        \"\"\"\n        Apply QKV attention.\n        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.\n        :return: an [N x (H * C) x T] tensor after attention.\n        \"\"\"\n        bs, width, length = qkv.shape\n        assert width % (3 * self.n_heads) == 0\n        ch = width // (3 * self.n_heads)\n        q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)\n        scale = 1 / math.sqrt(math.sqrt(ch))\n        weight = th.einsum(\n            \"bct,bcs->bts\", q * scale, k * scale\n        )  # More stable with f16 than dividing afterwards\n        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)\n        a = th.einsum(\"bts,bcs->bct\", weight, v)\n        return a.reshape(bs, -1, length)\n\n\nclass QKVAttention(nn.Module):\n    \"\"\"\n    A module which performs QKV attention and splits in a different order.\n    \"\"\"\n\n    def __init__(self, n_heads: int):\n        super().__init__()\n        self.n_heads = n_heads\n\n    def forward(self, qkv: th.Tensor) -> th.Tensor:\n        \"\"\"\n        Apply QKV attention.\n        :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.\n        :return: an [N x (H * C) x T] tensor after attention.\n        \"\"\"\n        bs, width, length = qkv.shape\n        assert width % (3 * self.n_heads) == 0\n        ch = width // (3 * self.n_heads)\n        q, k, v = qkv.chunk(3, dim=1)\n        scale = 1 / math.sqrt(math.sqrt(ch))\n        weight = th.einsum(\n            \"bct,bcs->bts\",\n            (q * scale).view(bs * self.n_heads, ch, length),\n            (k * scale).view(bs * self.n_heads, ch, length),\n        )  # More stable with f16 than dividing afterwards\n        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)\n        a = th.einsum(\"bts,bcs->bct\", weight, v.reshape(bs * self.n_heads, ch, length))\n        return a.reshape(bs, -1, length)\n\n\nclass Timestep(nn.Module):\n    def __init__(self, dim: int):\n        super().__init__()\n        self.dim = dim\n\n    def forward(self, t: th.Tensor) -> th.Tensor:\n        return timestep_embedding(t, self.dim)\n\n\nclass UNetModel(nn.Module):\n    \"\"\"\n    The full UNet model with attention and timestep embedding.\n    :param in_channels: channels in the input Tensor.\n    :param model_channels: base channel count for the model.\n    :param out_channels: channels in the output Tensor.\n    :param num_res_blocks: number of residual blocks per downsample.\n    :param attention_resolutions: a collection of downsample rates at which\n        attention will take place. May be a set, list, or tuple.\n        For example, if this contains 4, then at 4x downsampling, attention\n        will be used.\n    :param dropout: the dropout probability.\n    :param channel_mult: channel multiplier for each level of the UNet.\n    :param conv_resample: if True, use learned convolutions for upsampling and\n        downsampling.\n    :param dims: determines if the signal is 1D, 2D, or 3D.\n    :param num_classes: if specified (as an int), then this model will be\n        class-conditional with `num_classes` classes.\n    :param use_checkpoint: use gradient checkpointing to reduce memory usage.\n    :param num_heads: the number of attention heads in each attention layer.\n    :param num_heads_channels: if specified, ignore num_heads and instead use\n                               a fixed channel width per attention head.\n    :param num_heads_upsample: works with num_heads to set a different number\n                               of heads for upsampling. Deprecated.\n    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.\n    :param resblock_updown: use residual blocks for up/downsampling.\n    :param use_new_attention_order: use a different attention pattern for potentially\n                                    increased efficiency.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        model_channels: int,\n        out_channels: int,\n        num_res_blocks: int,\n        attention_resolutions: int,\n        dropout: float = 0.0,\n        channel_mult: Union[List, Tuple] = (1, 2, 4, 8),\n        conv_resample: bool = True,\n        dims: int = 2,\n        num_classes: Optional[Union[int, str]] = None,\n        use_checkpoint: bool = False,\n        num_heads: int = -1,\n        num_head_channels: int = -1,\n        num_heads_upsample: int = -1,\n        use_scale_shift_norm: bool = False,\n        resblock_updown: bool = False,\n        transformer_depth: int = 1,\n        context_dim: Optional[int] = None,\n        disable_self_attentions: Optional[List[bool]] = None,\n        num_attention_blocks: Optional[List[int]] = None,\n        disable_middle_self_attn: bool = False,\n        disable_middle_transformer: bool = False,\n        use_linear_in_transformer: bool = False,\n        spatial_transformer_attn_type: str = \"softmax\",\n        adm_in_channels: Optional[int] = None,\n    ):\n        super().__init__()\n\n        if num_heads_upsample == -1:\n            num_heads_upsample = num_heads\n\n        if num_heads == -1:\n            assert (\n                num_head_channels != -1\n            ), \"Either num_heads or num_head_channels has to be set\"\n\n        if num_head_channels == -1:\n            assert (\n                num_heads != -1\n            ), \"Either num_heads or num_head_channels has to be set\"\n\n        self.in_channels = in_channels\n        self.model_channels = model_channels\n        self.out_channels = out_channels\n        if isinstance(transformer_depth, int):\n            transformer_depth = len(channel_mult) * [transformer_depth]\n        transformer_depth_middle = transformer_depth[-1]\n\n        if isinstance(num_res_blocks, int):\n            self.num_res_blocks = len(channel_mult) * [num_res_blocks]\n        else:\n            if len(num_res_blocks) != len(channel_mult):\n                raise ValueError(\n                    \"provide num_res_blocks either as an int (globally constant) or \"\n                    \"as a list/tuple (per-level) with the same length as channel_mult\"\n                )\n            self.num_res_blocks = num_res_blocks\n\n        if disable_self_attentions is not None:\n            assert len(disable_self_attentions) == len(channel_mult)\n        if num_attention_blocks is not None:\n            assert len(num_attention_blocks) == len(self.num_res_blocks)\n            assert all(\n                map(\n                    lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],\n                    range(len(num_attention_blocks)),\n                )\n            )\n            logpy.info(\n                f\"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. \"\n                f\"This option has LESS priority than attention_resolutions {attention_resolutions}, \"\n                f\"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, \"\n                f\"attention will still not be set.\"\n            )\n\n        self.attention_resolutions = attention_resolutions\n        self.dropout = dropout\n        self.channel_mult = channel_mult\n        self.conv_resample = conv_resample\n        self.num_classes = num_classes\n        self.use_checkpoint = use_checkpoint\n        self.num_heads = num_heads\n        self.num_head_channels = num_head_channels\n        self.num_heads_upsample = num_heads_upsample\n\n        time_embed_dim = model_channels * 4\n        self.time_embed = nn.Sequential(\n            linear(model_channels, time_embed_dim),\n            nn.SiLU(),\n            linear(time_embed_dim, time_embed_dim),\n        )\n\n        if self.num_classes is not None:\n            if isinstance(self.num_classes, int):\n                self.label_emb = nn.Embedding(num_classes, time_embed_dim)\n            elif self.num_classes == \"continuous\":\n                logpy.info(\"setting up linear c_adm embedding layer\")\n                self.label_emb = nn.Linear(1, time_embed_dim)\n            elif self.num_classes == \"timestep\":\n                self.label_emb = nn.Sequential(\n                    Timestep(model_channels),\n                    nn.Sequential(\n                        linear(model_channels, time_embed_dim),\n                        nn.SiLU(),\n                        linear(time_embed_dim, time_embed_dim),\n                    ),\n                )\n            elif self.num_classes == \"sequential\":\n                assert adm_in_channels is not None\n                self.label_emb = nn.Sequential(\n                    nn.Sequential(\n                        linear(adm_in_channels, time_embed_dim),\n                        nn.SiLU(),\n                        linear(time_embed_dim, time_embed_dim),\n                    )\n                )\n            else:\n                raise ValueError\n\n        self.input_blocks = nn.ModuleList(\n            [\n                TimestepEmbedSequential(\n                    conv_nd(dims, in_channels, model_channels, 3, padding=1)\n                )\n            ]\n        )\n        self._feature_size = model_channels\n        input_block_chans = [model_channels]\n        ch = model_channels\n        ds = 1\n        for level, mult in enumerate(channel_mult):\n            for nr in range(self.num_res_blocks[level]):\n                layers = [\n                    ResBlock(\n                        ch,\n                        time_embed_dim,\n                        dropout,\n                        out_channels=mult * model_channels,\n                        dims=dims,\n                        use_checkpoint=use_checkpoint,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                    )\n                ]\n                ch = mult * model_channels\n                if ds in attention_resolutions:\n                    if num_head_channels == -1:\n                        dim_head = ch // num_heads\n                    else:\n                        num_heads = ch // num_head_channels\n                        dim_head = num_head_channels\n\n                    if context_dim is not None and exists(disable_self_attentions):\n                        disabled_sa = disable_self_attentions[level]\n                    else:\n                        disabled_sa = False\n\n                    if (\n                        not exists(num_attention_blocks)\n                        or nr < num_attention_blocks[level]\n                    ):\n                        layers.append(\n                            SpatialTransformer(\n                                ch,\n                                num_heads,\n                                dim_head,\n                                depth=transformer_depth[level],\n                                context_dim=context_dim,\n                                disable_self_attn=disabled_sa,\n                                use_linear=use_linear_in_transformer,\n                                attn_type=spatial_transformer_attn_type,\n                                use_checkpoint=use_checkpoint,\n                            )\n                        )\n                self.input_blocks.append(TimestepEmbedSequential(*layers))\n                self._feature_size += ch\n                input_block_chans.append(ch)\n            if level != len(channel_mult) - 1:\n                out_ch = ch\n                self.input_blocks.append(\n                    TimestepEmbedSequential(\n                        ResBlock(\n                            ch,\n                            time_embed_dim,\n                            dropout,\n                            out_channels=out_ch,\n                            dims=dims,\n                            use_checkpoint=use_checkpoint,\n                            use_scale_shift_norm=use_scale_shift_norm,\n                            down=True,\n                        )\n                        if resblock_updown\n                        else Downsample(\n                            ch, conv_resample, dims=dims, out_channels=out_ch\n                        )\n                    )\n                )\n                ch = out_ch\n                input_block_chans.append(ch)\n                ds *= 2\n                self._feature_size += ch\n\n        if num_head_channels == -1:\n            dim_head = ch // num_heads\n        else:\n            num_heads = ch // num_head_channels\n            dim_head = num_head_channels\n\n        self.middle_block = TimestepEmbedSequential(\n            ResBlock(\n                ch,\n                time_embed_dim,\n                dropout,\n                out_channels=ch,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n            SpatialTransformer(\n                ch,\n                num_heads,\n                dim_head,\n                depth=transformer_depth_middle,\n                context_dim=context_dim,\n                disable_self_attn=disable_middle_self_attn,\n                use_linear=use_linear_in_transformer,\n                attn_type=spatial_transformer_attn_type,\n                use_checkpoint=use_checkpoint,\n            )\n            if not disable_middle_transformer\n            else th.nn.Identity(),\n            ResBlock(\n                ch,\n                time_embed_dim,\n                dropout,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n        )\n        self._feature_size += ch\n\n        self.output_blocks = nn.ModuleList([])\n        for level, mult in list(enumerate(channel_mult))[::-1]:\n            for i in range(self.num_res_blocks[level] + 1):\n                ich = input_block_chans.pop()\n                layers = [\n                    ResBlock(\n                        ch + ich,\n                        time_embed_dim,\n                        dropout,\n                        out_channels=model_channels * mult,\n                        dims=dims,\n                        use_checkpoint=use_checkpoint,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                    )\n                ]\n                ch = model_channels * mult\n                if ds in attention_resolutions:\n                    if num_head_channels == -1:\n                        dim_head = ch // num_heads\n                    else:\n                        num_heads = ch // num_head_channels\n                        dim_head = num_head_channels\n\n                    if exists(disable_self_attentions):\n                        disabled_sa = disable_self_attentions[level]\n                    else:\n                        disabled_sa = False\n\n                    if (\n                        not exists(num_attention_blocks)\n                        or i < num_attention_blocks[level]\n                    ):\n                        layers.append(\n                            SpatialTransformer(\n                                ch,\n                                num_heads,\n                                dim_head,\n                                depth=transformer_depth[level],\n                                context_dim=context_dim,\n                                disable_self_attn=disabled_sa,\n                                use_linear=use_linear_in_transformer,\n                                attn_type=spatial_transformer_attn_type,\n                                use_checkpoint=use_checkpoint,\n                            )\n                        )\n                if level and i == self.num_res_blocks[level]:\n                    out_ch = ch\n                    layers.append(\n                        ResBlock(\n                            ch,\n                            time_embed_dim,\n                            dropout,\n                            out_channels=out_ch,\n                            dims=dims,\n                            use_checkpoint=use_checkpoint,\n                            use_scale_shift_norm=use_scale_shift_norm,\n                            up=True,\n                        )\n                        if resblock_updown\n                        else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)\n                    )\n                    ds //= 2\n                self.output_blocks.append(TimestepEmbedSequential(*layers))\n                self._feature_size += ch\n\n        self.out = nn.Sequential(\n            normalization(ch),\n            nn.SiLU(),\n            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),\n        )\n\n    def forward(\n        self,\n        x: th.Tensor,\n        timesteps: Optional[th.Tensor] = None,\n        context: Optional[th.Tensor] = None,\n        y: Optional[th.Tensor] = None,\n        **kwargs,\n    ) -> th.Tensor:\n        \"\"\"\n        Apply the model to an input batch.\n        :param x: an [N x C x ...] Tensor of inputs.\n        :param timesteps: a 1-D batch of timesteps.\n        :param context: conditioning plugged in via crossattn\n        :param y: an [N] Tensor of labels, if class-conditional.\n        :return: an [N x C x ...] Tensor of outputs.\n        \"\"\"\n        assert (y is not None) == (\n            self.num_classes is not None\n        ), \"must specify y if and only if the model is class-conditional\"\n        hs = []\n        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)\n        emb = self.time_embed(t_emb)\n\n        if self.num_classes is not None:\n            assert y.shape[0] == x.shape[0]\n            emb = emb + self.label_emb(y)\n\n        h = x\n        for module in self.input_blocks:\n            h = module(h, emb, context)\n            hs.append(h)\n        h = self.middle_block(h, emb, context)\n        for module in self.output_blocks:\n            h = th.cat([h, hs.pop()], dim=1)\n            h = module(h, emb, context)\n        h = h.type(x.dtype)\n\n        return self.out(h)\n"
  },
  {
    "path": "sgm/modules/diffusionmodules/sampling.py",
    "content": "\"\"\"\n    Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py\n\"\"\"\n\n\nfrom typing import Dict, Union\n\nimport torch\nfrom omegaconf import ListConfig, OmegaConf\nfrom tqdm import tqdm\n\nfrom ...modules.diffusionmodules.sampling_utils import (get_ancestral_step,\n                                                        linear_multistep_coeff,\n                                                        to_d, to_neg_log_sigma,\n                                                        to_sigma)\nfrom ...util import append_dims, default, instantiate_from_config\n\nDEFAULT_GUIDER = {\"target\": \"sgm.modules.diffusionmodules.guiders.IdentityGuider\"}\n\n\nclass BaseDiffusionSampler:\n    def __init__(\n        self,\n        discretization_config: Union[Dict, ListConfig, OmegaConf],\n        num_steps: Union[int, None] = None,\n        guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,\n        verbose: bool = False,\n        device: str = \"cuda\",\n    ):\n        self.num_steps = num_steps\n        self.discretization = instantiate_from_config(discretization_config)\n        self.guider = instantiate_from_config(\n            default(\n                guider_config,\n                DEFAULT_GUIDER,\n            )\n        )\n        self.verbose = verbose\n        self.device = device\n\n    def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):\n        sigmas = self.discretization(\n            self.num_steps if num_steps is None else num_steps, device=self.device\n        )\n        uc = default(uc, cond)\n\n        x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)\n        num_sigmas = len(sigmas)\n\n        s_in = x.new_ones([x.shape[0]])\n\n        return x, s_in, sigmas, num_sigmas, cond, uc\n\n    def denoise(self, x, denoiser, sigma, cond, uc):\n        denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc))\n        denoised = self.guider(denoised, sigma)\n        return denoised\n\n    def get_sigma_gen(self, num_sigmas):\n        sigma_generator = range(num_sigmas - 1)\n        if self.verbose:\n            print(\"#\" * 30, \" Sampling setting \", \"#\" * 30)\n            print(f\"Sampler: {self.__class__.__name__}\")\n            print(f\"Discretization: {self.discretization.__class__.__name__}\")\n            print(f\"Guider: {self.guider.__class__.__name__}\")\n            sigma_generator = tqdm(\n                sigma_generator,\n                total=num_sigmas,\n                desc=f\"Sampling with {self.__class__.__name__} for {num_sigmas} steps\",\n            )\n        return sigma_generator\n\n\nclass SingleStepDiffusionSampler(BaseDiffusionSampler):\n    def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):\n        raise NotImplementedError\n\n    def euler_step(self, x, d, dt):\n        return x + dt * d\n\n\nclass EDMSampler(SingleStepDiffusionSampler):\n    def __init__(\n        self, s_churn=0.0, s_tmin=0.0, s_tmax=float(\"inf\"), s_noise=1.0, *args, **kwargs\n    ):\n        super().__init__(*args, **kwargs)\n\n        self.s_churn = s_churn\n        self.s_tmin = s_tmin\n        self.s_tmax = s_tmax\n        self.s_noise = s_noise\n\n    def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):\n        sigma_hat = sigma * (gamma + 1.0)\n        if gamma > 0:\n            eps = torch.randn_like(x) * self.s_noise\n            x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5\n\n        denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)\n        d = to_d(x, sigma_hat, denoised)\n        dt = append_dims(next_sigma - sigma_hat, x.ndim)\n\n        euler_step = self.euler_step(x, d, dt)\n        x = self.possible_correction_step(\n            euler_step, x, d, dt, next_sigma, denoiser, cond, uc\n        )\n        return x\n\n    def __call__(self, denoiser, x, cond, uc=None, num_steps=None):\n        x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(\n            x, cond, uc, num_steps\n        )\n\n        for i in self.get_sigma_gen(num_sigmas):\n            gamma = (\n                min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)\n                if self.s_tmin <= sigmas[i] <= self.s_tmax\n                else 0.0\n            )\n            x = self.sampler_step(\n                s_in * sigmas[i],\n                s_in * sigmas[i + 1],\n                denoiser,\n                x,\n                cond,\n                uc,\n                gamma,\n            )\n\n        return x\n\n\nclass AncestralSampler(SingleStepDiffusionSampler):\n    def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        self.eta = eta\n        self.s_noise = s_noise\n        self.noise_sampler = lambda x: torch.randn_like(x)\n\n    def ancestral_euler_step(self, x, denoised, sigma, sigma_down):\n        d = to_d(x, sigma, denoised)\n        dt = append_dims(sigma_down - sigma, x.ndim)\n\n        return self.euler_step(x, d, dt)\n\n    def ancestral_step(self, x, sigma, next_sigma, sigma_up):\n        x = torch.where(\n            append_dims(next_sigma, x.ndim) > 0.0,\n            x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim),\n            x,\n        )\n        return x\n\n    def __call__(self, denoiser, x, cond, uc=None, num_steps=None):\n        x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(\n            x, cond, uc, num_steps\n        )\n\n        for i in self.get_sigma_gen(num_sigmas):\n            x = self.sampler_step(\n                s_in * sigmas[i],\n                s_in * sigmas[i + 1],\n                denoiser,\n                x,\n                cond,\n                uc,\n            )\n\n        return x\n\n\nclass LinearMultistepSampler(BaseDiffusionSampler):\n    def __init__(\n        self,\n        order=4,\n        *args,\n        **kwargs,\n    ):\n        super().__init__(*args, **kwargs)\n\n        self.order = order\n\n    def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):\n        x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(\n            x, cond, uc, num_steps\n        )\n\n        ds = []\n        sigmas_cpu = sigmas.detach().cpu().numpy()\n        for i in self.get_sigma_gen(num_sigmas):\n            sigma = s_in * sigmas[i]\n            denoised = denoiser(\n                *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs\n            )\n            denoised = self.guider(denoised, sigma)\n            d = to_d(x, sigma, denoised)\n            ds.append(d)\n            if len(ds) > self.order:\n                ds.pop(0)\n            cur_order = min(i + 1, self.order)\n            coeffs = [\n                linear_multistep_coeff(cur_order, sigmas_cpu, i, j)\n                for j in range(cur_order)\n            ]\n            x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))\n\n        return x\n\n\nclass EulerEDMSampler(EDMSampler):\n    def possible_correction_step(\n        self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc\n    ):\n        return euler_step\n\n\nclass HeunEDMSampler(EDMSampler):\n    def possible_correction_step(\n        self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc\n    ):\n        if torch.sum(next_sigma) < 1e-14:\n            # Save a network evaluation if all noise levels are 0\n            return euler_step\n        else:\n            denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc)\n            d_new = to_d(euler_step, next_sigma, denoised)\n            d_prime = (d + d_new) / 2.0\n\n            # apply correction if noise level is not 0\n            x = torch.where(\n                append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step\n            )\n            return x\n\n\nclass EulerAncestralSampler(AncestralSampler):\n    def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc):\n        sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)\n        denoised = self.denoise(x, denoiser, sigma, cond, uc)\n        x = self.ancestral_euler_step(x, denoised, sigma, sigma_down)\n        x = self.ancestral_step(x, sigma, next_sigma, sigma_up)\n\n        return x\n\n\nclass DPMPP2SAncestralSampler(AncestralSampler):\n    def get_variables(self, sigma, sigma_down):\n        t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)]\n        h = t_next - t\n        s = t + 0.5 * h\n        return h, s, t, t_next\n\n    def get_mult(self, h, s, t, t_next):\n        mult1 = to_sigma(s) / to_sigma(t)\n        mult2 = (-0.5 * h).expm1()\n        mult3 = to_sigma(t_next) / to_sigma(t)\n        mult4 = (-h).expm1()\n\n        return mult1, mult2, mult3, mult4\n\n    def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs):\n        sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)\n        denoised = self.denoise(x, denoiser, sigma, cond, uc)\n        x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down)\n\n        if torch.sum(sigma_down) < 1e-14:\n            # Save a network evaluation if all noise levels are 0\n            x = x_euler\n        else:\n            h, s, t, t_next = self.get_variables(sigma, sigma_down)\n            mult = [\n                append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)\n            ]\n\n            x2 = mult[0] * x - mult[1] * denoised\n            denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)\n            x_dpmpp2s = mult[2] * x - mult[3] * denoised2\n\n            # apply correction if noise level is not 0\n            x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler)\n\n        x = self.ancestral_step(x, sigma, next_sigma, sigma_up)\n        return x\n\n\nclass DPMPP2MSampler(BaseDiffusionSampler):\n    def get_variables(self, sigma, next_sigma, previous_sigma=None):\n        t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]\n        h = t_next - t\n\n        if previous_sigma is not None:\n            h_last = t - to_neg_log_sigma(previous_sigma)\n            r = h_last / h\n            return h, r, t, t_next\n        else:\n            return h, None, t, t_next\n\n    def get_mult(self, h, r, t, t_next, previous_sigma):\n        mult1 = to_sigma(t_next) / to_sigma(t)\n        mult2 = (-h).expm1()\n\n        if previous_sigma is not None:\n            mult3 = 1 + 1 / (2 * r)\n            mult4 = 1 / (2 * r)\n            return mult1, mult2, mult3, mult4\n        else:\n            return mult1, mult2\n\n    def sampler_step(\n        self,\n        old_denoised,\n        previous_sigma,\n        sigma,\n        next_sigma,\n        denoiser,\n        x,\n        cond,\n        uc=None,\n    ):\n        denoised = self.denoise(x, denoiser, sigma, cond, uc)\n\n        h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)\n        mult = [\n            append_dims(mult, x.ndim)\n            for mult in self.get_mult(h, r, t, t_next, previous_sigma)\n        ]\n\n        x_standard = mult[0] * x - mult[1] * denoised\n        if old_denoised is None or torch.sum(next_sigma) < 1e-14:\n            # Save a network evaluation if all noise levels are 0 or on the first step\n            return x_standard, denoised\n        else:\n            denoised_d = mult[2] * denoised - mult[3] * old_denoised\n            x_advanced = mult[0] * x - mult[1] * denoised_d\n\n            # apply correction if noise level is not 0 and not first step\n            x = torch.where(\n                append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard\n            )\n\n        return x, denoised\n\n    def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):\n        x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(\n            x, cond, uc, num_steps\n        )\n\n        old_denoised = None\n        for i in self.get_sigma_gen(num_sigmas):\n            x, old_denoised = self.sampler_step(\n                old_denoised,\n                None if i == 0 else s_in * sigmas[i - 1],\n                s_in * sigmas[i],\n                s_in * sigmas[i + 1],\n                denoiser,\n                x,\n                cond,\n                uc=uc,\n            )\n\n        return x\n"
  },
  {
    "path": "sgm/modules/diffusionmodules/sampling_utils.py",
    "content": "import torch\nfrom scipy import integrate\n\nfrom ...util import append_dims\n\n\ndef linear_multistep_coeff(order, t, i, j, epsrel=1e-4):\n    if order - 1 > i:\n        raise ValueError(f\"Order {order} too high for step {i}\")\n\n    def fn(tau):\n        prod = 1.0\n        for k in range(order):\n            if j == k:\n                continue\n            prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])\n        return prod\n\n    return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]\n\n\ndef get_ancestral_step(sigma_from, sigma_to, eta=1.0):\n    if not eta:\n        return sigma_to, 0.0\n    sigma_up = torch.minimum(\n        sigma_to,\n        eta\n        * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,\n    )\n    sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5\n    return sigma_down, sigma_up\n\n\ndef to_d(x, sigma, denoised):\n    return (x - denoised) / append_dims(sigma, x.ndim)\n\n\ndef to_neg_log_sigma(sigma):\n    return sigma.log().neg()\n\n\ndef to_sigma(neg_log_sigma):\n    return neg_log_sigma.neg().exp()\n"
  },
  {
    "path": "sgm/modules/diffusionmodules/sigma_sampling.py",
    "content": "import torch\nfrom typing import Optional, Union\nfrom ...util import default, instantiate_from_config\n\n\nclass EDMSampling:\n    def __init__(self, p_mean=-1.2, p_std=1.2):\n        self.p_mean = p_mean\n        self.p_std = p_std\n\n    def __call__(self, n_samples, rand=None):\n        log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))\n        return log_sigma.exp()\n\n\nclass DiscreteSampling:\n    def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True):\n        self.num_idx = num_idx\n        self.sigmas = instantiate_from_config(discretization_config)(\n            num_idx, do_append_zero=do_append_zero, flip=flip\n        )\n\n    def idx_to_sigma(self, idx):\n        return self.sigmas[idx]\n\n    def __call__(self, n_samples, rand=None):\n        idx = default(\n            rand,\n            torch.randint(0, self.num_idx, (n_samples,)),\n        )\n        return self.idx_to_sigma(idx)\n\n\nclass ZeroSampler:\n    def __call__(\n        self, n_samples: int, rand: Optional[torch.Tensor] = None\n    ) -> torch.Tensor:\n        return torch.zeros_like(default(rand, torch.randn((n_samples,)))) + 1.0e-5\n"
  },
  {
    "path": "sgm/modules/diffusionmodules/util.py",
    "content": "\"\"\"\npartially adopted from\nhttps://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py\nand\nhttps://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py\nand\nhttps://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py\n\nthanks!\n\"\"\"\n\nimport math\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\nfrom einops import rearrange, repeat\n\n\ndef get_alpha(\n    merge_strategy: str,\n    mix_factor: Optional[torch.Tensor],\n    image_only_indicator: torch.Tensor,\n    apply_sigmoid: bool = True,\n    is_attn: bool = False,\n) -> torch.Tensor:\n    if merge_strategy == \"fixed\" or merge_strategy == \"learned\":\n        alpha = mix_factor\n    elif merge_strategy == \"learned_with_images\":\n        alpha = torch.where(\n            image_only_indicator.bool(),\n            torch.ones(1, 1, device=image_only_indicator.device),\n            rearrange(mix_factor, \"... -> ... 1\"),\n        )\n        if is_attn:\n            alpha = rearrange(alpha, \"b t -> (b t) 1 1\")\n        else:\n            alpha = rearrange(alpha, \"b t -> b 1 t 1 1\")\n    elif merge_strategy == \"fixed_with_images\":\n        alpha = image_only_indicator\n        if is_attn:\n            alpha = rearrange(alpha, \"b t -> (b t) 1 1\")\n        else:\n            alpha = rearrange(alpha, \"b t -> b 1 t 1 1\")\n    else:\n        raise NotImplementedError\n    return torch.sigmoid(alpha) if apply_sigmoid else alpha\n\n    \ndef make_beta_schedule(\n    schedule,\n    n_timestep,\n    linear_start=1e-4,\n    linear_end=2e-2,\n):\n    if schedule == \"linear\":\n        betas = (\n            torch.linspace(\n                linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64\n            )\n            ** 2\n        )\n    return betas.numpy()\n\n\ndef extract_into_tensor(a, t, x_shape):\n    b, *_ = t.shape\n    out = a.gather(-1, t)\n    return out.reshape(b, *((1,) * (len(x_shape) - 1)))\n\n\ndef mixed_checkpoint(func, inputs: dict, params, flag):\n    \"\"\"\n    Evaluate a function without caching intermediate activations, allowing for\n    reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function\n    borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that\n    it also works with non-tensor inputs\n    :param func: the function to evaluate.\n    :param inputs: the argument dictionary to pass to `func`.\n    :param params: a sequence of parameters `func` depends on but does not\n                   explicitly take as arguments.\n    :param flag: if False, disable gradient checkpointing.\n    \"\"\"\n    if flag:\n        tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)]\n        tensor_inputs = [\n            inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)\n        ]\n        non_tensor_keys = [\n            key for key in inputs if not isinstance(inputs[key], torch.Tensor)\n        ]\n        non_tensor_inputs = [\n            inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)\n        ]\n        args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params)\n        return MixedCheckpointFunction.apply(\n            func,\n            len(tensor_inputs),\n            len(non_tensor_inputs),\n            tensor_keys,\n            non_tensor_keys,\n            *args,\n        )\n    else:\n        return func(**inputs)\n\n\nclass MixedCheckpointFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        run_function,\n        length_tensors,\n        length_non_tensors,\n        tensor_keys,\n        non_tensor_keys,\n        *args,\n    ):\n        ctx.end_tensors = length_tensors\n        ctx.end_non_tensors = length_tensors + length_non_tensors\n        ctx.gpu_autocast_kwargs = {\n            \"enabled\": torch.is_autocast_enabled(),\n            \"dtype\": torch.get_autocast_gpu_dtype(),\n            \"cache_enabled\": torch.is_autocast_cache_enabled(),\n        }\n        assert (\n            len(tensor_keys) == length_tensors\n            and len(non_tensor_keys) == length_non_tensors\n        )\n\n        ctx.input_tensors = {\n            key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))\n        }\n        ctx.input_non_tensors = {\n            key: val\n            for (key, val) in zip(\n                non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors])\n            )\n        }\n        ctx.run_function = run_function\n        ctx.input_params = list(args[ctx.end_non_tensors :])\n\n        with torch.no_grad():\n            output_tensors = ctx.run_function(\n                **ctx.input_tensors, **ctx.input_non_tensors\n            )\n        return output_tensors\n\n    @staticmethod\n    def backward(ctx, *output_grads):\n        # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)}\n        ctx.input_tensors = {\n            key: ctx.input_tensors[key].detach().requires_grad_(True)\n            for key in ctx.input_tensors\n        }\n\n        with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):\n            # Fixes a bug where the first op in run_function modifies the\n            # Tensor storage in place, which is not allowed for detach()'d\n            # Tensors.\n            shallow_copies = {\n                key: ctx.input_tensors[key].view_as(ctx.input_tensors[key])\n                for key in ctx.input_tensors\n            }\n            # shallow_copies.update(additional_args)\n            output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors)\n        input_grads = torch.autograd.grad(\n            output_tensors,\n            list(ctx.input_tensors.values()) + ctx.input_params,\n            output_grads,\n            allow_unused=True,\n        )\n        del ctx.input_tensors\n        del ctx.input_params\n        del output_tensors\n        return (\n            (None, None, None, None, None)\n            + input_grads[: ctx.end_tensors]\n            + (None,) * (ctx.end_non_tensors - ctx.end_tensors)\n            + input_grads[ctx.end_tensors :]\n        )\n\n\ndef checkpoint(func, inputs, params, flag):\n    \"\"\"\n    Evaluate a function without caching intermediate activations, allowing for\n    reduced memory at the expense of extra compute in the backward pass.\n    :param func: the function to evaluate.\n    :param inputs: the argument sequence to pass to `func`.\n    :param params: a sequence of parameters `func` depends on but does not\n                   explicitly take as arguments.\n    :param flag: if False, disable gradient checkpointing.\n    \"\"\"\n    if flag:\n        args = tuple(inputs) + tuple(params)\n        return CheckpointFunction.apply(func, len(inputs), *args)\n    else:\n        return func(*inputs)\n\n\nclass CheckpointFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, run_function, length, *args):\n        ctx.run_function = run_function\n        ctx.input_tensors = list(args[:length])\n        ctx.input_params = list(args[length:])\n        ctx.gpu_autocast_kwargs = {\n            \"enabled\": torch.is_autocast_enabled(),\n            \"dtype\": torch.get_autocast_gpu_dtype(),\n            \"cache_enabled\": torch.is_autocast_cache_enabled(),\n        }\n        with torch.no_grad():\n            output_tensors = ctx.run_function(*ctx.input_tensors)\n        return output_tensors\n\n    @staticmethod\n    def backward(ctx, *output_grads):\n        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]\n        with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):\n            # Fixes a bug where the first op in run_function modifies the\n            # Tensor storage in place, which is not allowed for detach()'d\n            # Tensors.\n            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]\n            output_tensors = ctx.run_function(*shallow_copies)\n        input_grads = torch.autograd.grad(\n            output_tensors,\n            ctx.input_tensors + ctx.input_params,\n            output_grads,\n            allow_unused=True,\n        )\n        del ctx.input_tensors\n        del ctx.input_params\n        del output_tensors\n        return (None, None) + input_grads\n\n\ndef timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):\n    \"\"\"\n    Create sinusoidal timestep embeddings.\n    :param timesteps: a 1-D Tensor of N indices, one per batch element.\n                      These may be fractional.\n    :param dim: the dimension of the output.\n    :param max_period: controls the minimum frequency of the embeddings.\n    :return: an [N x dim] Tensor of positional embeddings.\n    \"\"\"\n    if not repeat_only:\n        half = dim // 2\n        freqs = torch.exp(\n            -math.log(max_period)\n            * torch.arange(start=0, end=half, dtype=torch.float32)\n            / half\n        ).to(device=timesteps.device)\n        args = timesteps[:, None].float() * freqs[None]\n        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n        if dim % 2:\n            embedding = torch.cat(\n                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1\n            )\n    else:\n        embedding = repeat(timesteps, \"b -> b d\", d=dim)\n    return embedding\n\n\ndef zero_module(module):\n    \"\"\"\n    Zero out the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().zero_()\n    return module\n\n\ndef scale_module(module, scale):\n    \"\"\"\n    Scale the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().mul_(scale)\n    return module\n\n\ndef mean_flat(tensor):\n    \"\"\"\n    Take the mean over all non-batch dimensions.\n    \"\"\"\n    return tensor.mean(dim=list(range(1, len(tensor.shape))))\n\n\ndef normalization(channels):\n    \"\"\"\n    Make a standard normalization layer.\n    :param channels: number of input channels.\n    :return: an nn.Module for normalization.\n    \"\"\"\n    return GroupNorm32(32, channels)\n\n\n# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.\nclass SiLU(nn.Module):\n    def forward(self, x):\n        return x * torch.sigmoid(x)\n\n\nclass GroupNorm32(nn.GroupNorm):\n    def forward(self, x):\n        return super().forward(x.float()).type(x.dtype)\n\n\ndef conv_nd(dims, *args, **kwargs):\n    \"\"\"\n    Create a 1D, 2D, or 3D convolution module.\n    \"\"\"\n    if dims == 1:\n        return nn.Conv1d(*args, **kwargs)\n    elif dims == 2:\n        return nn.Conv2d(*args, **kwargs)\n    elif dims == 3:\n        return nn.Conv3d(*args, **kwargs)\n    raise ValueError(f\"unsupported dimensions: {dims}\")\n\n\ndef linear(*args, **kwargs):\n    \"\"\"\n    Create a linear module.\n    \"\"\"\n    return nn.Linear(*args, **kwargs)\n\n\ndef avg_pool_nd(dims, *args, **kwargs):\n    \"\"\"\n    Create a 1D, 2D, or 3D average pooling module.\n    \"\"\"\n    if dims == 1:\n        return nn.AvgPool1d(*args, **kwargs)\n    elif dims == 2:\n        return nn.AvgPool2d(*args, **kwargs)\n    elif dims == 3:\n        return nn.AvgPool3d(*args, **kwargs)\n    raise ValueError(f\"unsupported dimensions: {dims}\")\n\n\nclass AlphaBlender(nn.Module):\n    strategies = [\"learned\", \"fixed\", \"learned_with_images\"]\n\n    def __init__(\n        self,\n        alpha: float,\n        merge_strategy: str = \"learned_with_images\",\n        rearrange_pattern: str = \"b t -> (b t) 1 1\",\n    ):\n        super().__init__()\n        self.merge_strategy = merge_strategy\n        self.rearrange_pattern = rearrange_pattern\n\n        assert (\n            merge_strategy in self.strategies\n        ), f\"merge_strategy needs to be in {self.strategies}\"\n\n        if self.merge_strategy == \"fixed\":\n            self.register_buffer(\"mix_factor\", torch.Tensor([alpha]))\n        elif (\n            self.merge_strategy == \"learned\"\n            or self.merge_strategy == \"learned_with_images\"\n        ):\n            self.register_parameter(\n                \"mix_factor\", torch.nn.Parameter(torch.Tensor([alpha]))\n            )\n        else:\n            raise ValueError(f\"unknown merge strategy {self.merge_strategy}\")\n\n    def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:\n        if self.merge_strategy == \"fixed\":\n            alpha = self.mix_factor\n        elif self.merge_strategy == \"learned\":\n            alpha = torch.sigmoid(self.mix_factor)\n        elif self.merge_strategy == \"learned_with_images\":\n            assert image_only_indicator is not None, \"need image_only_indicator ...\"\n            alpha = torch.where(\n                image_only_indicator.bool(),\n                torch.ones(1, 1, device=image_only_indicator.device),\n                rearrange(torch.sigmoid(self.mix_factor), \"... -> ... 1\"),\n            )\n            alpha = rearrange(alpha, self.rearrange_pattern)\n        else:\n            raise NotImplementedError\n        return alpha\n\n    def forward(\n        self,\n        x_spatial: torch.Tensor,\n        x_temporal: torch.Tensor,\n        image_only_indicator: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        alpha = self.get_alpha(image_only_indicator)\n        x = (\n            alpha.to(x_spatial.dtype) * x_spatial\n            + (1.0 - alpha).to(x_spatial.dtype) * x_temporal\n        )\n        return x\n"
  },
  {
    "path": "sgm/modules/diffusionmodules/video_model.py",
    "content": "from functools import partial\nfrom typing import List, Optional, Union\n\nfrom einops import rearrange\n\nfrom ...modules.diffusionmodules.openaimodel import *\nfrom ...modules.video_attention import SpatialVideoTransformer\nfrom ...modules.spacetime_attention import (\n    BasicTransformerTimeMixBlock,\n    PostHocSpatialTransformerWithTimeMixing,\n    PostHocSpatialTransformerWithTimeMixingAndMotion,\n)\nfrom ...util import default\nfrom .util import AlphaBlender, get_alpha\n\n\nclass VideoResBlock(ResBlock):\n    def __init__(\n        self,\n        channels: int,\n        emb_channels: int,\n        dropout: float,\n        video_kernel_size: Union[int, List[int]] = 3,\n        merge_strategy: str = \"fixed\",\n        merge_factor: float = 0.5,\n        out_channels: Optional[int] = None,\n        use_conv: bool = False,\n        use_scale_shift_norm: bool = False,\n        dims: int = 2,\n        use_checkpoint: bool = False,\n        up: bool = False,\n        down: bool = False,\n    ):\n        super().__init__(\n            channels,\n            emb_channels,\n            dropout,\n            out_channels=out_channels,\n            use_conv=use_conv,\n            use_scale_shift_norm=use_scale_shift_norm,\n            dims=dims,\n            use_checkpoint=use_checkpoint,\n            up=up,\n            down=down,\n        )\n\n        self.time_stack = ResBlock(\n            default(out_channels, channels),\n            emb_channels,\n            dropout=dropout,\n            dims=3,\n            out_channels=default(out_channels, channels),\n            use_scale_shift_norm=False,\n            use_conv=False,\n            up=False,\n            down=False,\n            kernel_size=video_kernel_size,\n            use_checkpoint=use_checkpoint,\n            exchange_temb_dims=True,\n        )\n        self.time_mixer = AlphaBlender(\n            alpha=merge_factor,\n            merge_strategy=merge_strategy,\n            rearrange_pattern=\"b t -> b 1 t 1 1\",\n        )\n\n    def forward(\n        self,\n        x: th.Tensor,\n        emb: th.Tensor,\n        num_video_frames: int,\n        image_only_indicator: Optional[th.Tensor] = None,\n    ) -> th.Tensor:\n        x = super().forward(x, emb)\n\n        x_mix = rearrange(x, \"(b t) c h w -> b c t h w\", t=num_video_frames)\n        x = rearrange(x, \"(b t) c h w -> b c t h w\", t=num_video_frames)\n\n        x = self.time_stack(\n            x, rearrange(emb, \"(b t) ... -> b t ...\", t=num_video_frames)\n        )\n        x = self.time_mixer(\n            x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator\n        )\n        x = rearrange(x, \"b c t h w -> (b t) c h w\")\n        return x\n\n\nclass VideoUNet(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        model_channels: int,\n        out_channels: int,\n        num_res_blocks: int,\n        attention_resolutions: int,\n        dropout: float = 0.0,\n        channel_mult: List[int] = (1, 2, 4, 8),\n        conv_resample: bool = True,\n        dims: int = 2,\n        num_classes: Optional[int] = None,\n        use_checkpoint: bool = False,\n        num_heads: int = -1,\n        num_head_channels: int = -1,\n        num_heads_upsample: int = -1,\n        use_scale_shift_norm: bool = False,\n        resblock_updown: bool = False,\n        transformer_depth: Union[List[int], int] = 1,\n        transformer_depth_middle: Optional[int] = None,\n        context_dim: Optional[int] = None,\n        time_downup: bool = False,\n        time_context_dim: Optional[int] = None,\n        extra_ff_mix_layer: bool = False,\n        use_spatial_context: bool = False,\n        merge_strategy: str = \"fixed\",\n        merge_factor: float = 0.5,\n        spatial_transformer_attn_type: str = \"softmax\",\n        video_kernel_size: Union[int, List[int]] = 3,\n        use_linear_in_transformer: bool = False,\n        adm_in_channels: Optional[int] = None,\n        disable_temporal_crossattention: bool = False,\n        max_ddpm_temb_period: int = 10000,\n    ):\n        super().__init__()\n        assert context_dim is not None\n\n        if num_heads_upsample == -1:\n            num_heads_upsample = num_heads\n\n        if num_heads == -1:\n            assert num_head_channels != -1\n\n        if num_head_channels == -1:\n            assert num_heads != -1\n\n        self.in_channels = in_channels\n        self.model_channels = model_channels\n        self.out_channels = out_channels\n        if isinstance(transformer_depth, int):\n            transformer_depth = len(channel_mult) * [transformer_depth]\n        transformer_depth_middle = default(\n            transformer_depth_middle, transformer_depth[-1]\n        )\n\n        self.num_res_blocks = num_res_blocks\n        self.attention_resolutions = attention_resolutions\n        self.dropout = dropout\n        self.channel_mult = channel_mult\n        self.conv_resample = conv_resample\n        self.num_classes = num_classes\n        self.use_checkpoint = use_checkpoint\n        self.num_heads = num_heads\n        self.num_head_channels = num_head_channels\n        self.num_heads_upsample = num_heads_upsample\n\n        time_embed_dim = model_channels * 4\n        self.time_embed = nn.Sequential(\n            linear(model_channels, time_embed_dim),\n            nn.SiLU(),\n            linear(time_embed_dim, time_embed_dim),\n        )\n\n        if self.num_classes is not None:\n            if isinstance(self.num_classes, int):\n                self.label_emb = nn.Embedding(num_classes, time_embed_dim)\n            elif self.num_classes == \"continuous\":\n                print(\"setting up linear c_adm embedding layer\")\n                self.label_emb = nn.Linear(1, time_embed_dim)\n            elif self.num_classes == \"timestep\":\n                self.label_emb = nn.Sequential(\n                    Timestep(model_channels),\n                    nn.Sequential(\n                        linear(model_channels, time_embed_dim),\n                        nn.SiLU(),\n                        linear(time_embed_dim, time_embed_dim),\n                    ),\n                )\n\n            elif self.num_classes == \"sequential\":\n                assert adm_in_channels is not None\n                self.label_emb = nn.Sequential(\n                    nn.Sequential(\n                        linear(adm_in_channels, time_embed_dim),\n                        nn.SiLU(),\n                        linear(time_embed_dim, time_embed_dim),\n                    )\n                )\n            else:\n                raise ValueError()\n\n        self.input_blocks = nn.ModuleList(\n            [\n                TimestepEmbedSequential(\n                    conv_nd(dims, in_channels, model_channels, 3, padding=1)\n                )\n            ]\n        )\n        self._feature_size = model_channels\n        input_block_chans = [model_channels]\n        ch = model_channels\n        ds = 1\n\n        def get_attention_layer(\n            ch,\n            num_heads,\n            dim_head,\n            depth=1,\n            context_dim=None,\n            use_checkpoint=False,\n            disabled_sa=False,\n        ):\n            return SpatialVideoTransformer(\n                ch,\n                num_heads,\n                dim_head,\n                depth=depth,\n                context_dim=context_dim,\n                time_context_dim=time_context_dim,\n                dropout=dropout,\n                ff_in=extra_ff_mix_layer,\n                use_spatial_context=use_spatial_context,\n                merge_strategy=merge_strategy,\n                merge_factor=merge_factor,\n                checkpoint=use_checkpoint,\n                use_linear=use_linear_in_transformer,\n                attn_mode=spatial_transformer_attn_type,\n                disable_self_attn=disabled_sa,\n                disable_temporal_crossattention=disable_temporal_crossattention,\n                max_time_embed_period=max_ddpm_temb_period,\n            )\n\n        def get_resblock(\n            merge_factor,\n            merge_strategy,\n            video_kernel_size,\n            ch,\n            time_embed_dim,\n            dropout,\n            out_ch,\n            dims,\n            use_checkpoint,\n            use_scale_shift_norm,\n            down=False,\n            up=False,\n        ):\n            return VideoResBlock(\n                merge_factor=merge_factor,\n                merge_strategy=merge_strategy,\n                video_kernel_size=video_kernel_size,\n                channels=ch,\n                emb_channels=time_embed_dim,\n                dropout=dropout,\n                out_channels=out_ch,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n                down=down,\n                up=up,\n            )\n\n        for level, mult in enumerate(channel_mult):\n            for _ in range(num_res_blocks):\n                layers = [\n                    get_resblock(\n                        merge_factor=merge_factor,\n                        merge_strategy=merge_strategy,\n                        video_kernel_size=video_kernel_size,\n                        ch=ch,\n                        time_embed_dim=time_embed_dim,\n                        dropout=dropout,\n                        out_ch=mult * model_channels,\n                        dims=dims,\n                        use_checkpoint=use_checkpoint,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                    )\n                ]\n                ch = mult * model_channels\n                if ds in attention_resolutions:\n                    if num_head_channels == -1:\n                        dim_head = ch // num_heads\n                    else:\n                        num_heads = ch // num_head_channels\n                        dim_head = num_head_channels\n\n                    layers.append(\n                        get_attention_layer(\n                            ch,\n                            num_heads,\n                            dim_head,\n                            depth=transformer_depth[level],\n                            context_dim=context_dim,\n                            use_checkpoint=use_checkpoint,\n                            disabled_sa=False,\n                        )\n                    )\n                self.input_blocks.append(TimestepEmbedSequential(*layers))\n                self._feature_size += ch\n                input_block_chans.append(ch)\n            if level != len(channel_mult) - 1:\n                ds *= 2\n                out_ch = ch\n                self.input_blocks.append(\n                    TimestepEmbedSequential(\n                        get_resblock(\n                            merge_factor=merge_factor,\n                            merge_strategy=merge_strategy,\n                            video_kernel_size=video_kernel_size,\n                            ch=ch,\n                            time_embed_dim=time_embed_dim,\n                            dropout=dropout,\n                            out_ch=out_ch,\n                            dims=dims,\n                            use_checkpoint=use_checkpoint,\n                            use_scale_shift_norm=use_scale_shift_norm,\n                            down=True,\n                        )\n                        if resblock_updown\n                        else Downsample(\n                            ch,\n                            conv_resample,\n                            dims=dims,\n                            out_channels=out_ch,\n                            third_down=time_downup,\n                        )\n                    )\n                )\n                ch = out_ch\n                input_block_chans.append(ch)\n\n                self._feature_size += ch\n\n        if num_head_channels == -1:\n            dim_head = ch // num_heads\n        else:\n            num_heads = ch // num_head_channels\n            dim_head = num_head_channels\n\n        self.middle_block = TimestepEmbedSequential(\n            get_resblock(\n                merge_factor=merge_factor,\n                merge_strategy=merge_strategy,\n                video_kernel_size=video_kernel_size,\n                ch=ch,\n                time_embed_dim=time_embed_dim,\n                out_ch=None,\n                dropout=dropout,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n            get_attention_layer(\n                ch,\n                num_heads,\n                dim_head,\n                depth=transformer_depth_middle,\n                context_dim=context_dim,\n                use_checkpoint=use_checkpoint,\n            ),\n            get_resblock(\n                merge_factor=merge_factor,\n                merge_strategy=merge_strategy,\n                video_kernel_size=video_kernel_size,\n                ch=ch,\n                out_ch=None,\n                time_embed_dim=time_embed_dim,\n                dropout=dropout,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n        )\n        self._feature_size += ch\n\n        self.output_blocks = nn.ModuleList([])\n        for level, mult in list(enumerate(channel_mult))[::-1]:\n            for i in range(num_res_blocks + 1):\n                ich = input_block_chans.pop()\n                layers = [\n                    get_resblock(\n                        merge_factor=merge_factor,\n                        merge_strategy=merge_strategy,\n                        video_kernel_size=video_kernel_size,\n                        ch=ch + ich,\n                        time_embed_dim=time_embed_dim,\n                        dropout=dropout,\n                        out_ch=model_channels * mult,\n                        dims=dims,\n                        use_checkpoint=use_checkpoint,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                    )\n                ]\n                ch = model_channels * mult\n                if ds in attention_resolutions:\n                    if num_head_channels == -1:\n                        dim_head = ch // num_heads\n                    else:\n                        num_heads = ch // num_head_channels\n                        dim_head = num_head_channels\n\n                    layers.append(\n                        get_attention_layer(\n                            ch,\n                            num_heads,\n                            dim_head,\n                            depth=transformer_depth[level],\n                            context_dim=context_dim,\n                            use_checkpoint=use_checkpoint,\n                            disabled_sa=False,\n                        )\n                    )\n                if level and i == num_res_blocks:\n                    out_ch = ch\n                    ds //= 2\n                    layers.append(\n                        get_resblock(\n                            merge_factor=merge_factor,\n                            merge_strategy=merge_strategy,\n                            video_kernel_size=video_kernel_size,\n                            ch=ch,\n                            time_embed_dim=time_embed_dim,\n                            dropout=dropout,\n                            out_ch=out_ch,\n                            dims=dims,\n                            use_checkpoint=use_checkpoint,\n                            use_scale_shift_norm=use_scale_shift_norm,\n                            up=True,\n                        )\n                        if resblock_updown\n                        else Upsample(\n                            ch,\n                            conv_resample,\n                            dims=dims,\n                            out_channels=out_ch,\n                            third_up=time_downup,\n                        )\n                    )\n\n                self.output_blocks.append(TimestepEmbedSequential(*layers))\n                self._feature_size += ch\n\n        self.out = nn.Sequential(\n            normalization(ch),\n            nn.SiLU(),\n            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),\n        )\n\n    def forward(\n        self,\n        x: th.Tensor,\n        timesteps: th.Tensor,\n        context: Optional[th.Tensor] = None,\n        y: Optional[th.Tensor] = None,\n        time_context: Optional[th.Tensor] = None,\n        num_video_frames: Optional[int] = None,\n        image_only_indicator: Optional[th.Tensor] = None,\n    ):\n        assert (y is not None) == (\n            self.num_classes is not None\n        ), \"must specify y if and only if the model is class-conditional -> no, relax this TODO\"\n        hs = []\n        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)\n        emb = self.time_embed(t_emb)\n\n        if self.num_classes is not None:\n            assert y.shape[0] == x.shape[0]\n            emb = emb + self.label_emb(y)\n\n        h = x\n        for module in self.input_blocks:\n            h = module(\n                h,\n                emb,\n                context=context,\n                image_only_indicator=image_only_indicator,\n                time_context=time_context,\n                num_video_frames=num_video_frames,\n            )\n            hs.append(h)\n        h = self.middle_block(\n            h,\n            emb,\n            context=context,\n            image_only_indicator=image_only_indicator,\n            time_context=time_context,\n            num_video_frames=num_video_frames,\n        )\n        for module in self.output_blocks:\n            h = th.cat([h, hs.pop()], dim=1)\n            h = module(\n                h,\n                emb,\n                context=context,\n                image_only_indicator=image_only_indicator,\n                time_context=time_context,\n                num_video_frames=num_video_frames,\n            )\n        h = h.type(x.dtype)\n        return self.out(h)\n\n\nclass PostHocAttentionBlockWithTimeMixing(AttentionBlock):\n    def __init__(\n        self,\n        in_channels: int,\n        n_heads: int,\n        d_head: int,\n        use_checkpoint: bool = False,\n        use_new_attention_order: bool = False,\n        dropout: float = 0.0,\n        use_spatial_context: bool = False,\n        merge_strategy: bool = \"fixed\",\n        merge_factor: float = 0.5,\n        apply_sigmoid_to_merge: bool = True,\n        ff_in: bool = False,\n        attn_mode: str = \"softmax\",\n        disable_temporal_crossattention: bool = False,\n    ):\n        super().__init__(\n            in_channels,\n            n_heads,\n            d_head,\n            use_checkpoint=use_checkpoint,\n            use_new_attention_order=use_new_attention_order,\n        )\n        inner_dim = n_heads * d_head\n\n        self.time_mix_blocks = nn.ModuleList(\n            [\n                BasicTransformerTimeMixBlock(\n                    inner_dim,\n                    n_heads,\n                    d_head,\n                    dropout=dropout,\n                    checkpoint=use_checkpoint,\n                    ff_in=ff_in,\n                    attn_mode=attn_mode,\n                    disable_temporal_crossattention=disable_temporal_crossattention,\n                )\n            ]\n        )\n        self.in_channels = in_channels\n\n        time_embed_dim = self.in_channels * 4\n        self.time_mix_time_embed = nn.Sequential(\n            linear(self.in_channels, time_embed_dim),\n            nn.SiLU(),\n            linear(time_embed_dim, self.in_channels),\n        )\n\n        self.use_spatial_context = use_spatial_context\n\n        if merge_strategy == \"fixed\":\n            self.register_buffer(\"mix_factor\", th.Tensor([merge_factor]))\n        elif merge_strategy == \"learned\" or merge_strategy == \"learned_with_images\":\n            self.register_parameter(\n                \"mix_factor\", th.nn.Parameter(th.Tensor([merge_factor]))\n            )\n        elif merge_strategy == \"fixed_with_images\":\n            self.mix_factor = None\n        else:\n            raise ValueError(f\"unknown merge strategy {merge_strategy}\")\n\n        self.get_alpha_fn = functools.partial(\n            get_alpha,\n            merge_strategy,\n            self.mix_factor,\n            apply_sigmoid=apply_sigmoid_to_merge,\n        )\n\n    def forward(\n        self,\n        x: th.Tensor,\n        context: Optional[th.Tensor] = None,\n        # cam: Optional[th.Tensor] = None,\n        time_context: Optional[th.Tensor] = None,\n        timesteps: Optional[int] = None,\n        image_only_indicator: Optional[th.Tensor] = None,\n        conv_view: Optional[th.Tensor] = None,\n        conv_motion: Optional[th.Tensor] = None,\n    ):\n        if time_context is not None:\n            raise NotImplementedError\n\n        _, _, h, w = x.shape\n        if exists(context):\n            context = rearrange(context, \"b t ... -> (b t) ...\")\n        if self.use_spatial_context:\n            time_context = repeat(context[:, 0], \"b ... -> (b n) ...\", n=h * w)\n\n        x = super().forward(\n            x,\n        )\n\n        x = rearrange(x, \"b c h w -> b (h w) c\")\n        x_mix = x\n\n        num_frames = th.arange(timesteps, device=x.device)\n        num_frames = repeat(num_frames, \"t -> b t\", b=x.shape[0] // timesteps)\n        num_frames = rearrange(num_frames, \"b t -> (b t)\")\n        t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)\n        emb = self.time_mix_time_embed(t_emb)\n        emb = emb[:, None, :]\n        x_mix = x_mix + emb\n\n        x_mix = self.time_mix_blocks[0](\n            x_mix, context=time_context, timesteps=timesteps\n        )\n\n        alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator)\n        x = alpha * x + (1.0 - alpha) * x_mix\n        x = rearrange(x, \"b (h w) c -> b c h w\", h=h, w=w)\n        return x\n\n\nclass PostHocResBlockWithTime(ResBlock):\n    def __init__(\n        self,\n        channels: int,\n        emb_channels: int,\n        dropout: float,\n        time_kernel_size: Union[int, List[int]] = 3,\n        merge_strategy: bool = \"fixed\",\n        merge_factor: float = 0.5,\n        apply_sigmoid_to_merge: bool = True,\n        out_channels: Optional[int] = None,\n        use_conv: bool = False,\n        use_scale_shift_norm: bool = False,\n        dims: int = 2,\n        use_checkpoint: bool = False,\n        up: bool = False,\n        down: bool = False,\n        time_mix_legacy: bool = True,\n        replicate_bug: bool = False,\n    ):\n        super().__init__(\n            channels,\n            emb_channels,\n            dropout,\n            out_channels=out_channels,\n            use_conv=use_conv,\n            use_scale_shift_norm=use_scale_shift_norm,\n            dims=dims,\n            use_checkpoint=use_checkpoint,\n            up=up,\n            down=down,\n        )\n\n        self.time_mix_blocks = ResBlock(\n            default(out_channels, channels),\n            emb_channels,\n            dropout=dropout,\n            dims=3,\n            out_channels=default(out_channels, channels),\n            use_scale_shift_norm=False,\n            use_conv=False,\n            up=False,\n            down=False,\n            kernel_size=time_kernel_size,\n            use_checkpoint=use_checkpoint,\n            exchange_temb_dims=True,\n        )\n        self.time_mix_legacy = time_mix_legacy\n        if self.time_mix_legacy:\n            if merge_strategy == \"fixed\":\n                self.register_buffer(\"mix_factor\", th.Tensor([merge_factor]))\n            elif merge_strategy == \"learned\" or merge_strategy == \"learned_with_images\":\n                self.register_parameter(\n                    \"mix_factor\", th.nn.Parameter(th.Tensor([merge_factor]))\n                )\n            elif merge_strategy == \"fixed_with_images\":\n                self.mix_factor = None\n            else:\n                raise ValueError(f\"unknown merge strategy {merge_strategy}\")\n\n            self.get_alpha_fn = functools.partial(\n                get_alpha,\n                merge_strategy,\n                self.mix_factor,\n                apply_sigmoid=apply_sigmoid_to_merge,\n            )\n        else:\n            if False: # replicate_bug:\n                logpy.warning(\n                    \"*****************************************************************************************\\n\"\n                    \"GRAVE WARNING: YOU'RE USING THE BUGGY LEGACY ALPHABLENDER!!! ARE YOU SURE YOU WANT THIS?!\\n\"\n                    \"*****************************************************************************************\"\n                )\n                self.time_mixer = LegacyAlphaBlenderWithBug(\n                    alpha=merge_factor,\n                    merge_strategy=merge_strategy,\n                    rearrange_pattern=\"b t -> b 1 t 1 1\",\n                )\n            else:\n                self.time_mixer = AlphaBlender(\n                    alpha=merge_factor,\n                    merge_strategy=merge_strategy,\n                    rearrange_pattern=\"b t -> b 1 t 1 1\",\n                )\n\n    def forward(\n        self,\n        x: th.Tensor,\n        emb: th.Tensor,\n        num_video_frames: int,\n        image_only_indicator: Optional[th.Tensor] = None,\n        cond_view: Optional[th.Tensor] = None,\n        cond_motion: Optional[th.Tensor] = None,\n    ) -> th.Tensor:\n        x = super().forward(x, emb)\n\n        x_mix = rearrange(x, \"(b t) c h w -> b c t h w\", t=num_video_frames)\n        x = rearrange(x, \"(b t) c h w -> b c t h w\", t=num_video_frames)\n\n        x = self.time_mix_blocks(\n            x, rearrange(emb, \"(b t) ... -> b t ...\", t=num_video_frames)\n        )\n\n        if self.time_mix_legacy:\n            alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator*0.0)\n            x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix\n        else:\n            x = self.time_mixer(\n                x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator*0.0\n            )\n        x = rearrange(x, \"b c t h w -> (b t) c h w\")\n        return x\n\n\nclass SpatialUNetModelWithTime(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        model_channels: int,\n        out_channels: int,\n        num_res_blocks: int,\n        attention_resolutions: int,\n        dropout: float = 0.0,\n        channel_mult: List[int] = (1, 2, 4, 8),\n        conv_resample: bool = True,\n        dims: int = 2,\n        num_classes: Optional[int] = None,\n        use_checkpoint: bool = False,\n        num_heads: int = -1,\n        num_head_channels: int = -1,\n        num_heads_upsample: int = -1,\n        use_scale_shift_norm: bool = False,\n        resblock_updown: bool = False,\n        use_new_attention_order: bool = False,\n        use_spatial_transformer: bool = False,\n        transformer_depth: Union[List[int], int] = 1,\n        transformer_depth_middle: Optional[int] = None,\n        context_dim: Optional[int] = None,\n        time_downup: bool = False,\n        time_context_dim: Optional[int] = None,\n        view_context_dim: Optional[int] = None,\n        motion_context_dim: Optional[int] = None,\n        extra_ff_mix_layer: bool = False,\n        use_spatial_context: bool = False,\n        time_block_merge_strategy: str = \"fixed\",\n        time_block_merge_factor: float = 0.5,\n        view_block_merge_factor: float = 0.5,\n        motion_block_merge_factor: float = 0.5,\n        spatial_transformer_attn_type: str = \"softmax\",\n        time_kernel_size: Union[int, List[int]] = 3,\n        use_linear_in_transformer: bool = False,\n        legacy: bool = True,\n        adm_in_channels: Optional[int] = None,\n        use_temporal_resblock: bool = True,\n        disable_temporal_crossattention: bool = False,\n        time_mix_legacy: bool = True,\n        max_ddpm_temb_period: int = 10000,\n        replicate_time_mix_bug: bool = False,\n        use_motion_attention: bool = False,\n        use_camera_emb: bool = False,\n        use_3d_attention: bool = False,\n        separate_motion_merge_factor: bool = False,\n    ):\n        super().__init__()\n\n        if use_spatial_transformer:\n            assert context_dim is not None\n\n        if context_dim is not None:\n            assert use_spatial_transformer\n\n        if num_heads_upsample == -1:\n            num_heads_upsample = num_heads\n\n        if num_heads == -1:\n            assert num_head_channels != -1\n\n        if num_head_channels == -1:\n            assert num_heads != -1\n\n        self.in_channels = in_channels\n        self.model_channels = model_channels\n        self.out_channels = out_channels\n        if isinstance(transformer_depth, int):\n            transformer_depth = len(channel_mult) * [transformer_depth]\n        transformer_depth_middle = default(\n            transformer_depth_middle, transformer_depth[-1]\n        )\n\n        self.num_res_blocks = num_res_blocks\n        self.attention_resolutions = attention_resolutions\n        self.dropout = dropout\n        self.channel_mult = channel_mult\n        self.conv_resample = conv_resample\n        self.num_classes = num_classes\n        self.use_checkpoint = use_checkpoint\n        self.num_heads = num_heads\n        self.num_head_channels = num_head_channels\n        self.num_heads_upsample = num_heads_upsample\n        self.use_temporal_resblocks = use_temporal_resblock\n\n        time_embed_dim = model_channels * 4\n        self.time_embed = nn.Sequential(\n            linear(model_channels, time_embed_dim),\n            nn.SiLU(),\n            linear(time_embed_dim, time_embed_dim),\n        )\n\n        if self.num_classes is not None:\n            if isinstance(self.num_classes, int):\n                self.label_emb = nn.Embedding(num_classes, time_embed_dim)\n            elif self.num_classes == \"continuous\":\n                print(\"setting up linear c_adm embedding layer\")\n                self.label_emb = nn.Linear(1, time_embed_dim)\n            elif self.num_classes == \"timestep\":\n                self.label_emb = nn.Sequential(\n                    Timestep(model_channels),\n                    nn.Sequential(\n                        linear(model_channels, time_embed_dim),\n                        nn.SiLU(),\n                        linear(time_embed_dim, time_embed_dim),\n                    ),\n                )\n\n            elif self.num_classes == \"sequential\":\n                assert adm_in_channels is not None\n                self.label_emb = nn.Sequential(\n                    nn.Sequential(\n                        linear(adm_in_channels, time_embed_dim),\n                        nn.SiLU(),\n                        linear(time_embed_dim, time_embed_dim),\n                    )\n                )\n            else:\n                raise ValueError()\n\n        self.input_blocks = nn.ModuleList(\n            [\n                TimestepEmbedSequential(\n                    conv_nd(dims, in_channels, model_channels, 3, padding=1)\n                )\n            ]\n        )\n        self._feature_size = model_channels\n        input_block_chans = [model_channels]\n        ch = model_channels\n        ds = 1\n\n        def get_attention_layer(\n            ch,\n            num_heads,\n            dim_head,\n            depth=1,\n            context_dim=None,\n            use_checkpoint=False,\n            disabled_sa=False,\n        ):\n            if not use_spatial_transformer:\n                return PostHocAttentionBlockWithTimeMixing(\n                    ch,\n                    num_heads,\n                    dim_head,\n                    use_checkpoint=use_checkpoint,\n                    use_new_attention_order=use_new_attention_order,\n                    dropout=dropout,\n                    ff_in=extra_ff_mix_layer,\n                    use_spatial_context=use_spatial_context,\n                    merge_strategy=time_block_merge_strategy,\n                    merge_factor=time_block_merge_factor,\n                    attn_mode=spatial_transformer_attn_type,\n                    disable_temporal_crossattention=disable_temporal_crossattention,\n                )\n\n            elif use_motion_attention:\n                return PostHocSpatialTransformerWithTimeMixingAndMotion(\n                    ch,\n                    num_heads,\n                    dim_head,\n                    depth=depth,\n                    context_dim=context_dim,\n                    time_context_dim=time_context_dim,\n                    motion_context_dim=motion_context_dim,\n                    dropout=dropout,\n                    ff_in=extra_ff_mix_layer,\n                    use_spatial_context=use_spatial_context,\n                    use_camera_emb=use_camera_emb,\n                    use_3d_attention=use_3d_attention,\n                    separate_motion_merge_factor=separate_motion_merge_factor,\n                    adm_in_channels=adm_in_channels,\n                    merge_strategy=time_block_merge_strategy,\n                    merge_factor=time_block_merge_factor,\n                    merge_factor_motion=motion_block_merge_factor,\n                    checkpoint=use_checkpoint,\n                    use_linear=use_linear_in_transformer,\n                    attn_mode=spatial_transformer_attn_type,\n                    disable_self_attn=disabled_sa,\n                    disable_temporal_crossattention=disable_temporal_crossattention,\n                    time_mix_legacy=time_mix_legacy,\n                    max_time_embed_period=max_ddpm_temb_period,\n                )\n\n            else:\n                return PostHocSpatialTransformerWithTimeMixing(\n                    ch,\n                    num_heads,\n                    dim_head,\n                    depth=depth,\n                    context_dim=context_dim,\n                    time_context_dim=time_context_dim,\n                    dropout=dropout,\n                    ff_in=extra_ff_mix_layer,\n                    use_spatial_context=use_spatial_context,\n                    merge_strategy=time_block_merge_strategy,\n                    merge_factor=time_block_merge_factor,\n                    checkpoint=use_checkpoint,\n                    use_linear=use_linear_in_transformer,\n                    attn_mode=spatial_transformer_attn_type,\n                    disable_self_attn=disabled_sa,\n                    disable_temporal_crossattention=disable_temporal_crossattention,\n                    time_mix_legacy=time_mix_legacy,\n                    max_time_embed_period=max_ddpm_temb_period,\n                )\n\n        def get_resblock(\n            time_block_merge_factor,\n            time_block_merge_strategy,\n            time_kernel_size,\n            ch,\n            time_embed_dim,\n            dropout,\n            out_ch,\n            dims,\n            use_checkpoint,\n            use_scale_shift_norm,\n            down=False,\n            up=False,\n        ):\n            if self.use_temporal_resblocks:\n                return PostHocResBlockWithTime(\n                    merge_factor=time_block_merge_factor,\n                    merge_strategy=time_block_merge_strategy,\n                    time_kernel_size=time_kernel_size,\n                    channels=ch,\n                    emb_channels=time_embed_dim,\n                    dropout=dropout,\n                    out_channels=out_ch,\n                    dims=dims,\n                    use_checkpoint=use_checkpoint,\n                    use_scale_shift_norm=use_scale_shift_norm,\n                    down=down,\n                    up=up,\n                    time_mix_legacy=time_mix_legacy,\n                    replicate_bug=replicate_time_mix_bug,\n                )\n            else:\n                return ResBlock(\n                    channels=ch,\n                    emb_channels=time_embed_dim,\n                    dropout=dropout,\n                    out_channels=out_ch,\n                    use_checkpoint=use_checkpoint,\n                    dims=dims,\n                    use_scale_shift_norm=use_scale_shift_norm,\n                    down=down,\n                    up=up,\n                )\n\n        for level, mult in enumerate(channel_mult):\n            for _ in range(num_res_blocks):\n                layers = [\n                    get_resblock(\n                        time_block_merge_factor=time_block_merge_factor,\n                        time_block_merge_strategy=time_block_merge_strategy,\n                        time_kernel_size=time_kernel_size,\n                        ch=ch,\n                        time_embed_dim=time_embed_dim,\n                        dropout=dropout,\n                        out_ch=mult * model_channels,\n                        dims=dims,\n                        use_checkpoint=use_checkpoint,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                    )\n                ]\n                ch = mult * model_channels\n                if ds in attention_resolutions:\n                    if num_head_channels == -1:\n                        dim_head = ch // num_heads\n                    else:\n                        num_heads = ch // num_head_channels\n                        dim_head = num_head_channels\n                    if legacy:\n                        dim_head = (\n                            ch // num_heads\n                            if use_spatial_transformer\n                            else num_head_channels\n                        )\n\n                    layers.append(\n                        get_attention_layer(\n                            ch,\n                            num_heads,\n                            dim_head,\n                            depth=transformer_depth[level],\n                            context_dim=context_dim,\n                            use_checkpoint=use_checkpoint,\n                            disabled_sa=False,\n                        )\n                    )\n                self.input_blocks.append(TimestepEmbedSequential(*layers))\n                self._feature_size += ch\n                input_block_chans.append(ch)\n            if level != len(channel_mult) - 1:\n                ds *= 2\n                out_ch = ch\n                self.input_blocks.append(\n                    TimestepEmbedSequential(\n                        get_resblock(\n                            time_block_merge_factor=time_block_merge_factor,\n                            time_block_merge_strategy=time_block_merge_strategy,\n                            time_kernel_size=time_kernel_size,\n                            ch=ch,\n                            time_embed_dim=time_embed_dim,\n                            dropout=dropout,\n                            out_ch=out_ch,\n                            dims=dims,\n                            use_checkpoint=use_checkpoint,\n                            use_scale_shift_norm=use_scale_shift_norm,\n                            down=True,\n                        )\n                        if resblock_updown\n                        else Downsample(\n                            ch,\n                            conv_resample,\n                            dims=dims,\n                            out_channels=out_ch,\n                            third_down=time_downup,\n                        )\n                    )\n                )\n                ch = out_ch\n                input_block_chans.append(ch)\n\n                self._feature_size += ch\n\n        if num_head_channels == -1:\n            dim_head = ch // num_heads\n        else:\n            num_heads = ch // num_head_channels\n            dim_head = num_head_channels\n        if legacy:\n            # num_heads = 1\n            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels\n\n        self.middle_block = TimestepEmbedSequential(\n            get_resblock(\n                time_block_merge_factor=time_block_merge_factor,\n                time_block_merge_strategy=time_block_merge_strategy,\n                time_kernel_size=time_kernel_size,\n                ch=ch,\n                time_embed_dim=time_embed_dim,\n                out_ch=None,\n                dropout=dropout,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n            get_attention_layer(\n                ch,\n                num_heads,\n                dim_head,\n                depth=transformer_depth_middle,\n                context_dim=context_dim,\n                use_checkpoint=use_checkpoint,\n            ),\n            get_resblock(\n                time_block_merge_factor=time_block_merge_factor,\n                time_block_merge_strategy=time_block_merge_strategy,\n                time_kernel_size=time_kernel_size,\n                ch=ch,\n                out_ch=None,\n                time_embed_dim=time_embed_dim,\n                dropout=dropout,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n        )\n        self._feature_size += ch\n\n        self.output_blocks = nn.ModuleList([])\n        for level, mult in list(enumerate(channel_mult))[::-1]:\n            for i in range(num_res_blocks + 1):\n                ich = input_block_chans.pop()\n                layers = [\n                    get_resblock(\n                        time_block_merge_factor=time_block_merge_factor,\n                        time_block_merge_strategy=time_block_merge_strategy,\n                        time_kernel_size=time_kernel_size,\n                        ch=ch + ich,\n                        time_embed_dim=time_embed_dim,\n                        dropout=dropout,\n                        out_ch=model_channels * mult,\n                        dims=dims,\n                        use_checkpoint=use_checkpoint,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                    )\n                ]\n                ch = model_channels * mult\n                if ds in attention_resolutions:\n                    if num_head_channels == -1:\n                        dim_head = ch // num_heads\n                    else:\n                        num_heads = ch // num_head_channels\n                        dim_head = num_head_channels\n                    if legacy:\n                        dim_head = (\n                            ch // num_heads\n                            if use_spatial_transformer\n                            else num_head_channels\n                        )\n\n                    layers.append(\n                        get_attention_layer(\n                            ch,\n                            num_heads,\n                            dim_head,\n                            depth=transformer_depth[level],\n                            context_dim=context_dim,\n                            use_checkpoint=use_checkpoint,\n                            disabled_sa=False,\n                        )\n                    )\n                if level and i == num_res_blocks:\n                    out_ch = ch\n                    ds //= 2\n                    layers.append(\n                        get_resblock(\n                            time_block_merge_factor=time_block_merge_factor,\n                            time_block_merge_strategy=time_block_merge_strategy,\n                            time_kernel_size=time_kernel_size,\n                            ch=ch,\n                            time_embed_dim=time_embed_dim,\n                            dropout=dropout,\n                            out_ch=out_ch,\n                            dims=dims,\n                            use_checkpoint=use_checkpoint,\n                            use_scale_shift_norm=use_scale_shift_norm,\n                            up=True,\n                        )\n                        if resblock_updown\n                        else Upsample(\n                            ch,\n                            conv_resample,\n                            dims=dims,\n                            out_channels=out_ch,\n                            third_up=time_downup,\n                        )\n                    )\n\n                self.output_blocks.append(TimestepEmbedSequential(*layers))\n                self._feature_size += ch\n\n        self.out = nn.Sequential(\n            normalization(ch),\n            nn.SiLU(),\n            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),\n        )\n\n    def forward(\n        self,\n        x: th.Tensor,\n        timesteps: th.Tensor,\n        context: Optional[th.Tensor] = None,\n        y: Optional[th.Tensor] = None,\n        cam: Optional[th.Tensor] = None,\n        time_context: Optional[th.Tensor] = None,\n        num_video_frames: Optional[int] = None,\n        image_only_indicator: Optional[th.Tensor] = None,\n        cond_view: Optional[th.Tensor] = None,\n        cond_motion: Optional[th.Tensor] = None,\n        time_step: Optional[int] = None,\n    ):\n        assert (y is not None) == (\n            self.num_classes is not None\n        ), \"must specify y if and only if the model is class-conditional -> no, relax this TODO\"\n        hs = []\n        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # 21 x 320\n        emb = self.time_embed(t_emb) # 21 x 1280\n        time = str(timesteps[0].data.cpu().numpy())\n\n        if self.num_classes is not None:\n            assert y.shape[0] == x.shape[0]\n            emb = emb + self.label_emb(y) # 21 x 1280\n\n        h = x # 21 x 8 x 64 x 64\n        for i, module in enumerate(self.input_blocks):\n            h = module(\n                h,\n                emb,\n                context=context,\n                cam=cam,\n                image_only_indicator=image_only_indicator,\n                cond_view=cond_view,\n                cond_motion=cond_motion,\n                time_context=time_context,\n                num_video_frames=num_video_frames,\n                time_step=time_step,\n                name='encoder_{}_{}'.format(time, i)\n            )\n            hs.append(h)\n        h = self.middle_block(\n            h,\n            emb,\n            context=context,\n            cam=cam,\n            image_only_indicator=image_only_indicator,\n            cond_view=cond_view,\n            cond_motion=cond_motion,\n            time_context=time_context,\n            num_video_frames=num_video_frames,\n            time_step=time_step,\n            name='middle_{}_0'.format(time, i)\n        )\n        for i, module in enumerate(self.output_blocks):\n            h = th.cat([h, hs.pop()], dim=1)\n            h = module(\n                h,\n                emb,\n                context=context,\n                cam=cam,\n                image_only_indicator=image_only_indicator,\n                cond_view=cond_view,\n                cond_motion=cond_motion,\n                time_context=time_context,\n                num_video_frames=num_video_frames,\n                time_step=time_step,\n                name='decoder_{}_{}'.format(time, i)\n            )\n        h = h.type(x.dtype)\n        return self.out(h)\n"
  },
  {
    "path": "sgm/modules/diffusionmodules/wrappers.py",
    "content": "import torch\nimport torch.nn as nn\nfrom packaging import version\n\nOPENAIUNETWRAPPER = \"sgm.modules.diffusionmodules.wrappers.OpenAIWrapper\"\n\n\nclass IdentityWrapper(nn.Module):\n    def __init__(self, diffusion_model, compile_model: bool = False):\n        super().__init__()\n        compile = (\n            torch.compile\n            if (version.parse(torch.__version__) >= version.parse(\"2.0.0\"))\n            and compile_model\n            else lambda x: x\n        )\n        self.diffusion_model = compile(diffusion_model)\n\n    def forward(self, *args, **kwargs):\n        return self.diffusion_model(*args, **kwargs)\n\n\nclass OpenAIWrapper(IdentityWrapper):\n    def forward(\n        self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs\n    ) -> torch.Tensor:\n        x = torch.cat((x, c.get(\"concat\", torch.Tensor([]).type_as(x))), dim=1)\n        if \"cond_view\" in c:\n            return self.diffusion_model(\n                x,\n                timesteps=t,\n                context=c.get(\"crossattn\", None),\n                y=c.get(\"vector\", None),\n                cond_view=c.get(\"cond_view\", None),\n                cond_motion=c.get(\"cond_motion\", None),\n                **kwargs,\n            )\n        else:\n            return self.diffusion_model(\n                x,\n                timesteps=t,\n                context=c.get(\"crossattn\", None),\n                y=c.get(\"vector\", None),\n                **kwargs,\n            )\n"
  },
  {
    "path": "sgm/modules/distributions/__init__.py",
    "content": ""
  },
  {
    "path": "sgm/modules/distributions/distributions.py",
    "content": "import numpy as np\nimport torch\n\n\nclass AbstractDistribution:\n    def sample(self):\n        raise NotImplementedError()\n\n    def mode(self):\n        raise NotImplementedError()\n\n\nclass DiracDistribution(AbstractDistribution):\n    def __init__(self, value):\n        self.value = value\n\n    def sample(self):\n        return self.value\n\n    def mode(self):\n        return self.value\n\n\nclass DiagonalGaussianDistribution(object):\n    def __init__(self, parameters, deterministic=False):\n        self.parameters = parameters\n        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)\n        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)\n        self.deterministic = deterministic\n        self.std = torch.exp(0.5 * self.logvar)\n        self.var = torch.exp(self.logvar)\n        if self.deterministic:\n            self.var = self.std = torch.zeros_like(self.mean).to(\n                device=self.parameters.device\n            )\n\n    def sample(self):\n        x = self.mean + self.std * torch.randn(self.mean.shape).to(\n            device=self.parameters.device\n        )\n        return x\n\n    def kl(self, other=None):\n        if self.deterministic:\n            return torch.Tensor([0.0])\n        else:\n            if other is None:\n                return 0.5 * torch.sum(\n                    torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,\n                    dim=[1, 2, 3],\n                )\n            else:\n                return 0.5 * torch.sum(\n                    torch.pow(self.mean - other.mean, 2) / other.var\n                    + self.var / other.var\n                    - 1.0\n                    - self.logvar\n                    + other.logvar,\n                    dim=[1, 2, 3],\n                )\n\n    def nll(self, sample, dims=[1, 2, 3]):\n        if self.deterministic:\n            return torch.Tensor([0.0])\n        logtwopi = np.log(2.0 * np.pi)\n        return 0.5 * torch.sum(\n            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,\n            dim=dims,\n        )\n\n    def mode(self):\n        return self.mean\n\n\ndef normal_kl(mean1, logvar1, mean2, logvar2):\n    \"\"\"\n    source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12\n    Compute the KL divergence between two gaussians.\n    Shapes are automatically broadcasted, so batches can be compared to\n    scalars, among other use cases.\n    \"\"\"\n    tensor = None\n    for obj in (mean1, logvar1, mean2, logvar2):\n        if isinstance(obj, torch.Tensor):\n            tensor = obj\n            break\n    assert tensor is not None, \"at least one argument must be a Tensor\"\n\n    # Force variances to be Tensors. Broadcasting helps convert scalars to\n    # Tensors, but it does not work for torch.exp().\n    logvar1, logvar2 = [\n        x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)\n        for x in (logvar1, logvar2)\n    ]\n\n    return 0.5 * (\n        -1.0\n        + logvar2\n        - logvar1\n        + torch.exp(logvar1 - logvar2)\n        + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)\n    )\n"
  },
  {
    "path": "sgm/modules/ema.py",
    "content": "import torch\nfrom torch import nn\n\n\nclass LitEma(nn.Module):\n    def __init__(self, model, decay=0.9999, use_num_upates=True):\n        super().__init__()\n        if decay < 0.0 or decay > 1.0:\n            raise ValueError(\"Decay must be between 0 and 1\")\n\n        self.m_name2s_name = {}\n        self.register_buffer(\"decay\", torch.tensor(decay, dtype=torch.float32))\n        self.register_buffer(\n            \"num_updates\",\n            torch.tensor(0, dtype=torch.int)\n            if use_num_upates\n            else torch.tensor(-1, dtype=torch.int),\n        )\n\n        for name, p in model.named_parameters():\n            if p.requires_grad:\n                # remove as '.'-character is not allowed in buffers\n                s_name = name.replace(\".\", \"\")\n                self.m_name2s_name.update({name: s_name})\n                self.register_buffer(s_name, p.clone().detach().data)\n\n        self.collected_params = []\n\n    def reset_num_updates(self):\n        del self.num_updates\n        self.register_buffer(\"num_updates\", torch.tensor(0, dtype=torch.int))\n\n    def forward(self, model):\n        decay = self.decay\n\n        if self.num_updates >= 0:\n            self.num_updates += 1\n            decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))\n\n        one_minus_decay = 1.0 - decay\n\n        with torch.no_grad():\n            m_param = dict(model.named_parameters())\n            shadow_params = dict(self.named_buffers())\n\n            for key in m_param:\n                if m_param[key].requires_grad:\n                    sname = self.m_name2s_name[key]\n                    shadow_params[sname] = shadow_params[sname].type_as(m_param[key])\n                    shadow_params[sname].sub_(\n                        one_minus_decay * (shadow_params[sname] - m_param[key])\n                    )\n                else:\n                    assert not key in self.m_name2s_name\n\n    def copy_to(self, model):\n        m_param = dict(model.named_parameters())\n        shadow_params = dict(self.named_buffers())\n        for key in m_param:\n            if m_param[key].requires_grad:\n                m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)\n            else:\n                assert not key in self.m_name2s_name\n\n    def store(self, parameters):\n        \"\"\"\n        Save the current parameters for restoring later.\n        Args:\n          parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n            temporarily stored.\n        \"\"\"\n        self.collected_params = [param.clone() for param in parameters]\n\n    def restore(self, parameters):\n        \"\"\"\n        Restore the parameters stored with the `store` method.\n        Useful to validate the model with EMA parameters without affecting the\n        original optimization process. Store the parameters before the\n        `copy_to` method. After validation (or model saving), use this to\n        restore the former parameters.\n        Args:\n          parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n            updated with the stored parameters.\n        \"\"\"\n        for c_param, param in zip(self.collected_params, parameters):\n            param.data.copy_(c_param.data)\n"
  },
  {
    "path": "sgm/modules/encoders/__init__.py",
    "content": ""
  },
  {
    "path": "sgm/modules/encoders/modules.py",
    "content": "import math\nfrom contextlib import nullcontext\nfrom functools import partial\nfrom typing import Dict, List, Optional, Tuple, Union\n\nimport kornia\nimport numpy as np\nimport open_clip\nimport torch\nimport torch.nn as nn\nfrom einops import rearrange, repeat\nfrom omegaconf import ListConfig\nfrom torch.utils.checkpoint import checkpoint\nfrom transformers import (ByT5Tokenizer, CLIPTextModel, CLIPTokenizer,\n                          T5EncoderModel, T5Tokenizer)\n\nfrom ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer\nfrom ...modules.diffusionmodules.model import Encoder\nfrom ...modules.diffusionmodules.openaimodel import Timestep\nfrom ...modules.diffusionmodules.util import (extract_into_tensor,\n                                              make_beta_schedule)\nfrom ...modules.distributions.distributions import DiagonalGaussianDistribution\nfrom ...util import (append_dims, autocast, count_params, default,\n                     disabled_train, expand_dims_like, instantiate_from_config)\n\n\nclass AbstractEmbModel(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self._is_trainable = None\n        self._ucg_rate = None\n        self._input_key = None\n\n    @property\n    def is_trainable(self) -> bool:\n        return self._is_trainable\n\n    @property\n    def ucg_rate(self) -> Union[float, torch.Tensor]:\n        return self._ucg_rate\n\n    @property\n    def input_key(self) -> str:\n        return self._input_key\n\n    @is_trainable.setter\n    def is_trainable(self, value: bool):\n        self._is_trainable = value\n\n    @ucg_rate.setter\n    def ucg_rate(self, value: Union[float, torch.Tensor]):\n        self._ucg_rate = value\n\n    @input_key.setter\n    def input_key(self, value: str):\n        self._input_key = value\n\n    @is_trainable.deleter\n    def is_trainable(self):\n        del self._is_trainable\n\n    @ucg_rate.deleter\n    def ucg_rate(self):\n        del self._ucg_rate\n\n    @input_key.deleter\n    def input_key(self):\n        del self._input_key\n\n\nclass GeneralConditioner(nn.Module):\n    OUTPUT_DIM2KEYS = {2: \"vector\", 3: \"crossattn\", 4: \"concat\"} # , 5: \"concat\"}\n    KEY2CATDIM = {\"vector\": 1, \"crossattn\": 2, \"concat\": 1, \"cond_view\": 1, \"cond_motion\": 1}\n\n    def __init__(self, emb_models: Union[List, ListConfig]):\n        super().__init__()\n        embedders = []\n        for n, embconfig in enumerate(emb_models):\n            embedder = instantiate_from_config(embconfig)\n            assert isinstance(\n                embedder, AbstractEmbModel\n            ), f\"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel\"\n            embedder.is_trainable = embconfig.get(\"is_trainable\", False)\n            embedder.ucg_rate = embconfig.get(\"ucg_rate\", 0.0)\n            if not embedder.is_trainable:\n                embedder.train = disabled_train\n                for param in embedder.parameters():\n                    param.requires_grad = False\n                embedder.eval()\n            print(\n                f\"Initialized embedder #{n}: {embedder.__class__.__name__} \"\n                f\"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}\"\n            )\n\n            if \"input_key\" in embconfig:\n                embedder.input_key = embconfig[\"input_key\"]\n            elif \"input_keys\" in embconfig:\n                embedder.input_keys = embconfig[\"input_keys\"]\n            else:\n                raise KeyError(\n                    f\"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}\"\n                )\n\n            embedder.legacy_ucg_val = embconfig.get(\"legacy_ucg_value\", None)\n            if embedder.legacy_ucg_val is not None:\n                embedder.ucg_prng = np.random.RandomState()\n\n            embedders.append(embedder)\n        self.embedders = nn.ModuleList(embedders)\n\n    def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict:\n        assert embedder.legacy_ucg_val is not None\n        p = embedder.ucg_rate\n        val = embedder.legacy_ucg_val\n        for i in range(len(batch[embedder.input_key])):\n            if embedder.ucg_prng.choice(2, p=[1 - p, p]):\n                batch[embedder.input_key][i] = val\n        return batch\n\n    def forward(\n        self, batch: Dict, force_zero_embeddings: Optional[List] = None\n    ) -> Dict:\n        output = dict()\n        if force_zero_embeddings is None:\n            force_zero_embeddings = []\n        for embedder in self.embedders:\n            embedding_context = nullcontext if embedder.is_trainable else torch.no_grad\n            with embedding_context():\n                if hasattr(embedder, \"input_key\") and (embedder.input_key is not None):\n                    if embedder.legacy_ucg_val is not None:\n                        batch = self.possibly_get_ucg_val(embedder, batch)\n                    emb_out = embedder(batch[embedder.input_key])\n                elif hasattr(embedder, \"input_keys\"):\n                    emb_out = embedder(*[batch[k] for k in embedder.input_keys])\n            assert isinstance(\n                emb_out, (torch.Tensor, list, tuple)\n            ), f\"encoder outputs must be tensors or a sequence, but got {type(emb_out)}\"\n            if not isinstance(emb_out, (list, tuple)):\n                emb_out = [emb_out]\n            for emb in emb_out:\n                if embedder.input_key in [\"cond_view\", \"cond_motion\"]:\n                    out_key = embedder.input_key\n                else:\n                    out_key = self.OUTPUT_DIM2KEYS[emb.dim()]\n\n                if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:\n                    emb = (\n                        expand_dims_like(\n                            torch.bernoulli(\n                                (1.0 - embedder.ucg_rate)\n                                * torch.ones(emb.shape[0], device=emb.device)\n                            ),\n                            emb,\n                        )\n                        * emb\n                    )\n                if (\n                    hasattr(embedder, \"input_key\")\n                    and embedder.input_key in force_zero_embeddings\n                ):\n                    emb = torch.zeros_like(emb)\n                if out_key in output:\n                    output[out_key] = torch.cat(\n                        (output[out_key], emb), self.KEY2CATDIM[out_key]\n                    )\n                else:\n                    output[out_key] = emb\n        return output\n\n    def get_unconditional_conditioning(\n        self,\n        batch_c: Dict,\n        batch_uc: Optional[Dict] = None,\n        force_uc_zero_embeddings: Optional[List[str]] = None,\n        force_cond_zero_embeddings: Optional[List[str]] = None,\n    ):\n        if force_uc_zero_embeddings is None:\n            force_uc_zero_embeddings = []\n        ucg_rates = list()\n        for embedder in self.embedders:\n            ucg_rates.append(embedder.ucg_rate)\n            embedder.ucg_rate = 0.0\n        c = self(batch_c, force_cond_zero_embeddings)\n        uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings)\n\n        for embedder, rate in zip(self.embedders, ucg_rates):\n            embedder.ucg_rate = rate\n        return c, uc\n\n\nclass InceptionV3(nn.Module):\n    \"\"\"Wrapper around the https://github.com/mseitzer/pytorch-fid inception\n    port with an additional squeeze at the end\"\"\"\n\n    def __init__(self, normalize_input=False, **kwargs):\n        super().__init__()\n        from pytorch_fid import inception\n\n        kwargs[\"resize_input\"] = True\n        self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs)\n\n    def forward(self, inp):\n        outp = self.model(inp)\n\n        if len(outp) == 1:\n            return outp[0].squeeze()\n\n        return outp\n\n\nclass IdentityEncoder(AbstractEmbModel):\n    def encode(self, x):\n        return x\n\n    def forward(self, x):\n        return x\n\n\nclass ClassEmbedder(AbstractEmbModel):\n    def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False):\n        super().__init__()\n        self.embedding = nn.Embedding(n_classes, embed_dim)\n        self.n_classes = n_classes\n        self.add_sequence_dim = add_sequence_dim\n\n    def forward(self, c):\n        c = self.embedding(c)\n        if self.add_sequence_dim:\n            c = c[:, None, :]\n        return c\n\n    def get_unconditional_conditioning(self, bs, device=\"cuda\"):\n        uc_class = (\n            self.n_classes - 1\n        )  # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)\n        uc = torch.ones((bs,), device=device) * uc_class\n        uc = {self.key: uc.long()}\n        return uc\n\n\nclass ClassEmbedderForMultiCond(ClassEmbedder):\n    def forward(self, batch, key=None, disable_dropout=False):\n        out = batch\n        key = default(key, self.key)\n        islist = isinstance(batch[key], list)\n        if islist:\n            batch[key] = batch[key][0]\n        c_out = super().forward(batch, key, disable_dropout)\n        out[key] = [c_out] if islist else c_out\n        return out\n\n\nclass FrozenT5Embedder(AbstractEmbModel):\n    \"\"\"Uses the T5 transformer encoder for text\"\"\"\n\n    def __init__(\n        self, version=\"google/t5-v1_1-xxl\", device=\"cuda\", max_length=77, freeze=True\n    ):  # others are google/t5-v1_1-xl and google/t5-v1_1-xxl\n        super().__init__()\n        self.tokenizer = T5Tokenizer.from_pretrained(version)\n        self.transformer = T5EncoderModel.from_pretrained(version)\n        self.device = device\n        self.max_length = max_length\n        if freeze:\n            self.freeze()\n\n    def freeze(self):\n        self.transformer = self.transformer.eval()\n\n        for param in self.parameters():\n            param.requires_grad = False\n\n    def forward(self, text):\n        batch_encoding = self.tokenizer(\n            text,\n            truncation=True,\n            max_length=self.max_length,\n            return_length=True,\n            return_overflowing_tokens=False,\n            padding=\"max_length\",\n            return_tensors=\"pt\",\n        )\n        tokens = batch_encoding[\"input_ids\"].to(self.device)\n        with torch.autocast(\"cuda\", enabled=False):\n            outputs = self.transformer(input_ids=tokens)\n        z = outputs.last_hidden_state\n        return z\n\n    def encode(self, text):\n        return self(text)\n\n\nclass FrozenByT5Embedder(AbstractEmbModel):\n    \"\"\"\n    Uses the ByT5 transformer encoder for text. Is character-aware.\n    \"\"\"\n\n    def __init__(\n        self, version=\"google/byt5-base\", device=\"cuda\", max_length=77, freeze=True\n    ):  # others are google/t5-v1_1-xl and google/t5-v1_1-xxl\n        super().__init__()\n        self.tokenizer = ByT5Tokenizer.from_pretrained(version)\n        self.transformer = T5EncoderModel.from_pretrained(version)\n        self.device = device\n        self.max_length = max_length\n        if freeze:\n            self.freeze()\n\n    def freeze(self):\n        self.transformer = self.transformer.eval()\n\n        for param in self.parameters():\n            param.requires_grad = False\n\n    def forward(self, text):\n        batch_encoding = self.tokenizer(\n            text,\n            truncation=True,\n            max_length=self.max_length,\n            return_length=True,\n            return_overflowing_tokens=False,\n            padding=\"max_length\",\n            return_tensors=\"pt\",\n        )\n        tokens = batch_encoding[\"input_ids\"].to(self.device)\n        with torch.autocast(\"cuda\", enabled=False):\n            outputs = self.transformer(input_ids=tokens)\n        z = outputs.last_hidden_state\n        return z\n\n    def encode(self, text):\n        return self(text)\n\n\nclass FrozenCLIPEmbedder(AbstractEmbModel):\n    \"\"\"Uses the CLIP transformer encoder for text (from huggingface)\"\"\"\n\n    LAYERS = [\"last\", \"pooled\", \"hidden\"]\n\n    def __init__(\n        self,\n        version=\"openai/clip-vit-large-patch14\",\n        device=\"cuda\",\n        max_length=77,\n        freeze=True,\n        layer=\"last\",\n        layer_idx=None,\n        always_return_pooled=False,\n    ):  # clip-vit-base-patch32\n        super().__init__()\n        assert layer in self.LAYERS\n        self.tokenizer = CLIPTokenizer.from_pretrained(version)\n        self.transformer = CLIPTextModel.from_pretrained(version)\n        self.device = device\n        self.max_length = max_length\n        if freeze:\n            self.freeze()\n        self.layer = layer\n        self.layer_idx = layer_idx\n        self.return_pooled = always_return_pooled\n        if layer == \"hidden\":\n            assert layer_idx is not None\n            assert 0 <= abs(layer_idx) <= 12\n\n    def freeze(self):\n        self.transformer = self.transformer.eval()\n\n        for param in self.parameters():\n            param.requires_grad = False\n\n    @autocast\n    def forward(self, text):\n        batch_encoding = self.tokenizer(\n            text,\n            truncation=True,\n            max_length=self.max_length,\n            return_length=True,\n            return_overflowing_tokens=False,\n            padding=\"max_length\",\n            return_tensors=\"pt\",\n        )\n        tokens = batch_encoding[\"input_ids\"].to(self.device)\n        outputs = self.transformer(\n            input_ids=tokens, output_hidden_states=self.layer == \"hidden\"\n        )\n        if self.layer == \"last\":\n            z = outputs.last_hidden_state\n        elif self.layer == \"pooled\":\n            z = outputs.pooler_output[:, None, :]\n        else:\n            z = outputs.hidden_states[self.layer_idx]\n        if self.return_pooled:\n            return z, outputs.pooler_output\n        return z\n\n    def encode(self, text):\n        return self(text)\n\n\nclass FrozenOpenCLIPEmbedder2(AbstractEmbModel):\n    \"\"\"\n    Uses the OpenCLIP transformer encoder for text\n    \"\"\"\n\n    LAYERS = [\"pooled\", \"last\", \"penultimate\"]\n\n    def __init__(\n        self,\n        arch=\"ViT-H-14\",\n        version=\"laion2b_s32b_b79k\",\n        device=\"cuda\",\n        max_length=77,\n        freeze=True,\n        layer=\"last\",\n        always_return_pooled=False,\n        legacy=True,\n    ):\n        super().__init__()\n        assert layer in self.LAYERS\n        model, _, _ = open_clip.create_model_and_transforms(\n            arch,\n            device=torch.device(\"cpu\"),\n            pretrained=version,\n        )\n        del model.visual\n        self.model = model\n\n        self.device = device\n        self.max_length = max_length\n        self.return_pooled = always_return_pooled\n        if freeze:\n            self.freeze()\n        self.layer = layer\n        if self.layer == \"last\":\n            self.layer_idx = 0\n        elif self.layer == \"penultimate\":\n            self.layer_idx = 1\n        else:\n            raise NotImplementedError()\n        self.legacy = legacy\n\n    def freeze(self):\n        self.model = self.model.eval()\n        for param in self.parameters():\n            param.requires_grad = False\n\n    @autocast\n    def forward(self, text):\n        tokens = open_clip.tokenize(text)\n        z = self.encode_with_transformer(tokens.to(self.device))\n        if not self.return_pooled and self.legacy:\n            return z\n        if self.return_pooled:\n            assert not self.legacy\n            return z[self.layer], z[\"pooled\"]\n        return z[self.layer]\n\n    def encode_with_transformer(self, text):\n        x = self.model.token_embedding(text)  # [batch_size, n_ctx, d_model]\n        x = x + self.model.positional_embedding\n        x = x.permute(1, 0, 2)  # NLD -> LND\n        x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)\n        if self.legacy:\n            x = x[self.layer]\n            x = self.model.ln_final(x)\n            return x\n        else:\n            # x is a dict and will stay a dict\n            o = x[\"last\"]\n            o = self.model.ln_final(o)\n            pooled = self.pool(o, text)\n            x[\"pooled\"] = pooled\n            return x\n\n    def pool(self, x, text):\n        # take features from the eot embedding (eot_token is the highest number in each sequence)\n        x = (\n            x[torch.arange(x.shape[0]), text.argmax(dim=-1)]\n            @ self.model.text_projection\n        )\n        return x\n\n    def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):\n        outputs = {}\n        for i, r in enumerate(self.model.transformer.resblocks):\n            if i == len(self.model.transformer.resblocks) - 1:\n                outputs[\"penultimate\"] = x.permute(1, 0, 2)  # LND -> NLD\n            if (\n                self.model.transformer.grad_checkpointing\n                and not torch.jit.is_scripting()\n            ):\n                x = checkpoint(r, x, attn_mask)\n            else:\n                x = r(x, attn_mask=attn_mask)\n        outputs[\"last\"] = x.permute(1, 0, 2)  # LND -> NLD\n        return outputs\n\n    def encode(self, text):\n        return self(text)\n\n\nclass FrozenOpenCLIPEmbedder(AbstractEmbModel):\n    LAYERS = [\n        # \"pooled\",\n        \"last\",\n        \"penultimate\",\n    ]\n\n    def __init__(\n        self,\n        arch=\"ViT-H-14\",\n        version=\"laion2b_s32b_b79k\",\n        device=\"cuda\",\n        max_length=77,\n        freeze=True,\n        layer=\"last\",\n    ):\n        super().__init__()\n        assert layer in self.LAYERS\n        model, _, _ = open_clip.create_model_and_transforms(\n            arch, device=torch.device(\"cpu\"), pretrained=version\n        )\n        del model.visual\n        self.model = model\n\n        self.device = device\n        self.max_length = max_length\n        if freeze:\n            self.freeze()\n        self.layer = layer\n        if self.layer == \"last\":\n            self.layer_idx = 0\n        elif self.layer == \"penultimate\":\n            self.layer_idx = 1\n        else:\n            raise NotImplementedError()\n\n    def freeze(self):\n        self.model = self.model.eval()\n        for param in self.parameters():\n            param.requires_grad = False\n\n    def forward(self, text):\n        tokens = open_clip.tokenize(text)\n        z = self.encode_with_transformer(tokens.to(self.device))\n        return z\n\n    def encode_with_transformer(self, text):\n        x = self.model.token_embedding(text)  # [batch_size, n_ctx, d_model]\n        x = x + self.model.positional_embedding\n        x = x.permute(1, 0, 2)  # NLD -> LND\n        x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)\n        x = x.permute(1, 0, 2)  # LND -> NLD\n        x = self.model.ln_final(x)\n        return x\n\n    def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):\n        for i, r in enumerate(self.model.transformer.resblocks):\n            if i == len(self.model.transformer.resblocks) - self.layer_idx:\n                break\n            if (\n                self.model.transformer.grad_checkpointing\n                and not torch.jit.is_scripting()\n            ):\n                x = checkpoint(r, x, attn_mask)\n            else:\n                x = r(x, attn_mask=attn_mask)\n        return x\n\n    def encode(self, text):\n        return self(text)\n\n\nclass FrozenOpenCLIPImageEmbedder(AbstractEmbModel):\n    \"\"\"\n    Uses the OpenCLIP vision transformer encoder for images\n    \"\"\"\n\n    def __init__(\n        self,\n        arch=\"ViT-H-14\",\n        version=\"laion2b_s32b_b79k\",\n        device=\"cuda\",\n        max_length=77,\n        freeze=True,\n        antialias=True,\n        ucg_rate=0.0,\n        unsqueeze_dim=False,\n        repeat_to_max_len=False,\n        num_image_crops=0,\n        output_tokens=False,\n        init_device=None,\n    ):\n        super().__init__()\n        model, _, _ = open_clip.create_model_and_transforms(\n            arch,\n            device=torch.device(default(init_device, \"cpu\")),\n            pretrained=version,\n        )\n        del model.transformer\n        self.model = model\n        self.max_crops = num_image_crops\n        self.pad_to_max_len = self.max_crops > 0\n        self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)\n        self.device = device\n        self.max_length = max_length\n        if freeze:\n            self.freeze()\n\n        self.antialias = antialias\n\n        self.register_buffer(\n            \"mean\", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False\n        )\n        self.register_buffer(\n            \"std\", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False\n        )\n        self.ucg_rate = ucg_rate\n        self.unsqueeze_dim = unsqueeze_dim\n        self.stored_batch = None\n        self.model.visual.output_tokens = output_tokens\n        self.output_tokens = output_tokens\n\n    def preprocess(self, x):\n        # normalize to [0,1]\n        x = kornia.geometry.resize(\n            x,\n            (224, 224),\n            interpolation=\"bicubic\",\n            align_corners=True,\n            antialias=self.antialias,\n        )\n        x = (x + 1.0) / 2.0\n        # renormalize according to clip\n        x = kornia.enhance.normalize(x, self.mean, self.std)\n        return x\n\n    def freeze(self):\n        self.model = self.model.eval()\n        for param in self.parameters():\n            param.requires_grad = False\n\n    @autocast\n    def forward(self, image, no_dropout=False):\n        z = self.encode_with_vision_transformer(image)\n        tokens = None\n        if self.output_tokens:\n            z, tokens = z[0], z[1]\n        z = z.to(image.dtype)\n        if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):\n            z = (\n                torch.bernoulli(\n                    (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)\n                )[:, None]\n                * z\n            )\n            if tokens is not None:\n                tokens = (\n                    expand_dims_like(\n                        torch.bernoulli(\n                            (1.0 - self.ucg_rate)\n                            * torch.ones(tokens.shape[0], device=tokens.device)\n                        ),\n                        tokens,\n                    )\n                    * tokens\n                )\n        if self.unsqueeze_dim:\n            z = z[:, None, :]\n        if self.output_tokens:\n            assert not self.repeat_to_max_len\n            assert not self.pad_to_max_len\n            return tokens, z\n        if self.repeat_to_max_len:\n            if z.dim() == 2:\n                z_ = z[:, None, :]\n            else:\n                z_ = z\n            return repeat(z_, \"b 1 d -> b n d\", n=self.max_length), z\n        elif self.pad_to_max_len:\n            assert z.dim() == 3\n            z_pad = torch.cat(\n                (\n                    z,\n                    torch.zeros(\n                        z.shape[0],\n                        self.max_length - z.shape[1],\n                        z.shape[2],\n                        device=z.device,\n                    ),\n                ),\n                1,\n            )\n            return z_pad, z_pad[:, 0, ...]\n        return z\n\n    def encode_with_vision_transformer(self, img):\n        # if self.max_crops > 0:\n        #    img = self.preprocess_by_cropping(img)\n        if img.dim() == 5:\n            assert self.max_crops == img.shape[1]\n            img = rearrange(img, \"b n c h w -> (b n) c h w\")\n        img = self.preprocess(img)\n        if not self.output_tokens:\n            assert not self.model.visual.output_tokens\n            x = self.model.visual(img)\n            tokens = None\n        else:\n            assert self.model.visual.output_tokens\n            x, tokens = self.model.visual(img)\n        if self.max_crops > 0:\n            x = rearrange(x, \"(b n) d -> b n d\", n=self.max_crops)\n            # drop out between 0 and all along the sequence axis\n            x = (\n                torch.bernoulli(\n                    (1.0 - self.ucg_rate)\n                    * torch.ones(x.shape[0], x.shape[1], 1, device=x.device)\n                )\n                * x\n            )\n            if tokens is not None:\n                tokens = rearrange(tokens, \"(b n) t d -> b t (n d)\", n=self.max_crops)\n                print(\n                    f\"You are running very experimental token-concat in {self.__class__.__name__}. \"\n                    f\"Check what you are doing, and then remove this message.\"\n                )\n        if self.output_tokens:\n            return x, tokens\n        return x\n\n    def encode(self, text):\n        return self(text)\n\n\nclass FrozenCLIPT5Encoder(AbstractEmbModel):\n    def __init__(\n        self,\n        clip_version=\"openai/clip-vit-large-patch14\",\n        t5_version=\"google/t5-v1_1-xl\",\n        device=\"cuda\",\n        clip_max_length=77,\n        t5_max_length=77,\n    ):\n        super().__init__()\n        self.clip_encoder = FrozenCLIPEmbedder(\n            clip_version, device, max_length=clip_max_length\n        )\n        self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)\n        print(\n            f\"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, \"\n            f\"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.\"\n        )\n\n    def encode(self, text):\n        return self(text)\n\n    def forward(self, text):\n        clip_z = self.clip_encoder.encode(text)\n        t5_z = self.t5_encoder.encode(text)\n        return [clip_z, t5_z]\n\n\nclass SpatialRescaler(nn.Module):\n    def __init__(\n        self,\n        n_stages=1,\n        method=\"bilinear\",\n        multiplier=0.5,\n        in_channels=3,\n        out_channels=None,\n        bias=False,\n        wrap_video=False,\n        kernel_size=1,\n        remap_output=False,\n    ):\n        super().__init__()\n        self.n_stages = n_stages\n        assert self.n_stages >= 0\n        assert method in [\n            \"nearest\",\n            \"linear\",\n            \"bilinear\",\n            \"trilinear\",\n            \"bicubic\",\n            \"area\",\n        ]\n        self.multiplier = multiplier\n        self.interpolator = partial(torch.nn.functional.interpolate, mode=method)\n        self.remap_output = out_channels is not None or remap_output\n        if self.remap_output:\n            print(\n                f\"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.\"\n            )\n            self.channel_mapper = nn.Conv2d(\n                in_channels,\n                out_channels,\n                kernel_size=kernel_size,\n                bias=bias,\n                padding=kernel_size // 2,\n            )\n        self.wrap_video = wrap_video\n\n    def forward(self, x):\n        if self.wrap_video and x.ndim == 5:\n            B, C, T, H, W = x.shape\n            x = rearrange(x, \"b c t h w -> b t c h w\")\n            x = rearrange(x, \"b t c h w -> (b t) c h w\")\n\n        for stage in range(self.n_stages):\n            x = self.interpolator(x, scale_factor=self.multiplier)\n\n        if self.wrap_video:\n            x = rearrange(x, \"(b t) c h w -> b t c h w\", b=B, t=T, c=C)\n            x = rearrange(x, \"b t c h w -> b c t h w\")\n        if self.remap_output:\n            x = self.channel_mapper(x)\n        return x\n\n    def encode(self, x):\n        return self(x)\n\n\nclass LowScaleEncoder(nn.Module):\n    def __init__(\n        self,\n        model_config,\n        linear_start,\n        linear_end,\n        timesteps=1000,\n        max_noise_level=250,\n        output_size=64,\n        scale_factor=1.0,\n    ):\n        super().__init__()\n        self.max_noise_level = max_noise_level\n        self.model = instantiate_from_config(model_config)\n        self.augmentation_schedule = self.register_schedule(\n            timesteps=timesteps, linear_start=linear_start, linear_end=linear_end\n        )\n        self.out_size = output_size\n        self.scale_factor = scale_factor\n\n    def register_schedule(\n        self,\n        beta_schedule=\"linear\",\n        timesteps=1000,\n        linear_start=1e-4,\n        linear_end=2e-2,\n        cosine_s=8e-3,\n    ):\n        betas = make_beta_schedule(\n            beta_schedule,\n            timesteps,\n            linear_start=linear_start,\n            linear_end=linear_end,\n            cosine_s=cosine_s,\n        )\n        alphas = 1.0 - betas\n        alphas_cumprod = np.cumprod(alphas, axis=0)\n        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])\n\n        (timesteps,) = betas.shape\n        self.num_timesteps = int(timesteps)\n        self.linear_start = linear_start\n        self.linear_end = linear_end\n        assert (\n            alphas_cumprod.shape[0] == self.num_timesteps\n        ), \"alphas have to be defined for each timestep\"\n\n        to_torch = partial(torch.tensor, dtype=torch.float32)\n\n        self.register_buffer(\"betas\", to_torch(betas))\n        self.register_buffer(\"alphas_cumprod\", to_torch(alphas_cumprod))\n        self.register_buffer(\"alphas_cumprod_prev\", to_torch(alphas_cumprod_prev))\n\n        # calculations for diffusion q(x_t | x_{t-1}) and others\n        self.register_buffer(\"sqrt_alphas_cumprod\", to_torch(np.sqrt(alphas_cumprod)))\n        self.register_buffer(\n            \"sqrt_one_minus_alphas_cumprod\", to_torch(np.sqrt(1.0 - alphas_cumprod))\n        )\n        self.register_buffer(\n            \"log_one_minus_alphas_cumprod\", to_torch(np.log(1.0 - alphas_cumprod))\n        )\n        self.register_buffer(\n            \"sqrt_recip_alphas_cumprod\", to_torch(np.sqrt(1.0 / alphas_cumprod))\n        )\n        self.register_buffer(\n            \"sqrt_recipm1_alphas_cumprod\", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))\n        )\n\n    def q_sample(self, x_start, t, noise=None):\n        noise = default(noise, lambda: torch.randn_like(x_start))\n        return (\n            extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start\n            + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)\n            * noise\n        )\n\n    def forward(self, x):\n        z = self.model.encode(x)\n        if isinstance(z, DiagonalGaussianDistribution):\n            z = z.sample()\n        z = z * self.scale_factor\n        noise_level = torch.randint(\n            0, self.max_noise_level, (x.shape[0],), device=x.device\n        ).long()\n        z = self.q_sample(z, noise_level)\n        if self.out_size is not None:\n            z = torch.nn.functional.interpolate(z, size=self.out_size, mode=\"nearest\")\n        return z, noise_level\n\n    def decode(self, z):\n        z = z / self.scale_factor\n        return self.model.decode(z)\n\n\nclass ConcatTimestepEmbedderND(AbstractEmbModel):\n    \"\"\"embeds each dimension independently and concatenates them\"\"\"\n\n    def __init__(self, outdim):\n        super().__init__()\n        self.timestep = Timestep(outdim)\n        self.outdim = outdim\n\n    def forward(self, x):\n        if x.ndim == 1:\n            x = x[:, None]\n        assert len(x.shape) == 2\n        b, dims = x.shape[0], x.shape[1]\n        x = rearrange(x, \"b d -> (b d)\")\n        emb = self.timestep(x)\n        emb = rearrange(emb, \"(b d) d2 -> b (d d2)\", b=b, d=dims, d2=self.outdim)\n        return emb\n\n\nclass GaussianEncoder(Encoder, AbstractEmbModel):\n    def __init__(\n        self, weight: float = 1.0, flatten_output: bool = True, *args, **kwargs\n    ):\n        super().__init__(*args, **kwargs)\n        self.posterior = DiagonalGaussianRegularizer()\n        self.weight = weight\n        self.flatten_output = flatten_output\n\n    def forward(self, x) -> Tuple[Dict, torch.Tensor]:\n        z = super().forward(x)\n        z, log = self.posterior(z)\n        log[\"loss\"] = log[\"kl_loss\"]\n        log[\"weight\"] = self.weight\n        if self.flatten_output:\n            z = rearrange(z, \"b c h w -> b (h w ) c\")\n        return log, z\n\n\nclass VideoPredictionEmbedderWithEncoder(AbstractEmbModel):\n    def __init__(\n        self,\n        n_cond_frames: int,\n        n_copies: int,\n        encoder_config: dict,\n        sigma_sampler_config: Optional[dict] = None,\n        sigma_cond_config: Optional[dict] = None,\n        is_ae: bool = False,\n        scale_factor: float = 1.0,\n        disable_encoder_autocast: bool = False,\n        en_and_decode_n_samples_a_time: Optional[int] = None,\n    ):\n        super().__init__()\n\n        self.n_cond_frames = n_cond_frames\n        self.n_copies = n_copies\n        self.encoder = instantiate_from_config(encoder_config)\n        self.sigma_sampler = (\n            instantiate_from_config(sigma_sampler_config)\n            if sigma_sampler_config is not None\n            else None\n        )\n        self.sigma_cond = (\n            instantiate_from_config(sigma_cond_config)\n            if sigma_cond_config is not None\n            else None\n        )\n        self.is_ae = is_ae\n        self.scale_factor = scale_factor\n        self.disable_encoder_autocast = disable_encoder_autocast\n        self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time\n\n    def forward(\n        self, vid: torch.Tensor\n    ) -> Union[\n        torch.Tensor,\n        Tuple[torch.Tensor, torch.Tensor],\n        Tuple[torch.Tensor, dict],\n        Tuple[Tuple[torch.Tensor, torch.Tensor], dict],\n    ]:\n        if self.sigma_sampler is not None:\n            b = vid.shape[0] // self.n_cond_frames\n            sigmas = self.sigma_sampler(b).to(vid.device)\n            if self.sigma_cond is not None:\n                sigma_cond = self.sigma_cond(sigmas)\n                if self.n_cond_frames == 1:\n                    sigma_cond = repeat(sigma_cond, \"b d -> (b t) d\", t=self.n_copies)\n                else:\n                    sigma_cond = repeat(sigma_cond, \"b d -> (b t) d\", t=self.n_cond_frames) # For SV4D\n            sigmas = repeat(sigmas, \"b -> (b t)\", t=self.n_cond_frames)\n            noise = torch.randn_like(vid)\n            vid = vid + noise * append_dims(sigmas, vid.ndim)\n\n        with torch.autocast(\"cuda\", enabled=not self.disable_encoder_autocast):\n            n_samples = (\n                self.en_and_decode_n_samples_a_time\n                if self.en_and_decode_n_samples_a_time is not None\n                else vid.shape[0]\n            )\n            n_rounds = math.ceil(vid.shape[0] / n_samples)\n            all_out = []\n            for n in range(n_rounds):\n                if self.is_ae:\n                    out = self.encoder.encode(vid[n * n_samples : (n + 1) * n_samples])\n                else:\n                    out = self.encoder(vid[n * n_samples : (n + 1) * n_samples])\n                all_out.append(out)\n\n        vid = torch.cat(all_out, dim=0)\n        vid *= self.scale_factor\n\n        if self.n_cond_frames == 1:\n            vid = rearrange(vid, \"(b t) c h w -> b () (t c) h w\", t=self.n_cond_frames)\n            vid = repeat(vid, \"b 1 c h w -> (b t) c h w\", t=self.n_copies)\n\n        return_val = (vid, sigma_cond) if self.sigma_cond is not None else vid\n\n        return return_val\n\n\nclass FrozenOpenCLIPImagePredictionEmbedder(AbstractEmbModel):\n    def __init__(\n        self,\n        open_clip_embedding_config: Dict,\n        n_cond_frames: int,\n        n_copies: int,\n    ):\n        super().__init__()\n\n        self.n_cond_frames = n_cond_frames\n        self.n_copies = n_copies\n        self.open_clip = instantiate_from_config(open_clip_embedding_config)\n\n    def forward(self, vid):\n        vid = self.open_clip(vid)\n        vid = rearrange(vid, \"(b t) d -> b t d\", t=self.n_cond_frames)\n        vid = repeat(vid, \"b t d -> (b s) t d\", s=self.n_copies)\n\n        return vid\n"
  },
  {
    "path": "sgm/modules/spacetime_attention.py",
    "content": "from functools import partial\n\nimport torch\nimport torch.nn.functional as F\n\nfrom ..modules.attention import *\nfrom ..modules.diffusionmodules.util import (\n    AlphaBlender,\n    get_alpha,\n    linear,\n    mixed_checkpoint,\n    timestep_embedding,\n)\n\n\nclass TimeMixSequential(nn.Sequential):\n    def forward(self, x, context=None, timesteps=None):\n        for layer in self:\n            x = layer(x, context, timesteps)\n\n        return x\n\n\nclass BasicTransformerTimeMixBlock(nn.Module):\n    ATTENTION_MODES = {\n        \"softmax\": CrossAttention,\n        \"softmax-xformers\": MemoryEfficientCrossAttention,\n    }\n\n    def __init__(\n        self,\n        dim,\n        n_heads,\n        d_head,\n        dropout=0.0,\n        context_dim=None,\n        gated_ff=True,\n        checkpoint=True,\n        timesteps=None,\n        ff_in=False,\n        inner_dim=None,\n        attn_mode=\"softmax\",\n        disable_self_attn=False,\n        disable_temporal_crossattention=False,\n        switch_temporal_ca_to_sa=False,\n    ):\n        super().__init__()\n\n        attn_cls = self.ATTENTION_MODES[attn_mode]\n\n        self.ff_in = ff_in or inner_dim is not None\n        if inner_dim is None:\n            inner_dim = dim\n\n        assert int(n_heads * d_head) == inner_dim\n\n        self.is_res = inner_dim == dim\n\n        if self.ff_in:\n            self.norm_in = nn.LayerNorm(dim)\n            self.ff_in = FeedForward(\n                dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff\n            )\n\n        self.timesteps = timesteps\n        self.disable_self_attn = disable_self_attn\n        if self.disable_self_attn:\n            self.attn1 = attn_cls(\n                query_dim=inner_dim,\n                heads=n_heads,\n                dim_head=d_head,\n                context_dim=context_dim,\n                dropout=dropout,\n            )  # is a cross-attention\n        else:\n            self.attn1 = attn_cls(\n                query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout\n            )  # is a self-attention\n\n        self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff)\n\n        if disable_temporal_crossattention:\n            if switch_temporal_ca_to_sa:\n                raise ValueError\n            else:\n                self.attn2 = None\n        else:\n            self.norm2 = nn.LayerNorm(inner_dim)\n            if switch_temporal_ca_to_sa:\n                self.attn2 = attn_cls(\n                    query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout\n                )  # is a self-attention\n            else:\n                self.attn2 = attn_cls(\n                    query_dim=inner_dim,\n                    context_dim=context_dim,\n                    heads=n_heads,\n                    dim_head=d_head,\n                    dropout=dropout,\n                )  # is self-attn if context is none\n\n        self.norm1 = nn.LayerNorm(inner_dim)\n        self.norm3 = nn.LayerNorm(inner_dim)\n        self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa\n\n        self.checkpoint = checkpoint\n        if self.checkpoint:\n            logpy.info(f\"{self.__class__.__name__} is using checkpointing\")\n\n    def forward(\n        self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None\n    ) -> torch.Tensor:\n        if self.checkpoint:\n            return checkpoint(self._forward, x, context, timesteps)\n        else:\n            return self._forward(x, context, timesteps=timesteps)\n\n    def _forward(self, x, context=None, timesteps=None):\n        assert self.timesteps or timesteps\n        assert not (self.timesteps and timesteps) or self.timesteps == timesteps\n        timesteps = self.timesteps or timesteps\n        B, S, C = x.shape\n        x = rearrange(x, \"(b t) s c -> (b s) t c\", t=timesteps)\n\n        if self.ff_in:\n            x_skip = x\n            x = self.ff_in(self.norm_in(x))\n            if self.is_res:\n                x += x_skip\n\n        if self.disable_self_attn:\n            x = self.attn1(self.norm1(x), context=context) + x\n        else:\n            x = self.attn1(self.norm1(x)) + x\n\n        if self.attn2 is not None:\n            if self.switch_temporal_ca_to_sa:\n                x = self.attn2(self.norm2(x)) + x\n            else:\n                x = self.attn2(self.norm2(x), context=context) + x\n        x_skip = x\n        x = self.ff(self.norm3(x))\n        if self.is_res:\n            x += x_skip\n\n        x = rearrange(\n            x, \"(b s) t c -> (b t) s c\", s=S, b=B // timesteps, c=C, t=timesteps\n        )\n        return x\n\n    def get_last_layer(self):\n        return self.ff.net[-1].weight\n\n\nclass PostHocSpatialTransformerWithTimeMixing(SpatialTransformer):\n    def __init__(\n        self,\n        in_channels,\n        n_heads,\n        d_head,\n        depth=1,\n        dropout=0.0,\n        use_linear=False,\n        context_dim=None,\n        use_spatial_context=False,\n        timesteps=None,\n        merge_strategy: str = \"fixed\",\n        merge_factor: float = 0.5,\n        apply_sigmoid_to_merge: bool = True,\n        time_context_dim=None,\n        ff_in=False,\n        checkpoint=False,\n        time_depth=1,\n        attn_mode=\"softmax\",\n        disable_self_attn=False,\n        disable_temporal_crossattention=False,\n        time_mix_legacy: bool = True,\n        max_time_embed_period: int = 10000,\n    ):\n        super().__init__(\n            in_channels,\n            n_heads,\n            d_head,\n            depth=depth,\n            dropout=dropout,\n            attn_type=attn_mode,\n            use_checkpoint=checkpoint,\n            context_dim=context_dim,\n            use_linear=use_linear,\n            disable_self_attn=disable_self_attn,\n        )\n        self.time_depth = time_depth\n        self.depth = depth\n        self.max_time_embed_period = max_time_embed_period\n\n        time_mix_d_head = d_head\n        n_time_mix_heads = n_heads\n\n        time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)\n\n        inner_dim = n_heads * d_head\n        if use_spatial_context:\n            time_context_dim = context_dim\n\n        self.time_mix_blocks = nn.ModuleList(\n            [\n                BasicTransformerTimeMixBlock(\n                    inner_dim,\n                    n_time_mix_heads,\n                    time_mix_d_head,\n                    dropout=dropout,\n                    context_dim=time_context_dim,\n                    timesteps=timesteps,\n                    checkpoint=checkpoint,\n                    ff_in=ff_in,\n                    inner_dim=time_mix_inner_dim,\n                    attn_mode=attn_mode,\n                    disable_self_attn=disable_self_attn,\n                    disable_temporal_crossattention=disable_temporal_crossattention,\n                )\n                for _ in range(self.depth)\n            ]\n        )\n\n        assert len(self.time_mix_blocks) == len(self.transformer_blocks)\n\n        self.use_spatial_context = use_spatial_context\n        self.in_channels = in_channels\n\n        time_embed_dim = self.in_channels * 4\n        self.time_mix_time_embed = nn.Sequential(\n            linear(self.in_channels, time_embed_dim),\n            nn.SiLU(),\n            linear(time_embed_dim, self.in_channels),\n        )\n\n        self.time_mix_legacy = time_mix_legacy\n        if self.time_mix_legacy:\n            if merge_strategy == \"fixed\":\n                self.register_buffer(\"mix_factor\", torch.Tensor([merge_factor]))\n            elif merge_strategy == \"learned\" or merge_strategy == \"learned_with_images\":\n                self.register_parameter(\n                    \"mix_factor\", torch.nn.Parameter(torch.Tensor([merge_factor]))\n                )\n            elif merge_strategy == \"fixed_with_images\":\n                self.mix_factor = None\n            else:\n                raise ValueError(f\"unknown merge strategy {merge_strategy}\")\n\n            self.get_alpha_fn = partial(\n                get_alpha,\n                merge_strategy,\n                self.mix_factor,\n                apply_sigmoid=apply_sigmoid_to_merge,\n                is_attn=True,\n            )\n        else:\n            self.time_mixer = AlphaBlender(\n                alpha=merge_factor, merge_strategy=merge_strategy\n            )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        context: Optional[torch.Tensor] = None,\n        # cam: Optional[torch.Tensor] = None,\n        time_context: Optional[torch.Tensor] = None,\n        timesteps: Optional[int] = None,\n        image_only_indicator: Optional[torch.Tensor] = None,\n        cond_view: Optional[torch.Tensor] = None,\n        cond_motion: Optional[torch.Tensor] = None,\n        time_step: Optional[int] = None,\n        name: Optional[str] = None,\n    ) -> torch.Tensor:\n        _, _, h, w = x.shape\n        x_in = x\n        spatial_context = None\n        if exists(context):\n            spatial_context = context\n\n        if self.use_spatial_context:\n            assert (\n                context.ndim == 3\n            ), f\"n dims of spatial context should be 3 but are {context.ndim}\"\n\n            time_context = context\n            time_context_first_timestep = time_context[::timesteps]\n            time_context = repeat(\n                time_context_first_timestep, \"b ... -> (b n) ...\", n=h * w\n            )\n        elif time_context is not None and not self.use_spatial_context:\n            time_context = repeat(time_context, \"b ... -> (b n) ...\", n=h * w)\n            if time_context.ndim == 2:\n                time_context = rearrange(time_context, \"b c -> b 1 c\")\n\n        x = self.norm(x)\n        if not self.use_linear:\n            x = self.proj_in(x)\n        x = rearrange(x, \"b c h w -> b (h w) c\")\n        if self.use_linear:\n            x = self.proj_in(x)\n\n        if self.time_mix_legacy:\n            alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator)\n\n        num_frames = torch.arange(timesteps, device=x.device)\n        num_frames = repeat(num_frames, \"t -> b t\", b=x.shape[0] // timesteps)\n        num_frames = rearrange(num_frames, \"b t -> (b t)\")\n        t_emb = timestep_embedding(\n            num_frames,\n            self.in_channels,\n            repeat_only=False,\n            max_period=self.max_time_embed_period,\n        )\n        emb = self.time_mix_time_embed(t_emb)\n        emb = emb[:, None, :]\n\n        for it_, (block, mix_block) in enumerate(\n            zip(self.transformer_blocks, self.time_mix_blocks)\n        ):\n            # spatial attention\n            x = block(\n                x,\n                context=spatial_context,\n                time_step=time_step, \n                name=name + '_' + str(it_)\n            )\n\n            x_mix = x\n            x_mix = x_mix + emb\n\n            # temporal attention\n            x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps)\n            if self.time_mix_legacy:\n                x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix\n            else:\n                x = self.time_mixer(\n                    x_spatial=x,\n                    x_temporal=x_mix,\n                    image_only_indicator=image_only_indicator,\n                )\n\n        if self.use_linear:\n            x = self.proj_out(x)\n        x = rearrange(x, \"b (h w) c -> b c h w\", h=h, w=w)\n        if not self.use_linear:\n            x = self.proj_out(x)\n        out = x + x_in\n        return out\n    \n\nclass PostHocSpatialTransformerWithTimeMixingAndMotion(SpatialTransformer):\n    def __init__(\n        self,\n        in_channels,\n        n_heads,\n        d_head,\n        depth=1,\n        dropout=0.0,\n        use_linear=False,\n        context_dim=None,\n        use_spatial_context=False,\n        use_camera_emb=False,\n        use_3d_attention=False,\n        separate_motion_merge_factor=False,\n        adm_in_channels=None,\n        timesteps=None,\n        merge_strategy: str = \"fixed\",\n        merge_factor: float = 0.5,\n        merge_factor_motion: float = 0.5,\n        apply_sigmoid_to_merge: bool = True,\n        time_context_dim=None,\n        motion_context_dim=None,\n        ff_in=False,\n        checkpoint=False,\n        time_depth=1,\n        attn_mode=\"softmax\",\n        disable_self_attn=False,\n        disable_temporal_crossattention=False,\n        time_mix_legacy: bool = True,\n        max_time_embed_period: int = 10000,\n    ):\n        super().__init__(\n            in_channels,\n            n_heads,\n            d_head,\n            depth=depth,\n            dropout=dropout,\n            attn_type=attn_mode,\n            use_checkpoint=checkpoint,\n            context_dim=context_dim,\n            use_linear=use_linear,\n            disable_self_attn=disable_self_attn,\n        )\n        self.time_depth = time_depth\n        self.depth = depth\n        self.max_time_embed_period = max_time_embed_period\n        self.use_camera_emb = use_camera_emb\n        self.motion_context_dim = motion_context_dim\n        self.use_3d_attention = use_3d_attention\n        self.separate_motion_merge_factor = separate_motion_merge_factor\n\n        time_mix_d_head = d_head\n        n_time_mix_heads = n_heads\n\n        time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)\n\n        inner_dim = n_heads * d_head\n        if use_spatial_context:\n            time_context_dim = context_dim\n\n        # Camera attention layer\n        self.time_mix_blocks = nn.ModuleList(\n            [\n                BasicTransformerTimeMixBlock(\n                    inner_dim,\n                    n_time_mix_heads,\n                    time_mix_d_head,\n                    dropout=dropout,\n                    context_dim=time_context_dim,\n                    timesteps=timesteps,\n                    checkpoint=checkpoint,\n                    ff_in=ff_in,\n                    inner_dim=time_mix_inner_dim,\n                    attn_mode=attn_mode,\n                    disable_self_attn=disable_self_attn,\n                    disable_temporal_crossattention=disable_temporal_crossattention,\n                )\n                for _ in range(self.depth)\n            ]\n        )\n\n        # Motion attention layer\n        self.motion_blocks = nn.ModuleList(\n            [\n                BasicTransformerTimeMixBlock(\n                    inner_dim,\n                    n_time_mix_heads,\n                    time_mix_d_head,\n                    dropout=dropout,\n                    context_dim=motion_context_dim,\n                    timesteps=timesteps,\n                    checkpoint=checkpoint,\n                    ff_in=ff_in,\n                    inner_dim=time_mix_inner_dim,\n                    attn_mode=attn_mode,\n                    disable_self_attn=disable_self_attn,\n                    disable_temporal_crossattention=disable_temporal_crossattention,\n                )\n                for _ in range(self.depth)\n            ]\n        )\n\n        assert len(self.time_mix_blocks) == len(self.transformer_blocks)\n\n        self.use_spatial_context = use_spatial_context\n        self.in_channels = in_channels\n\n        time_embed_dim = self.in_channels * 4\n        time_embed_channels = adm_in_channels if self.use_camera_emb else self.in_channels\n        # Camera view embedding\n        self.time_mix_time_embed = nn.Sequential(\n            linear(time_embed_channels, time_embed_dim),\n            nn.SiLU(),\n            linear(time_embed_dim, self.in_channels),\n        )\n        # Motion time embedding\n        self.time_mix_motion_embed = nn.Sequential(\n            linear(self.in_channels, time_embed_dim),\n            nn.SiLU(),\n            linear(time_embed_dim, self.in_channels),\n        )\n\n        self.time_mix_legacy = time_mix_legacy\n        if self.time_mix_legacy:\n            if merge_strategy == \"fixed\":\n                self.register_buffer(\"mix_factor\", torch.Tensor([merge_factor]))\n            elif merge_strategy == \"learned\" or merge_strategy == \"learned_with_images\":\n                self.register_parameter(\n                    \"mix_factor\", torch.nn.Parameter(torch.Tensor([merge_factor]))\n                )\n            elif merge_strategy == \"fixed_with_images\":\n                self.mix_factor = None\n            else:\n                raise ValueError(f\"unknown merge strategy {merge_strategy}\")\n\n            self.get_alpha_fn = partial(\n                get_alpha,\n                merge_strategy,\n                self.mix_factor,\n                apply_sigmoid=apply_sigmoid_to_merge,\n                is_attn=True,\n            )\n        else:\n            self.time_mixer = AlphaBlender(\n                alpha=merge_factor, merge_strategy=merge_strategy\n            )\n            if self.separate_motion_merge_factor:\n                self.time_mixer_motion = AlphaBlender(\n                    alpha=merge_factor_motion, merge_strategy=merge_strategy\n                )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        context: Optional[torch.Tensor] = None,\n        cam: Optional[torch.Tensor] = None,\n        time_context: Optional[torch.Tensor] = None,\n        timesteps: Optional[int] = None,\n        image_only_indicator: Optional[torch.Tensor] = None,\n        cond_view: Optional[torch.Tensor] = None,\n        cond_motion: Optional[torch.Tensor] = None,\n        time_step: Optional[int] = None,\n        name: Optional[str] = None,\n    ) -> torch.Tensor:\n        # context: b t 1024\n        # cond_view: b*v 4 h w\n        # cond_motion: b*t 4 h w\n        # image_only_indicator: b t*v\n        b, t, d1 = context.shape # CLIP\n        v, d2 = cond_view.shape[0]//b, cond_view.shape[1] # VAE\n        _, c, h, w = x.shape\n\n        x_in = x\n        spatial_context = None\n        if exists(context):\n            spatial_context = context\n\n        cond_view = torch.nn.functional.interpolate(cond_view, size=(h,w), mode=\"bilinear\") # b*v d h w\n        spatial_context = context[:,:,None].repeat(1,1,v,1).reshape(-1,1,d1) # (b*t*v) 1 d1\n        camera_context = context[:,:,None].repeat(1,1,h*w,1).reshape(-1,1,d1) # (b*t*h*w) 1 d1\n        motion_context = cond_view.permute(0,2,3,1).reshape(-1,1,d2) # (b*v*h*w) 1 d2\n\n        x = self.norm(x)\n        if not self.use_linear:\n            x = self.proj_in(x)\n        x = rearrange(x, \"b c h w -> b (h w) c\")\n        if self.use_linear:\n            x = self.proj_in(x)\n\n        if self.time_mix_legacy:\n            alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator)\n\n        num_frames = torch.arange(t, device=x.device)\n        num_frames = repeat(num_frames, \"t -> b t\", b=b)\n        num_frames = rearrange(num_frames, \"b t -> (b t)\")\n        t_emb = timestep_embedding(\n            num_frames,\n            self.in_channels,\n            repeat_only=False,\n            max_period=self.max_time_embed_period,\n        )\n        emb_time = self.time_mix_motion_embed(t_emb)\n        emb_time = emb_time[:, None, :] # b*t 1 c\n\n        if self.use_camera_emb:\n            emb_view = self.time_mix_time_embed(cam.view(b,t,v,-1)[:,0].reshape(b*v,-1))\n            emb_view = emb_view[:, None, :]\n        else:\n            num_views = torch.arange(v, device=x.device)\n            num_views = repeat(num_views, \"t -> b t\", b=b)\n            num_views = rearrange(num_views, \"b t -> (b t)\")\n            v_emb = timestep_embedding(\n                num_views,\n                self.in_channels,\n                repeat_only=False,\n                max_period=self.max_time_embed_period,\n            )\n            emb_view = self.time_mix_time_embed(v_emb)\n            emb_view = emb_view[:, None, :] # b*v 1 c\n\n        if self.use_3d_attention:\n            emb_view = emb_view.repeat(1, h*w, 1).view(-1,1,c) # b*v*h*w 1 c\n\n        for it_, (block, time_block, mot_block) in enumerate(\n            zip(self.transformer_blocks, self.time_mix_blocks, self.motion_blocks)\n        ):\n            # Spatial attention\n            x = block(\n                x,\n                context=spatial_context,\n            )\n\n            # Camera attention\n            if self.use_3d_attention:\n                x = x.view(b, t, v, h*w, c).permute(0,2,3,1,4).reshape(-1,t,c) # b*v*h*w t c\n            else:\n                x = x.view(b, t, v, h*w, c).permute(0,2,1,3,4).reshape(b*v,-1,c) # b*v t*h*w c\n            x_mix = x + emb_view\n            x_mix = time_block(x_mix, context=camera_context, timesteps=v)\n            if self.time_mix_legacy:\n                x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix\n            else:\n                x = self.time_mixer(\n                    x_spatial=x,\n                    x_temporal=x_mix,\n                    image_only_indicator=torch.zeros_like(image_only_indicator[:,:1].repeat(1,x.shape[0]//b)),\n                )\n\n            # Motion attention\n            if self.use_3d_attention:\n                x = x.view(b, v, h*w, t, c).permute(0,3,1,2,4).reshape(b*t,-1,c) # b*t v*h*w c\n            else:\n                x = x.view(b, v, t, h*w, c).permute(0,2,1,3,4).reshape(b*t,-1,c) # b*t v*h*w c\n            x_mix = x + emb_time\n            x_mix = mot_block(x_mix, context=motion_context, timesteps=t)\n            if self.time_mix_legacy:\n                x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix\n            else:\n                motion_mixer = self.time_mixer_motion if self.separate_motion_merge_factor else self.time_mixer\n                x = motion_mixer(\n                    x_spatial=x,\n                    x_temporal=x_mix,\n                    image_only_indicator=torch.zeros_like(image_only_indicator[:,:1].repeat(1,x.shape[0]//b)),\n                )\n\n            x = x.view(b, t, v, h*w, c).reshape(-1,h*w,c) # b*t*v h*w c\n\n        if self.use_linear:\n            x = self.proj_out(x)\n        x = rearrange(x, \"b (h w) c -> b c h w\", h=h, w=w)\n        if not self.use_linear:\n            x = self.proj_out(x)\n        out = x + x_in\n        return out"
  },
  {
    "path": "sgm/modules/video_attention.py",
    "content": "import torch\n\nfrom ..modules.attention import *\nfrom ..modules.diffusionmodules.util import (AlphaBlender, linear,\n                                             timestep_embedding)\n\n\nclass TimeMixSequential(nn.Sequential):\n    def forward(self, x, context=None, timesteps=None):\n        for layer in self:\n            x = layer(x, context, timesteps)\n\n        return x\n\n\nclass VideoTransformerBlock(nn.Module):\n    ATTENTION_MODES = {\n        \"softmax\": CrossAttention,\n        \"softmax-xformers\": MemoryEfficientCrossAttention,\n    }\n\n    def __init__(\n        self,\n        dim,\n        n_heads,\n        d_head,\n        dropout=0.0,\n        context_dim=None,\n        gated_ff=True,\n        checkpoint=True,\n        timesteps=None,\n        ff_in=False,\n        inner_dim=None,\n        attn_mode=\"softmax\",\n        disable_self_attn=False,\n        disable_temporal_crossattention=False,\n        switch_temporal_ca_to_sa=False,\n    ):\n        super().__init__()\n\n        attn_cls = self.ATTENTION_MODES[attn_mode]\n\n        self.ff_in = ff_in or inner_dim is not None\n        if inner_dim is None:\n            inner_dim = dim\n\n        assert int(n_heads * d_head) == inner_dim\n\n        self.is_res = inner_dim == dim\n\n        if self.ff_in:\n            self.norm_in = nn.LayerNorm(dim)\n            self.ff_in = FeedForward(\n                dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff\n            )\n\n        self.timesteps = timesteps\n        self.disable_self_attn = disable_self_attn\n        if self.disable_self_attn:\n            self.attn1 = attn_cls(\n                query_dim=inner_dim,\n                heads=n_heads,\n                dim_head=d_head,\n                context_dim=context_dim,\n                dropout=dropout,\n            )  # is a cross-attention\n        else:\n            self.attn1 = attn_cls(\n                query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout\n            )  # is a self-attention\n\n        self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff)\n\n        if disable_temporal_crossattention:\n            if switch_temporal_ca_to_sa:\n                raise ValueError\n            else:\n                self.attn2 = None\n        else:\n            self.norm2 = nn.LayerNorm(inner_dim)\n            if switch_temporal_ca_to_sa:\n                self.attn2 = attn_cls(\n                    query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout\n                )  # is a self-attention\n            else:\n                self.attn2 = attn_cls(\n                    query_dim=inner_dim,\n                    context_dim=context_dim,\n                    heads=n_heads,\n                    dim_head=d_head,\n                    dropout=dropout,\n                )  # is self-attn if context is none\n\n        self.norm1 = nn.LayerNorm(inner_dim)\n        self.norm3 = nn.LayerNorm(inner_dim)\n        self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa\n\n        self.checkpoint = checkpoint\n        if self.checkpoint:\n            print(f\"{self.__class__.__name__} is using checkpointing\")\n\n    def forward(\n        self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None\n    ) -> torch.Tensor:\n        if self.checkpoint:\n            return checkpoint(self._forward, x, context, timesteps)\n        else:\n            return self._forward(x, context, timesteps=timesteps)\n\n    def _forward(self, x, context=None, timesteps=None):\n        assert self.timesteps or timesteps\n        assert not (self.timesteps and timesteps) or self.timesteps == timesteps\n        timesteps = self.timesteps or timesteps\n        B, S, C = x.shape\n        x = rearrange(x, \"(b t) s c -> (b s) t c\", t=timesteps)\n\n        if self.ff_in:\n            x_skip = x\n            x = self.ff_in(self.norm_in(x))\n            if self.is_res:\n                x += x_skip\n\n        if self.disable_self_attn:\n            x = self.attn1(self.norm1(x), context=context) + x\n        else:\n            x = self.attn1(self.norm1(x)) + x\n\n        if self.attn2 is not None:\n            if self.switch_temporal_ca_to_sa:\n                x = self.attn2(self.norm2(x)) + x\n            else:\n                x = self.attn2(self.norm2(x), context=context) + x\n        x_skip = x\n        x = self.ff(self.norm3(x))\n        if self.is_res:\n            x += x_skip\n\n        x = rearrange(\n            x, \"(b s) t c -> (b t) s c\", s=S, b=B // timesteps, c=C, t=timesteps\n        )\n        return x\n\n    def get_last_layer(self):\n        return self.ff.net[-1].weight\n\n\nclass SpatialVideoTransformer(SpatialTransformer):\n    def __init__(\n        self,\n        in_channels,\n        n_heads,\n        d_head,\n        depth=1,\n        dropout=0.0,\n        use_linear=False,\n        context_dim=None,\n        use_spatial_context=False,\n        timesteps=None,\n        merge_strategy: str = \"fixed\",\n        merge_factor: float = 0.5,\n        time_context_dim=None,\n        ff_in=False,\n        checkpoint=False,\n        time_depth=1,\n        attn_mode=\"softmax\",\n        disable_self_attn=False,\n        disable_temporal_crossattention=False,\n        max_time_embed_period: int = 10000,\n    ):\n        super().__init__(\n            in_channels,\n            n_heads,\n            d_head,\n            depth=depth,\n            dropout=dropout,\n            attn_type=attn_mode,\n            use_checkpoint=checkpoint,\n            context_dim=context_dim,\n            use_linear=use_linear,\n            disable_self_attn=disable_self_attn,\n        )\n        self.time_depth = time_depth\n        self.depth = depth\n        self.max_time_embed_period = max_time_embed_period\n\n        time_mix_d_head = d_head\n        n_time_mix_heads = n_heads\n\n        time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)\n\n        inner_dim = n_heads * d_head\n        if use_spatial_context:\n            time_context_dim = context_dim\n\n        self.time_stack = nn.ModuleList(\n            [\n                VideoTransformerBlock(\n                    inner_dim,\n                    n_time_mix_heads,\n                    time_mix_d_head,\n                    dropout=dropout,\n                    context_dim=time_context_dim,\n                    timesteps=timesteps,\n                    checkpoint=checkpoint,\n                    ff_in=ff_in,\n                    inner_dim=time_mix_inner_dim,\n                    attn_mode=attn_mode,\n                    disable_self_attn=disable_self_attn,\n                    disable_temporal_crossattention=disable_temporal_crossattention,\n                )\n                for _ in range(self.depth)\n            ]\n        )\n\n        assert len(self.time_stack) == len(self.transformer_blocks)\n\n        self.use_spatial_context = use_spatial_context\n        self.in_channels = in_channels\n\n        time_embed_dim = self.in_channels * 4\n        self.time_pos_embed = nn.Sequential(\n            linear(self.in_channels, time_embed_dim),\n            nn.SiLU(),\n            linear(time_embed_dim, self.in_channels),\n        )\n\n        self.time_mixer = AlphaBlender(\n            alpha=merge_factor, merge_strategy=merge_strategy\n        )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        context: Optional[torch.Tensor] = None,\n        time_context: Optional[torch.Tensor] = None,\n        timesteps: Optional[int] = None,\n        image_only_indicator: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        _, _, h, w = x.shape\n        x_in = x\n        spatial_context = None\n        if exists(context):\n            spatial_context = context\n\n        if self.use_spatial_context:\n            assert (\n                context.ndim == 3\n            ), f\"n dims of spatial context should be 3 but are {context.ndim}\"\n\n            time_context = context\n            time_context_first_timestep = time_context[::timesteps]\n            time_context = repeat(\n                time_context_first_timestep, \"b ... -> (b n) ...\", n=h * w\n            )\n        elif time_context is not None and not self.use_spatial_context:\n            time_context = repeat(time_context, \"b ... -> (b n) ...\", n=h * w)\n            if time_context.ndim == 2:\n                time_context = rearrange(time_context, \"b c -> b 1 c\")\n\n        x = self.norm(x)\n        if not self.use_linear:\n            x = self.proj_in(x)\n        x = rearrange(x, \"b c h w -> b (h w) c\")\n        if self.use_linear:\n            x = self.proj_in(x)\n\n        num_frames = torch.arange(timesteps, device=x.device)\n        num_frames = repeat(num_frames, \"t -> b t\", b=x.shape[0] // timesteps)\n        num_frames = rearrange(num_frames, \"b t -> (b t)\")\n        t_emb = timestep_embedding(\n            num_frames,\n            self.in_channels,\n            repeat_only=False,\n            max_period=self.max_time_embed_period,\n        )\n        emb = self.time_pos_embed(t_emb)\n        emb = emb[:, None, :]\n\n        for it_, (block, mix_block) in enumerate(\n            zip(self.transformer_blocks, self.time_stack)\n        ):\n            x = block(\n                x,\n                context=spatial_context,\n            )\n\n            x_mix = x\n            x_mix = x_mix + emb\n\n            x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps)\n            x = self.time_mixer(\n                x_spatial=x,\n                x_temporal=x_mix,\n                image_only_indicator=image_only_indicator,\n            )\n        if self.use_linear:\n            x = self.proj_out(x)\n        x = rearrange(x, \"b (h w) c -> b c h w\", h=h, w=w)\n        if not self.use_linear:\n            x = self.proj_out(x)\n        out = x + x_in\n        return out\n"
  },
  {
    "path": "sgm/util.py",
    "content": "import functools\nimport importlib\nimport os\nfrom functools import partial\nfrom inspect import isfunction\n\nimport fsspec\nimport numpy as np\nimport torch\nfrom PIL import Image, ImageDraw, ImageFont\nfrom safetensors.torch import load_file as load_safetensors\n\n\ndef disabled_train(self, mode=True):\n    \"\"\"Overwrite model.train with this function to make sure train/eval mode\n    does not change anymore.\"\"\"\n    return self\n\n\ndef get_string_from_tuple(s):\n    try:\n        # Check if the string starts and ends with parentheses\n        if s[0] == \"(\" and s[-1] == \")\":\n            # Convert the string to a tuple\n            t = eval(s)\n            # Check if the type of t is tuple\n            if type(t) == tuple:\n                return t[0]\n            else:\n                pass\n    except:\n        pass\n    return s\n\n\ndef is_power_of_two(n):\n    \"\"\"\n    chat.openai.com/chat\n    Return True if n is a power of 2, otherwise return False.\n\n    The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False.\n    The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False.\n    If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise.\n    Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False.\n\n    \"\"\"\n    if n <= 0:\n        return False\n    return (n & (n - 1)) == 0\n\n\ndef autocast(f, enabled=True):\n    def do_autocast(*args, **kwargs):\n        with torch.cuda.amp.autocast(\n            enabled=enabled,\n            dtype=torch.get_autocast_gpu_dtype(),\n            cache_enabled=torch.is_autocast_cache_enabled(),\n        ):\n            return f(*args, **kwargs)\n\n    return do_autocast\n\n\ndef load_partial_from_config(config):\n    return partial(get_obj_from_str(config[\"target\"]), **config.get(\"params\", dict()))\n\n\ndef log_txt_as_img(wh, xc, size=10):\n    # wh a tuple of (width, height)\n    # xc a list of captions to plot\n    b = len(xc)\n    txts = list()\n    for bi in range(b):\n        txt = Image.new(\"RGB\", wh, color=\"white\")\n        draw = ImageDraw.Draw(txt)\n        font = ImageFont.truetype(\"data/DejaVuSans.ttf\", size=size)\n        nc = int(40 * (wh[0] / 256))\n        if isinstance(xc[bi], list):\n            text_seq = xc[bi][0]\n        else:\n            text_seq = xc[bi]\n        lines = \"\\n\".join(\n            text_seq[start : start + nc] for start in range(0, len(text_seq), nc)\n        )\n\n        try:\n            draw.text((0, 0), lines, fill=\"black\", font=font)\n        except UnicodeEncodeError:\n            print(\"Cant encode string for logging. Skipping.\")\n\n        txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0\n        txts.append(txt)\n    txts = np.stack(txts)\n    txts = torch.tensor(txts)\n    return txts\n\n\ndef partialclass(cls, *args, **kwargs):\n    class NewCls(cls):\n        __init__ = functools.partialmethod(cls.__init__, *args, **kwargs)\n\n    return NewCls\n\n\ndef make_path_absolute(path):\n    fs, p = fsspec.core.url_to_fs(path)\n    if fs.protocol == \"file\":\n        return os.path.abspath(p)\n    return path\n\n\ndef ismap(x):\n    if not isinstance(x, torch.Tensor):\n        return False\n    return (len(x.shape) == 4) and (x.shape[1] > 3)\n\n\ndef isimage(x):\n    if not isinstance(x, torch.Tensor):\n        return False\n    return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)\n\n\ndef isheatmap(x):\n    if not isinstance(x, torch.Tensor):\n        return False\n\n    return x.ndim == 2\n\n\ndef isneighbors(x):\n    if not isinstance(x, torch.Tensor):\n        return False\n    return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1)\n\n\ndef exists(x):\n    return x is not None\n\n\ndef expand_dims_like(x, y):\n    while x.dim() != y.dim():\n        x = x.unsqueeze(-1)\n    return x\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\n\ndef mean_flat(tensor):\n    \"\"\"\n    https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86\n    Take the mean over all non-batch dimensions.\n    \"\"\"\n    return tensor.mean(dim=list(range(1, len(tensor.shape))))\n\n\ndef count_params(model, verbose=False):\n    total_params = sum(p.numel() for p in model.parameters())\n    if verbose:\n        print(f\"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.\")\n    return total_params\n\n\ndef instantiate_from_config(config):\n    if not \"target\" in config:\n        if config == \"__is_first_stage__\":\n            return None\n        elif config == \"__is_unconditional__\":\n            return None\n        raise KeyError(\"Expected key `target` to instantiate.\")\n    return get_obj_from_str(config[\"target\"])(**config.get(\"params\", dict()))\n\n\ndef get_obj_from_str(string, reload=False, invalidate_cache=True):\n    module, cls = string.rsplit(\".\", 1)\n    if invalidate_cache:\n        importlib.invalidate_caches()\n    if reload:\n        module_imp = importlib.import_module(module)\n        importlib.reload(module_imp)\n    return getattr(importlib.import_module(module, package=None), cls)\n\n\ndef append_zero(x):\n    return torch.cat([x, x.new_zeros([1])])\n\n\ndef append_dims(x, target_dims):\n    \"\"\"Appends dimensions to the end of a tensor until it has target_dims dimensions.\"\"\"\n    dims_to_append = target_dims - x.ndim\n    if dims_to_append < 0:\n        raise ValueError(\n            f\"input has {x.ndim} dims but target_dims is {target_dims}, which is less\"\n        )\n    return x[(...,) + (None,) * dims_to_append]\n\n\ndef load_model_from_config(config, ckpt, verbose=True, freeze=True):\n    print(f\"Loading model from {ckpt}\")\n    if ckpt.endswith(\"ckpt\"):\n        pl_sd = torch.load(ckpt, map_location=\"cpu\")\n        if \"global_step\" in pl_sd:\n            print(f\"Global Step: {pl_sd['global_step']}\")\n        sd = pl_sd[\"state_dict\"]\n    elif ckpt.endswith(\"safetensors\"):\n        sd = load_safetensors(ckpt)\n    else:\n        raise NotImplementedError\n\n    model = instantiate_from_config(config.model)\n\n    m, u = model.load_state_dict(sd, strict=False)\n\n    if len(m) > 0 and verbose:\n        print(\"missing keys:\")\n        print(m)\n    if len(u) > 0 and verbose:\n        print(\"unexpected keys:\")\n        print(u)\n\n    if freeze:\n        for param in model.parameters():\n            param.requires_grad = False\n\n    model.eval()\n    return model\n\n\ndef get_configs_path() -> str:\n    \"\"\"\n    Get the `configs` directory.\n    For a working copy, this is the one in the root of the repository,\n    but for an installed copy, it's in the `sgm` package (see pyproject.toml).\n    \"\"\"\n    this_dir = os.path.dirname(__file__)\n    candidates = (\n        os.path.join(this_dir, \"configs\"),\n        os.path.join(this_dir, \"..\", \"configs\"),\n    )\n    for candidate in candidates:\n        candidate = os.path.abspath(candidate)\n        if os.path.isdir(candidate):\n            return candidate\n    raise FileNotFoundError(f\"Could not find SGM configs in {candidates}\")\n\n\ndef get_nested_attribute(obj, attribute_path, depth=None, return_key=False):\n    \"\"\"\n    Will return the result of a recursive get attribute call.\n    E.g.:\n        a.b.c\n        = getattr(getattr(a, \"b\"), \"c\")\n        = get_nested_attribute(a, \"b.c\")\n    If any part of the attribute call is an integer x with current obj a, will\n    try to call a[x] instead of a.x first.\n    \"\"\"\n    attributes = attribute_path.split(\".\")\n    if depth is not None and depth > 0:\n        attributes = attributes[:depth]\n    assert len(attributes) > 0, \"At least one attribute should be selected\"\n    current_attribute = obj\n    current_key = None\n    for level, attribute in enumerate(attributes):\n        current_key = \".\".join(attributes[: level + 1])\n        try:\n            id_ = int(attribute)\n            current_attribute = current_attribute[id_]\n        except ValueError:\n            current_attribute = getattr(current_attribute, attribute)\n\n    return (current_attribute, current_key) if return_key else current_attribute\n"
  },
  {
    "path": "tests/inference/test_inference.py",
    "content": "import numpy\nfrom PIL import Image\nimport pytest\nfrom pytest import fixture\nimport torch\nfrom typing import Tuple\n\nfrom sgm.inference.api import (\n    model_specs,\n    SamplingParams,\n    SamplingPipeline,\n    Sampler,\n    ModelArchitecture,\n)\nimport sgm.inference.helpers as helpers\n\n\n@pytest.mark.inference\nclass TestInference:\n    @fixture(scope=\"class\", params=model_specs.keys())\n    def pipeline(self, request) -> SamplingPipeline:\n        pipeline = SamplingPipeline(request.param)\n        yield pipeline\n        del pipeline\n        torch.cuda.empty_cache()\n\n    @fixture(\n        scope=\"class\",\n        params=[\n            [ModelArchitecture.SDXL_V1_BASE, ModelArchitecture.SDXL_V1_REFINER],\n            [ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER],\n        ],\n        ids=[\"SDXL_V1\", \"SDXL_V0_9\"],\n    )\n    def sdxl_pipelines(self, request) -> Tuple[SamplingPipeline, SamplingPipeline]:\n        base_pipeline = SamplingPipeline(request.param[0])\n        refiner_pipeline = SamplingPipeline(request.param[1])\n        yield base_pipeline, refiner_pipeline\n        del base_pipeline\n        del refiner_pipeline\n        torch.cuda.empty_cache()\n\n    def create_init_image(self, h, w):\n        image_array = numpy.random.rand(h, w, 3) * 255\n        image = Image.fromarray(image_array.astype(\"uint8\")).convert(\"RGB\")\n        return helpers.get_input_image_tensor(image)\n\n    @pytest.mark.parametrize(\"sampler_enum\", Sampler)\n    def test_txt2img(self, pipeline: SamplingPipeline, sampler_enum):\n        output = pipeline.text_to_image(\n            params=SamplingParams(sampler=sampler_enum.value, steps=10),\n            prompt=\"A professional photograph of an astronaut riding a pig\",\n            negative_prompt=\"\",\n            samples=1,\n        )\n\n        assert output is not None\n\n    @pytest.mark.parametrize(\"sampler_enum\", Sampler)\n    def test_img2img(self, pipeline: SamplingPipeline, sampler_enum):\n        output = pipeline.image_to_image(\n            params=SamplingParams(sampler=sampler_enum.value, steps=10),\n            image=self.create_init_image(pipeline.specs.height, pipeline.specs.width),\n            prompt=\"A professional photograph of an astronaut riding a pig\",\n            negative_prompt=\"\",\n            samples=1,\n        )\n        assert output is not None\n\n    @pytest.mark.parametrize(\"sampler_enum\", Sampler)\n    @pytest.mark.parametrize(\n        \"use_init_image\", [True, False], ids=[\"img2img\", \"txt2img\"]\n    )\n    def test_sdxl_with_refiner(\n        self,\n        sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline],\n        sampler_enum,\n        use_init_image,\n    ):\n        base_pipeline, refiner_pipeline = sdxl_pipelines\n        if use_init_image:\n            output = base_pipeline.image_to_image(\n                params=SamplingParams(sampler=sampler_enum.value, steps=10),\n                image=self.create_init_image(\n                    base_pipeline.specs.height, base_pipeline.specs.width\n                ),\n                prompt=\"A professional photograph of an astronaut riding a pig\",\n                negative_prompt=\"\",\n                samples=1,\n                return_latents=True,\n            )\n        else:\n            output = base_pipeline.text_to_image(\n                params=SamplingParams(sampler=sampler_enum.value, steps=10),\n                prompt=\"A professional photograph of an astronaut riding a pig\",\n                negative_prompt=\"\",\n                samples=1,\n                return_latents=True,\n            )\n\n        assert isinstance(output, (tuple, list))\n        samples, samples_z = output\n        assert samples is not None\n        assert samples_z is not None\n        refiner_pipeline.refiner(\n            params=SamplingParams(sampler=sampler_enum.value, steps=10),\n            image=samples_z,\n            prompt=\"A professional photograph of an astronaut riding a pig\",\n            negative_prompt=\"\",\n            samples=1,\n        )\n"
  }
]