Full Code of deep-floyd/IF for AI

develop ffc816389168 cached
38 files
165.4 KB
42.9k tokens
144 symbols
1 requests
Download .txt
Repository: deep-floyd/IF
Branch: develop
Commit: ffc816389168
Files: 38
Total size: 165.4 KB

Directory structure:
gitextract_r39ejdyw/

├── .gitattributes
├── .gitignore
├── .pre-commit-config.yaml
├── CHANGELOG.md
├── LICENSE
├── LICENSE-MODEL
├── README.md
├── deepfloyd_if/
│   ├── __init__.py
│   ├── model/
│   │   ├── __init__.py
│   │   ├── gaussian_diffusion.py
│   │   ├── losses.py
│   │   ├── nn.py
│   │   ├── resample.py
│   │   ├── respace.py
│   │   └── unet.py
│   ├── modules/
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── stage_I.py
│   │   ├── stage_II.py
│   │   ├── stage_III.py
│   │   ├── stage_III_sd_x4.py
│   │   ├── t5.py
│   │   └── utils.py
│   ├── pipelines/
│   │   ├── __init__.py
│   │   ├── dream.py
│   │   ├── inpainting.py
│   │   ├── style_transfer.py
│   │   ├── super_resolution.py
│   │   └── utils.py
│   ├── resources/
│   │   ├── p_head_v1.npz
│   │   ├── w_head_v1.npz
│   │   └── zero_t5-v1_1-xxl_vector.npy
│   └── utils.py
├── requirements-dev.txt
├── requirements-test.txt
├── requirements.txt
├── setup.cfg
└── setup.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitattributes
================================================
notebooks/pipes-DeepFloyd-IF.ipynb filter=lfs diff=lfs merge=lfs -text


================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.idea
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, deepfloyd_if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/


================================================
FILE: .pre-commit-config.yaml
================================================
repos:
-   repo: https://github.com/pre-commit/pre-commit-hooks
    rev: v4.2.0
    hooks:
    -   id: check-docstring-first
    -   id: check-merge-conflict
        stages:
        - push
    -   id: double-quote-string-fixer
    -   id: end-of-file-fixer
    -   id: fix-encoding-pragma
    -   id: mixed-line-ending
    -   id: trailing-whitespace
-   repo: https://github.com/pycqa/flake8
    rev: "4.0.1"
    hooks:
    -   id: flake8
        args: ['--config=setup.cfg']
-   repo: https://github.com/pre-commit/mirrors-autopep8
    rev: v1.6.0
    hooks:
    -   id: autopep8


================================================
FILE: CHANGELOG.md
================================================
v1.0.2rc
-------

- uses separated tokenizer_path to init tokenizer in T5Embedder

v1.0.1
------

- renamed main model `IF-I-IF` --> `IF-I-XL`
- moved dir `notebooks` to HF storage https://huggingface.co/DeepFloyd/IF-notebooks; lets keep new notebooks there;
- added additional kaggle notebook (more free GPU resources) how to generate pictures 1k: [![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/code/shonenkov/deepfloyd-if-4-3b-generator-of-pictures)

v1.0.0
------

- initial version


================================================
FILE: LICENSE
================================================
Copyright (c) 2023 DeepFloyd, StabilityAI

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

1. The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

2. All persons obtaining a copy or substantial portion of the Software,
a modified version of the Software (or substantial portion thereof), or
a derivative work based upon this Software (or substantial portion thereof)
must not delete, remove, disable, diminish, or circumvent any inference filters or
inference filter mechanisms in the Software, or any portion of the Software that
implements any such filters or filter mechanisms.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: LICENSE-MODEL
================================================
DEEPFLOYD IF LICENSE AGREEMENT

This License Agreement (as may be amended in accordance with this License Agreement, “License”),
between you, or your employer or other entity (if you are entering into this agreement on behalf
of your employer or other entity) (“Licensee” or “you”) and Stability AI Ltd.. (“Stability AI” or “we”)
applies to your use of any computer program, algorithm, source code, object code, or software that is made
available by Stability AI under this License (“Software”) and any specifications, manuals, documentation,
and other written information provided by Stability AI related to the Software (“Documentation”).
By clicking “I Accept” below or by using the Software, you agree to the terms of this License.
If you do not agree to this License, then you do not have any rights to use the Software or
Documentation (collectively, the “Software Products”), and you must immediately cease using
the Software Products. If you are agreeing to be bound by the terms of this License on behalf
of your employer or other entity, you represent and warrant to Stability AI that you have full legal
authority to bind your employer or such entity to this License. If you do not have the requisite authority,
you may not accept the License or access the Software Products on behalf of your employer or other entity.

1. LICENSE GRANT

a. Subject to your compliance with the Documentation and Sections 2, 3, and 5, Stability AI grants
you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited
license under Stability AI’s copyright interests to reproduce, distribute, and create derivative works of
the Software solely for your non-commercial research purposes. The foregoing license is personal to you,
and you may not assign or sublicense this License or any other rights or obligations under this License
without Stability AI’s prior written consent; any such assignment or sublicense will be void and will
automatically and immediately terminate this License.

b. You may make a reasonable number of copies of the Documentation solely for use in connection with
the license to the Software granted above.

c. The grant of rights expressly set forth in this Section 1 (License Grant) are the complete
grant of rights to you in the Software Products, and no other licenses are granted, whether by waiver,
estoppel, implication, equity or otherwise. Stability AI and its licensors reserve all rights
not expressly granted by this License.


2. RESTRICTIONS

You will not, and will not permit, assist or cause any third party to:

a. use, modify, copy, reproduce, create derivative works of, or distribute the Software Products
(or any derivative works thereof, works incorporating the Software Products, or any data produced
by the Software), in whole or in part, for (i) any commercial or production purposes,
(ii) military purposes or in the service of nuclear technology, (iii) purposes of surveillance,
including any research or development relating to surveillance, (iv) biometric processing,
(v) in any manner that infringes, misappropriates, or otherwise violates any third-party rights,
or (vi) in any manner that violates any applicable law and violating any privacy or security laws,
rules, regulations, directives, or governmental requirements (including the General Data Privacy
Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and all laws
governing the processing of biometric information), as well as all amendments and successor laws
to any of the foregoing;

b. alter or remove copyright and other proprietary notices which appear on or in the Software Products;

c. utilize any equipment, device, software, or other means to circumvent or remove any security or
protection used by Stability AI in connection with the Software, or to circumvent or remove any
usage restrictions, or to enable functionality disabled by Stability AI; or

d. offer or impose any terms on the Software Products that alter, restrict, or are inconsistent
with the terms of this License.

e. 1) violate any applicable U.S. and non-U.S. export control and trade sanctions laws
(“Export Laws”); 2) directly or indirectly export, re-export, provide, or otherwise
transfer Software Products: (a) to any individual, entity, or country prohibited by Export Laws; (b)
to anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited
by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications;
3) use or download Software Products if you or they are: (a) located in a comprehensively sanctioned
jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any
purpose prohibited by Export Laws; and (4) will not disguise your location through IP proxying or other methods.


3. ATTRIBUTION

Together with any copies of the Software Products (as well as derivative works thereof or works
incorporating the Software Products) that you distribute, you must provide (i) a copy of this License,
and (ii) the following attribution notice: “DeepFloyd is licensed under the DeepFloyd License,
Copyright (c) Stability AI Ltd. All Rights Reserved.”


4. DISCLAIMERS

THE SOFTWARE PRODUCTS ARE PROVIDED “AS IS” and “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED. STABILITY AIEXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED,
WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE SOFTWARE PRODUCTS,
INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE,
TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. STABILITY AI MAKES NO WARRANTIES OR REPRESENTATIONS
THAT THE SOFTWARE PRODUCTS WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS,
OR PRODUCE ANY PARTICULAR RESULTS.


5. LIMITATION OF LIABILITY

TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL STABILITY AI BE LIABLE TO YOU (A) UNDER
ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY,
OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL,
PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF STABILITY AI HAS BEEN ADVISED OF THE POSSIBILITY
OF SUCH DAMAGES. THE SOFTWARE PRODUCTS, THEIR CONSTITUENT COMPONENTS, AND ANY OUTPUT
(COLLECTIVELY, “SOFTWARE MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR
SITUATION WHERE FAILURE OR FAULT OF THE SOFTWARE MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD
TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S
PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”).
IF YOU ELECT TO USE ANY OF THE SOFTWARE MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK.
YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND
POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY
OF THE SOFTWARE MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL
THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE.


6. INDEMNIFICATION

You will indemnify, defend and hold harmless Stability AI and our subsidiaries and affiliates,
and each of our respective shareholders, directors, officers, employees, agents, successors,
and assigns (collectively, the “Stability AI Parties”) from and against any losses, liabilities,
damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any
Stability AI Party in connection with any claim, demand, allegation, lawsuit, proceeding, or
investigation (collectively, “Claims”) arising out of or related to: (a) your access to or
use of the Software Products (as well as any results or data generated from such access or use),
including any High-Risk Use (defined below); (b) your violation of this License; or (c)
your violation, misappropriation or infringement of any rights of another (including intellectual
property or other proprietary rights and privacy rights). You will promptly notify the Stability AI
Parties of any such Claims, and cooperate with Stability AI Parties in defending such Claims.
You will also grant the Stability AI Parties sole control of the defense or settlement,
at Stability AI’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of,
any other indemnities or remedies set forth in a written agreement between you and
Stability AI or the other Stability AI Parties.


7. TERMINATION; SURVIVAL

a. This License will automatically terminate upon any breach by you of the terms of this License.

b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you.

c. The following sections survive termination of this License: 2 (Restrictions), 3 (Attribution),
4 (Disclaimers), 5 (Limitation on Liability), 6 (Indemnification) 7 (Termination; Survival),
8 (Third Party Materials), 9 (Trademarks), 10 (Applicable Law; Dispute Resolution), and 11 (Miscellaneous).


8. THIRD PARTY MATERIALS

The Software Products may contain third-party software or other components (including free and
open source software) (all of the foregoing, “Third Party Materials”), which are subject to
the license terms of the respective third-party licensors. Your dealings or correspondence
with third parties and your use of or interaction with any Third Party Materials are solely
between you and the third party. Stability AI does not control or endorse, and makes
no representations or warranties regarding, any Third Party Materials, and your access
to and use of such Third Party Materials are at your own risk.


9. TRADEMARKS

Licensee has not been granted any trademark license as part of this License and may not use any name
or mark associated with Stability AI without the prior written permission of Stability AI, except to
the extent necessary to make the reference required by the “ATTRIBUTION” section of this Agreement.


10. APPLICABLE LAW; DISPUTE RESOLUTION

This License will be governed and construed under the laws of the State of California without regard
to conflicts of law provisions. Any suit or proceeding arising out of or relating to this License
will be brought in the federal or state courts, as applicable, in San Mateo County, California,
and each party irrevocably submits to the jurisdiction and venue of such courts.


11. MISCELLANEOUS

If any provision or part of a provision of this License is unlawful, void or unenforceable,
that provision or part of the provision is deemed severed from this License, and will not affect
the validity and enforceability of any remaining provisions. The failure of Stability AI to exercise
or enforce any right or provision of this License will not operate as a waiver of such right or provision.
This License does not confer any third-party beneficiary rights upon any other person or entity.
This License, together with the Documentation, contains the entire understanding between you and
Stability AI regarding the subject matter of this License, and supersedes all other written or
oral agreements and understandings between you and Stability AI regarding such subject matter.
No change or addition to any provision of this License will be binding unless it is in writing and
signed by an authorized representative of both you and Stability AI.


================================================
FILE: README.md
================================================
[![License](https://img.shields.io/badge/Code_License-Modified_MIT-blue.svg)](LICENSE)
[![License](https://img.shields.io/badge/Weights_License-DeepFloyd_IF-orange.svg)](LICENSE-MODEL)
[![Downloads](https://pepy.tech/badge/deepfloyd_if)](https://pepy.tech/project/deepfloyd_if)
[![Discord](https://img.shields.io/badge/Discord-%237289DA.svg?logo=discord&logoColor=white)](https://discord.gg/umz62Mgr)
[![Twitter](https://img.shields.io/badge/Twitter-%231DA1F2.svg?logo=twitter&logoColor=white)](https://twitter.com/deepfloydai)
[![Linktree](https://img.shields.io/badge/Linktree-%2339E09B.svg?logo=linktree&logoColor=white)](http://linktr.ee/deepfloyd)

# IF by [DeepFloyd Lab](https://deepfloyd.ai) at [StabilityAI](https://stability.ai/)

<p align="center">
  <img src="./pics/nabla.jpg" width="100%">
</p>

We introduce DeepFloyd IF, a novel state-of-the-art open-source text-to-image model with a high degree of photorealism and language understanding. DeepFloyd IF is a modular composed of a frozen text encoder and three cascaded pixel diffusion modules: a base model that generates 64x64 px image based on text prompt and two super-resolution models, each designed to generate images of increasing resolution: 256x256 px and 1024x1024 px. All stages of the model utilize a frozen text encoder based on the T5 transformer to extract text embeddings, which are then fed into a UNet architecture enhanced with cross-attention and attention pooling. The result is a highly efficient model that outperforms current state-of-the-art models, achieving a zero-shot FID score of 6.66 on the COCO dataset. Our work underscores the potential of larger UNet architectures in the first stage of cascaded diffusion models and depicts a promising future for text-to-image synthesis.

<p align="center">
  <img src="./pics/deepfloyd_if_scheme.jpg" width="100%">
</p>

*Inspired by* [*Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding*](https://arxiv.org/pdf/2205.11487.pdf)

## Minimum requirements to use all IF models:
- 16GB vRAM for IF-I-XL (4.3B text to 64x64 base module) & IF-II-L (1.2B to 256x256 upscaler module)
- 24GB vRAM for IF-I-XL (4.3B text to 64x64 base module) & IF-II-L (1.2B to 256x256 upscaler module) & Stable x4 (to 1024x1024 upscaler)
- `xformers` and set env variable `FORCE_MEM_EFFICIENT_ATTN=1`


## Quick Start
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/deepfloyd_if_free_tier_google_colab.ipynb)
[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/DeepFloyd/IF)

```shell
pip install deepfloyd_if==1.0.2rc0
pip install xformers==0.0.16
pip install git+https://github.com/openai/CLIP.git --no-deps
```

## Local notebooks
[![Jupyter Notebook](https://img.shields.io/badge/jupyter_notebook-%23FF7A01.svg?logo=jupyter&logoColor=white)](https://huggingface.co/DeepFloyd/IF-notebooks/blob/main/pipes-DeepFloyd-IF-v1.0.ipynb)
[![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/code/shonenkov/deepfloyd-if-4-3b-generator-of-pictures)

The Dream, Style Transfer, Super Resolution or Inpainting modes are avaliable in a Jupyter Notebook [here](https://huggingface.co/DeepFloyd/IF-notebooks/blob/main/pipes-DeepFloyd-IF-v1.0.ipynb).



## Integration with 🤗 Diffusers

IF is also integrated with the 🤗 Hugging Face [Diffusers library](https://github.com/huggingface/diffusers/).

Diffusers runs each stage individually allowing the user to customize the image generation process as well as allowing to inspect intermediate results easily.

### Example

Before you can use IF, you need to accept its usage conditions. To do so:
1. Make sure to have a [Hugging Face account](https://huggingface.co/join) and be loggin in
2. Accept the license on the model card of [DeepFloyd/IF-I-XL-v1.0](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0)
3. Make sure to login locally. Install `huggingface_hub`
```sh
pip install huggingface_hub --upgrade
```

run the login function in a Python shell

```py
from huggingface_hub import login

login()
```

and enter your [Hugging Face Hub access token](https://huggingface.co/docs/hub/security-tokens#what-are-user-access-tokens).

Next we install `diffusers` and dependencies:

```sh
pip install diffusers accelerate transformers safetensors
```

And we can now run the model locally.

By default `diffusers` makes use of [model cpu offloading](https://huggingface.co/docs/diffusers/optimization/fp16#model-offloading-for-fast-inference-and-memory-savings) to run the whole IF pipeline with as little as 14 GB of VRAM.

If you are using `torch>=2.0.0`, make sure to **delete all** `enable_xformers_memory_efficient_attention()`
functions.

```py
from diffusers import DiffusionPipeline
from diffusers.utils import pt_to_pil
import torch

# stage 1
stage_1 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
stage_1.enable_xformers_memory_efficient_attention()  # remove line if torch.__version__ >= 2.0.0
stage_1.enable_model_cpu_offload()

# stage 2
stage_2 = DiffusionPipeline.from_pretrained(
    "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
)
stage_2.enable_xformers_memory_efficient_attention()  # remove line if torch.__version__ >= 2.0.0
stage_2.enable_model_cpu_offload()

# stage 3
safety_modules = {"feature_extractor": stage_1.feature_extractor, "safety_checker": stage_1.safety_checker, "watermarker": stage_1.watermarker}
stage_3 = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16)
stage_3.enable_xformers_memory_efficient_attention()  # remove line if torch.__version__ >= 2.0.0
stage_3.enable_model_cpu_offload()

prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"'

# text embeds
prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)

generator = torch.manual_seed(0)

# stage 1
image = stage_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt").images
pt_to_pil(image)[0].save("./if_stage_I.png")

# stage 2
image = stage_2(
    image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt"
).images
pt_to_pil(image)[0].save("./if_stage_II.png")

# stage 3
image = stage_3(prompt=prompt, image=image, generator=generator, noise_level=100).images
image[0].save("./if_stage_III.png")
```

 There are multiple ways to speed up the inference time and lower the memory consumption even more with `diffusers`. To do so, please have a look at the Diffusers docs:

- 🚀 [Optimizing for inference time](https://huggingface.co/docs/diffusers/api/pipelines/if#optimizing-for-speed)
- ⚙️ [Optimizing for low memory during inference](https://huggingface.co/docs/diffusers/api/pipelines/if#optimizing-for-memory)

For more in-detail information about how to use IF, please have a look at [the IF blog post](https://huggingface.co/blog/if) and [the documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/if) 📖.

Diffusers dreambooth scripts also supports fine-tuning 🎨 [IF](https://huggingface.co/docs/diffusers/main/en/training/dreambooth#if).
With parameter efficient finetuning, you can add new concepts to IF with a single GPU and ~28 GB VRAM.

## Run the code locally

### Loading the models into VRAM

```python
from deepfloyd_if.modules import IFStageI, IFStageII, StableStageIII
from deepfloyd_if.modules.t5 import T5Embedder

device = 'cuda:0'
if_I = IFStageI('IF-I-XL-v1.0', device=device)
if_II = IFStageII('IF-II-L-v1.0', device=device)
if_III = StableStageIII('stable-diffusion-x4-upscaler', device=device)
t5 = T5Embedder(device="cpu")
```

### I. Dream
Dream is the text-to-image mode of the IF model

```python
from deepfloyd_if.pipelines import dream

prompt = 'ultra close-up color photo portrait of rainbow owl with deer horns in the woods'
count = 4

result = dream(
    t5=t5, if_I=if_I, if_II=if_II, if_III=if_III,
    prompt=[prompt]*count,
    seed=42,
    if_I_kwargs={
        "guidance_scale": 7.0,
        "sample_timestep_respacing": "smart100",
    },
    if_II_kwargs={
        "guidance_scale": 4.0,
        "sample_timestep_respacing": "smart50",
    },
    if_III_kwargs={
        "guidance_scale": 9.0,
        "noise_level": 20,
        "sample_timestep_respacing": "75",
    },
)

if_III.show(result['III'], size=14)
```
![](./pics/dream-III.jpg)

## II. Zero-shot Image-to-Image Translation

![](./pics/img_to_img_scheme.jpeg)

In Style Transfer mode, the output of your prompt comes out at the style of the `support_pil_img`
```python
from deepfloyd_if.pipelines import style_transfer

result = style_transfer(
    t5=t5, if_I=if_I, if_II=if_II,
    support_pil_img=raw_pil_image,
    style_prompt=[
        'in style of professional origami',
        'in style of oil art, Tate modern',
        'in style of plastic building bricks',
        'in style of classic anime from 1990',
    ],
    seed=42,
    if_I_kwargs={
        "guidance_scale": 10.0,
        "sample_timestep_respacing": "10,10,10,10,10,10,10,10,0,0",
        'support_noise_less_qsample_steps': 5,
    },
    if_II_kwargs={
        "guidance_scale": 4.0,
        "sample_timestep_respacing": 'smart50',
        "support_noise_less_qsample_steps": 5,
    },
)
if_I.show(result['II'], 1, 20)
```

![Alternative Text](./pics/deep_floyd_if_image_2_image.gif)


## III. Super Resolution
For super-resolution, users can run `IF-II` and `IF-III` or 'Stable x4' on an image that was not necessarely generated by IF (two cascades):

```python
from deepfloyd_if.pipelines import super_resolution

middle_res = super_resolution(
    t5,
    if_III=if_II,
    prompt=['woman with a blue headscarf and a blue sweaterp, detailed picture, 4k dslr, best quality'],
    support_pil_img=raw_pil_image,
    img_scale=4.,
    img_size=64,
    if_III_kwargs={
        'sample_timestep_respacing': 'smart100',
        'aug_level': 0.5,
        'guidance_scale': 6.0,
    },
)
high_res = super_resolution(
    t5,
    if_III=if_III,
    prompt=[''],
    support_pil_img=middle_res['III'][0],
    img_scale=4.,
    img_size=256,
    if_III_kwargs={
        "guidance_scale": 9.0,
        "noise_level": 20,
        "sample_timestep_respacing": "75",
    },
)
show_superres(raw_pil_image, high_res['III'][0])
```

![](./pics/if_as_upscaler.jpg)


### IV. Zero-shot Inpainting

```python
from deepfloyd_if.pipelines import inpainting

result = inpainting(
    t5=t5, if_I=if_I,
    if_II=if_II,
    if_III=if_III,
    support_pil_img=raw_pil_image,
    inpainting_mask=inpainting_mask,
    prompt=[
        'oil art, a man in a hat',
    ],
    seed=42,
    if_I_kwargs={
        "guidance_scale": 7.0,
        "sample_timestep_respacing": "10,10,10,10,10,0,0,0,0,0",
        'support_noise_less_qsample_steps': 0,
    },
    if_II_kwargs={
        "guidance_scale": 4.0,
        'aug_level': 0.0,
        "sample_timestep_respacing": '100',
    },
    if_III_kwargs={
        "guidance_scale": 9.0,
        "noise_level": 20,
        "sample_timestep_respacing": "75",
    },
)
if_I.show(result['I'], 2, 3)
if_I.show(result['II'], 2, 6)
if_I.show(result['III'], 2, 14)
```
![](./pics/deep_floyd_if_inpainting.gif)

### 🤗 Model Zoo 🤗
The link to download the weights as well as the model cards will be available soon on each model of the model zoo

#### Original

| Name                                                      | Cascade | Params | FID  | Batch size | Steps |
|:----------------------------------------------------------|:-------:|:------:|:----:|:----------:|:-----:|
| [IF-I-M](https://huggingface.co/DeepFloyd/IF-I-M-v1.0)    |    I    |  400M  | 8.86 |    3072    | 2.5M  |
| [IF-I-L](https://huggingface.co/DeepFloyd/IF-I-L-v1.0)    |    I    |  900M  | 8.06 |    3200    | 3.0M  |
| [IF-I-XL](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0)* |    I    |  4.3B  | 6.66 |    3072    | 2.42M |
| [IF-II-M](https://huggingface.co/DeepFloyd/IF-II-M-v1.0)  |   II    |  450M  |  -   |    1536    | 2.5M  |
| [IF-II-L](https://huggingface.co/DeepFloyd/IF-II-L-v1.0)* |   II    |  1.2B  |  -   |    1536    | 2.5M  |
| IF-III-L* _(soon)_                                        |   III   |  700M  |  -   |    3072    | 1.25M |

 *best modules

### Quantitative Evaluation

`FID = 6.66`

![](./pics/fid30k_if.jpg)

## License

The code in this repository is released under the bespoke license (see added [point two](https://github.com/deep-floyd/IF/blob/main/LICENSE#L13)).

The weights will be available soon via [the DeepFloyd organization at Hugging Face](https://huggingface.co/DeepFloyd) and have their own LICENSE.

**Disclaimer:** *The initial release of the IF model is under a restricted research-purposes-only license temporarily to gather feedback, and after that we intend to release a fully open-source model in line with other Stability AI models.*

## Limitations and Biases

The models available in this codebase have known limitations and biases. Please refer to [the model card](https://huggingface.co/DeepFloyd/IF-I-L-v1.0) for more information.


## 🎓 DeepFloyd IF creators:

- Alex Shonenkov [GitHub](https://github.com/shonenkov) | [Linktr](https://linktr.ee/shonenkovAI)
- Misha Konstantinov [GitHub](https://github.com/zeroshot-ai) | [Twitter](https://twitter.com/_bra_ket)
- Daria Bakshandaeva [GitHub](https://github.com/Gugutse) | [Twitter](https://twitter.com/_gugutse_)
- Christoph Schuhmann [GitHub](https://github.com/christophschuhmann) | [Twitter](https://twitter.com/laion_ai)
- Ksenia Ivanova [GitHub](https://github.com/ivksu) | [Twitter](https://twitter.com/susiaiv)
- Nadiia Klokova [GitHub](https://github.com/vauimpuls) | [Twitter](https://twitter.com/vauimpuls)


## 📄 Research Paper (Soon)

## Acknowledgements

Special thanks to [StabilityAI](http://stability.ai) and its CEO [Emad Mostaque](https://twitter.com/emostaque) for invaluable support, providing GPU compute and infrastructure to train the models (our gratitude goes to [Richard Vencu](https://github.com/rvencu)); thanks to [LAION](https://laion.ai) and [Christoph Schuhmann](https://github.com/christophschuhmann) in particular for contribution to the project and well-prepared datasets; thanks to [Huggingface](https://huggingface.co) teams for optimizing models' speed and memory consumption during inference, creating demos and giving cool advice!

## 🚀 External Contributors 🚀
- The Biggest Thanks [@Apolinário](https://github.com/apolinario), for ideas, consultations, help and support on all stages to make IF available in open-source; for writing a lot of documentation and instructions; for creating a friendly atmosphere in difficult moments 🦉;
- Thanks, [@patrickvonplaten](https://github.com/patrickvonplaten), for improving loading time of unet models by 80%;
for integration Stable-Diffusion-x4 as native pipeline 💪;
- Thanks, [@williamberman](https://github.com/williamberman) and [@patrickvonplaten](https://github.com/patrickvonplaten) for diffusers integration 🙌;
- Thanks, [@hysts](https://github.com/hysts) and [@Apolinário](https://github.com/apolinario) for creating [the best gradio demo with IF](https://huggingface.co/spaces/DeepFloyd/IF) 🚀;
- Thanks, [@Dango233](https://github.com/Dango233), for adapting IF with xformers memory efficient attention 💪;


================================================
FILE: deepfloyd_if/__init__.py
================================================
# -*- coding: utf-8 -*-


__version__ = '1.0.2rc0'


================================================
FILE: deepfloyd_if/model/__init__.py
================================================
# -*- coding: utf-8 -*-
from .unet import UNetModel, SuperResUNetModel


__all__ = ['UNetModel', 'SuperResUNetModel']


================================================
FILE: deepfloyd_if/model/gaussian_diffusion.py
================================================
# -*- coding: utf-8 -*-
"""
This code started out as a PyTorch port of Ho et al's diffusion model:
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
"""

import enum
import math
import numpy as np
import torch

from .nn import mean_flat
from .losses import normal_kl, discretized_gaussian_log_likelihood


def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
    """
    Get a pre-defined beta schedule for the given name.
    The beta schedule library consists of beta schedules which remain similar
    in the limit of num_diffusion_timesteps.
    Beta schedules may be added, but should not be removed or changed once
    they are committed to maintain backwards compatibility.
    """
    if schedule_name == 'linear':
        # Linear schedule from Ho et al, extended to work for any number of
        # diffusion steps.
        scale = 1000 / num_diffusion_timesteps
        beta_start = scale * 0.0001
        beta_end = scale * 0.02
        return np.linspace(
            beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
        )
    elif schedule_name == 'cosine':
        return betas_for_alpha_bar(
            num_diffusion_timesteps,
            lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
        )
    else:
        raise NotImplementedError(f'unknown beta schedule: {schedule_name}')


def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
    """
    Create a beta schedule that discretizes the given alpha_t_bar function,
    which defines the cumulative product of (1-beta) over time from t = [0,1].
    :param num_diffusion_timesteps: the number of betas to produce.
    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
                      produces the cumulative product of (1-beta) up to that
                      part of the diffusion process.
    :param max_beta: the maximum beta to use; use values lower than 1 to
                     prevent singularities.
    """
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas)


class ModelMeanType(enum.Enum):
    """
    Which type of output the model predicts.
    """

    PREVIOUS_X = enum.auto()  # the model predicts x_{t-1}
    START_X = enum.auto()  # the model predicts x_0
    EPSILON = enum.auto()  # the model predicts epsilon


class ModelVarType(enum.Enum):
    """
    What is used as the model's output variance.
    The LEARNED_RANGE option has been added to allow the model to predict
    values between FIXED_SMALL and FIXED_LARGE, making its job easier.
    """

    LEARNED = enum.auto()
    FIXED_SMALL = enum.auto()
    FIXED_LARGE = enum.auto()
    LEARNED_RANGE = enum.auto()


class LossType(enum.Enum):
    MSE = enum.auto()  # use raw MSE loss (and KL when learning variances)
    RESCALED_MSE = (
        enum.auto()
    )  # use raw MSE loss (with RESCALED_KL when learning variances)
    KL = enum.auto()  # use the variational lower-bound
    RESCALED_KL = enum.auto()  # like KL, but rescale to estimate the full VLB

    def is_vb(self):
        return self == LossType.KL or self == LossType.RESCALED_KL


class GaussianDiffusion:
    """
    Utilities for training and sampling diffusion model.
    Ported directly from here, and then adapted over time to further experimentation.
    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
    :param betas: a 1-D numpy array of betas for each diffusion timestep,
                  starting at T and going to 1.
    :param model_mean_type: a ModelMeanType determining what the model outputs.
    :param model_var_type: a ModelVarType determining how variance is output.
    :param loss_type: a LossType determining the loss function to use.
    :param rescale_timesteps: if True, pass floating point timesteps into the
                              model so that they are always scaled like in the
                              original paper (0 to 1000).
    """

    def __init__(
        self,
        *,
        betas,
        model_mean_type,
        model_var_type,
        loss_type,
        rescale_timesteps=False,
    ):
        self.model_mean_type = model_mean_type
        self.model_var_type = model_var_type
        self.loss_type = loss_type
        self.rescale_timesteps = rescale_timesteps

        # Use float64 for accuracy.
        betas = np.array(betas, dtype=np.float64)
        self.betas = betas
        assert len(betas.shape) == 1, 'betas must be 1-D'
        assert (betas > 0).all() and (betas <= 1).all()

        self.num_timesteps = int(betas.shape[0])

        alphas = 1.0 - betas
        self.alphas_cumprod = np.cumprod(alphas, axis=0)
        self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
        self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
        assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
        self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = (
            betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        # log calculation clipped because the posterior variance is 0 at the
        # beginning of the diffusion chain.
        self.posterior_log_variance_clipped = np.log(
            np.append(self.posterior_variance[1], self.posterior_variance[1:])
        )
        self.posterior_mean_coef1 = (
            betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_mean_coef2 = (
            (1.0 - self.alphas_cumprod_prev)
            * np.sqrt(alphas)
            / (1.0 - self.alphas_cumprod)
        )

    def dynamic_thresholding(self, x, p=0.995, c=1.7):
        """
        Dynamic thresholding, a diffusion sampling technique from Imagen (https://arxiv.org/abs/2205.11487)
        to leverage high guidance weights and generating more photorealistic and detailed images
        than previously was possible based on x.clamp(-1, 1) vanilla clipping or static thresholding

        p — percentile determine relative value for clipping threshold for dynamic compression,
            helps prevent oversaturation recommend values [0.96 — 0.99]

        c — absolute hard clipping of value for clipping threshold for dynamic compression,
            helps prevent undersaturation and low contrast issues; recommend values [1.5 — 2.]
        """
        x_shapes = x.shape
        s = torch.quantile(x.abs().reshape(x_shapes[0], -1), p, dim=-1)
        s = torch.clamp(s, min=1, max=c)
        x_compressed = torch.clip(x.reshape(x_shapes[0], -1).T, -s, s) / s
        x_compressed = x_compressed.T.reshape(x_shapes)
        return x_compressed

    def q_mean_variance(self, x_start, t):
        """
        Get the distribution q(x_t | x_0).
        :param x_start: the [N x C x ...] tensor of noiseless inputs.
        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
        :return: A tuple (mean, variance, log_variance), all of x_start's shape.
        """
        mean = (
            _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
        )
        variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
        log_variance = _extract_into_tensor(
            self.log_one_minus_alphas_cumprod, t, x_start.shape
        )
        return mean, variance, log_variance

    def q_sample(self, x_start, t, noise=None):
        """
        Diffuse the data for a given number of diffusion steps.
        In other words, sample from q(x_t | x_0).
        :param x_start: the initial data batch.
        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
        :param noise: if specified, the split-out normal noise.
        :return: A noisy version of x_start.
        """
        if noise is None:
            noise = torch.randn_like(x_start)
        assert noise.shape == x_start.shape
        return (
            _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
            + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
            * noise
        )

    def q_posterior_mean_variance(self, x_start, x_t, t):
        """
        Compute the mean and variance of the diffusion posterior:
            q(x_{t-1} | x_t, x_0)
        """
        assert x_start.shape == x_t.shape
        posterior_mean = (
            _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
            + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = _extract_into_tensor(
            self.posterior_log_variance_clipped, t, x_t.shape
        )
        assert (
            posterior_mean.shape[0]
            == posterior_variance.shape[0]
            == posterior_log_variance_clipped.shape[0]
            == x_start.shape[0]
        )
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def p_mean_variance(
        self, model, x, t, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7,
        denoised_fn=None, model_kwargs=None
    ):
        """
        Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
        the initial x, x_0.
        :param model: the model, which takes a signal and a batch of timesteps
                      as input.
        :param x: the [N x C x ...] tensor at time t.
        :param t: a 1-D Tensor of timesteps.
        :param clip_denoised: if True, clip the denoised signal into [-1, 1].
        :param denoised_fn: if not None, a function which applies to the
            x_start prediction before it is used to sample. Applies before
            clip_denoised.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :return: a dict with the following keys:
                 - 'mean': the model mean output.
                 - 'variance': the model variance output.
                 - 'log_variance': the log of 'variance'.
                 - 'pred_xstart': the prediction for x_0.
        """
        if model_kwargs is None:
            model_kwargs = {}

        B, C = x.shape[:2]
        assert t.shape == (B,)
        model_output = model(x, self._scale_timesteps(t), **model_kwargs)

        if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
            assert model_output.shape == (B, C * 2, *x.shape[2:])
            model_output, model_var_values = torch.split(model_output, C, dim=1)
            if self.model_var_type == ModelVarType.LEARNED:
                model_log_variance = model_var_values
                model_variance = torch.exp(model_log_variance)
            else:
                min_log = _extract_into_tensor(
                    self.posterior_log_variance_clipped, t, x.shape
                )
                max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
                # The model_var_values is [-1, 1] for [min_var, max_var].
                frac = (model_var_values + 1) / 2
                model_log_variance = frac * max_log + (1 - frac) * min_log
                model_variance = torch.exp(model_log_variance)
        else:
            model_variance, model_log_variance = {
                # for fixedlarge, we set the initial (log-)variance like so
                # to get a better decoder log likelihood.
                ModelVarType.FIXED_LARGE: (
                    np.append(self.posterior_variance[1], self.betas[1:]),
                    np.log(np.append(self.posterior_variance[1], self.betas[1:])),
                ),
                ModelVarType.FIXED_SMALL: (
                    self.posterior_variance,
                    self.posterior_log_variance_clipped,
                ),
            }[self.model_var_type]
            model_variance = _extract_into_tensor(model_variance, t, x.shape)
            model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)

        def process_xstart(x):
            if denoised_fn is not None:
                x = denoised_fn(x)
            if clip_denoised:
                x = self.dynamic_thresholding(x, p=dynamic_thresholding_p, c=dynamic_thresholding_c)
                return x  # x.clamp(-1, 1)
            return x

        if self.model_mean_type == ModelMeanType.PREVIOUS_X:
            pred_xstart = process_xstart(
                self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
            )
            model_mean = model_output
        elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
            if self.model_mean_type == ModelMeanType.START_X:
                pred_xstart = process_xstart(model_output)
            else:
                pred_xstart = process_xstart(
                    self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
                )
            model_mean, _, _ = self.q_posterior_mean_variance(
                x_start=pred_xstart, x_t=x, t=t
            )
        else:
            raise NotImplementedError(self.model_mean_type)

        assert (
            model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
        )
        return {
            'mean': model_mean,
            'variance': model_variance,
            'log_variance': model_log_variance,
            'pred_xstart': pred_xstart,
        }

    def _predict_xstart_from_eps(self, x_t, t, eps):
        assert x_t.shape == eps.shape
        return (
            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
            - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
        )

    def _predict_xstart_from_xprev(self, x_t, t, xprev):
        assert x_t.shape == xprev.shape
        return (  # (xprev - coef2*x_t) / coef1
            _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
            - _extract_into_tensor(
                self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
            )
            * x_t
        )

    def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
        return (
            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
            - pred_xstart
        ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)

    def _scale_timesteps(self, t):
        if self.rescale_timesteps:
            return t.float() * (1000.0 / self.num_timesteps)
        return t

    def p_sample(
        self, model, x, t, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7,
        denoised_fn=None, model_kwargs=None, inpainting_mask=None,
    ):
        """
        Sample x_{t-1} from the model at the given timestep.
        :param model: the model to sample from.
        :param x: the current tensor at x_{t-1}.
        :param t: the value of t, starting at 0 for the first diffusion step.
        :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
        :param denoised_fn: if not None, a function which applies to the
            x_start prediction before it is used to sample.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :return: a dict containing the following keys:
                 - 'sample': a random sample from the model.
                 - 'pred_xstart': a prediction of x_0.
        """
        out = self.p_mean_variance(
            model,
            x,
            t,
            clip_denoised=clip_denoised,
            dynamic_thresholding_p=dynamic_thresholding_p,
            dynamic_thresholding_c=dynamic_thresholding_c,
            denoised_fn=denoised_fn,
            model_kwargs=model_kwargs,
        )
        noise = torch.randn_like(x)
        nonzero_mask = (
            (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
        )  # no noise when t == 0
        if inpainting_mask is None:
            inpainting_mask = torch.ones_like(x, device=x.device)

        sample = out['mean'] + nonzero_mask * torch.exp(0.5 * out['log_variance']) * noise
        sample = (1 - inpainting_mask)*x + inpainting_mask*sample
        return {'sample': sample, 'pred_xstart': out['pred_xstart']}

    def p_sample_loop(
        self,
        model,
        shape,
        noise=None,
        clip_denoised=True,
        dynamic_thresholding_p=0.99,
        dynamic_thresholding_c=1.7,
        inpainting_mask=None,
        denoised_fn=None,
        model_kwargs=None,
        device=None,
        progress=False,
        sample_fn=None,
    ):
        """
        Generate samples from the model.
        :param model: the model module.
        :param shape: the shape of the samples, (N, C, H, W).
        :param noise: if specified, the noise from the encoder to sample.
                      Should be of the same shape as `shape`.
        :param clip_denoised: if True, clip x_start predictions to [-1, 1].
        :param denoised_fn: if not None, a function which applies to the
            x_start prediction before it is used to sample.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :param device: if specified, the device to create the samples on.
                       If not specified, use a model parameter's device.
        :param progress: if True, show a tqdm progress bar.
        :return: a non-differentiable batch of samples.
        """
        final = None
        for step_idx, sample in enumerate(self.p_sample_loop_progressive(
            model,
            shape,
            noise=noise,
            clip_denoised=clip_denoised,
            dynamic_thresholding_p=dynamic_thresholding_p,
            dynamic_thresholding_c=dynamic_thresholding_c,
            denoised_fn=denoised_fn,
            inpainting_mask=inpainting_mask,
            model_kwargs=model_kwargs,
            device=device,
            progress=progress,
        )):
            if sample_fn is not None:
                sample = sample_fn(step_idx, sample)
            final = sample
        return final['sample']

    def p_sample_loop_progressive(
        self,
        model,
        shape,
        inpainting_mask=None,
        noise=None,
        clip_denoised=True,
        dynamic_thresholding_p=0.99,
        dynamic_thresholding_c=1.7,
        denoised_fn=None,
        model_kwargs=None,
        device=None,
        progress=False,
    ):
        """
        Generate samples from the model and yield intermediate samples from
        each timestep of diffusion.
        Arguments are the same as p_sample_loop().
        Returns a generator over dicts, where each dict is the return value of
        p_sample().
        """
        if device is None:
            device = next(model.parameters()).device
        assert isinstance(shape, (tuple, list))
        if noise is not None:
            img = noise
        else:
            img = torch.randn(*shape, device=device)
        indices = list(range(self.num_timesteps))[::-1]

        if progress:
            # Lazy import so that we don't depend on tqdm.
            from tqdm.auto import tqdm

            indices = tqdm(indices)

        for i in indices:
            t = torch.tensor([i] * shape[0], device=device)
            with torch.no_grad():
                out = self.p_sample(
                    model,
                    img,
                    t,
                    clip_denoised=clip_denoised,
                    dynamic_thresholding_p=dynamic_thresholding_p,
                    dynamic_thresholding_c=dynamic_thresholding_c,
                    denoised_fn=denoised_fn,
                    inpainting_mask=inpainting_mask,
                    model_kwargs=model_kwargs,
                )
                yield out
                img = out['sample']

    def ddim_sample(
        self,
        model,
        x,
        t,
        clip_denoised=True,
        dynamic_thresholding_p=0.99,
        dynamic_thresholding_c=1.7,
        denoised_fn=None,
        model_kwargs=None,
        eta=0.0,
    ):
        """
        Sample x_{t-1} from the model using DDIM.
        Same usage as p_sample().
        """
        out = self.p_mean_variance(
            model,
            x,
            t,
            dynamic_thresholding_p=dynamic_thresholding_p,
            dynamic_thresholding_c=dynamic_thresholding_c,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            model_kwargs=model_kwargs,
        )
        # Usually our model outputs epsilon, but we re-derive it
        # in case we used x_start or x_prev prediction.
        eps = self._predict_eps_from_xstart(x, t, out['pred_xstart'])
        alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
        alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
        sigma = (
            eta
            * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
            * torch.sqrt(1 - alpha_bar / alpha_bar_prev)
        )
        # Equation 12.
        noise = torch.randn_like(x)
        mean_pred = (
            out['pred_xstart'] * torch.sqrt(alpha_bar_prev)
            + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
        )
        nonzero_mask = (
            (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
        )  # no noise when t == 0
        sample = mean_pred + nonzero_mask * sigma * noise
        return {'sample': sample, 'pred_xstart': out['pred_xstart']}

    def ddim_reverse_sample(
        self,
        model,
        x,
        t,
        clip_denoised=True,
        dynamic_thresholding_p=0.99,
        dynamic_thresholding_c=1.7,
        denoised_fn=None,
        model_kwargs=None,
        eta=0.0,
    ):
        """
        Sample x_{t+1} from the model using DDIM reverse ODE.
        """
        assert eta == 0.0, 'Reverse ODE only for deterministic path'
        out = self.p_mean_variance(
            model,
            x,
            t,
            clip_denoised=clip_denoised,
            dynamic_thresholding_p=dynamic_thresholding_p,
            dynamic_thresholding_c=dynamic_thresholding_c,
            denoised_fn=denoised_fn,
            model_kwargs=model_kwargs,
        )
        # Usually our model outputs epsilon, but we re-derive it
        # in case we used x_start or x_prev prediction.
        eps = (
            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
            - out['pred_xstart']
        ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
        alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)

        # Equation 12. reversed
        mean_pred = (
            out['pred_xstart'] * torch.sqrt(alpha_bar_next)
            + torch.sqrt(1 - alpha_bar_next) * eps
        )

        return {'sample': mean_pred, 'pred_xstart': out['pred_xstart']}

    def ddim_sample_loop(
        self,
        model,
        shape,
        noise=None,
        clip_denoised=True,
        dynamic_thresholding_p=0.99,
        dynamic_thresholding_c=1.7,
        denoised_fn=None,
        model_kwargs=None,
        device=None,
        progress=False,
        eta=0.0,
        sample_fn=None,
    ):
        """
        Generate samples from the model using DDIM.
        Same usage as p_sample_loop().
        """
        final = None
        for step_idx, sample in enumerate(self.ddim_sample_loop_progressive(
            model,
            shape,
            noise=noise,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            dynamic_thresholding_p=dynamic_thresholding_p,
            dynamic_thresholding_c=dynamic_thresholding_c,
            model_kwargs=model_kwargs,
            device=device,
            progress=progress,
            eta=eta,
        )):
            if sample_fn is not None:
                sample = sample_fn(step_idx, sample)
            final = sample
        return final['sample']

    def ddim_sample_loop_progressive(
        self,
        model,
        shape,
        noise=None,
        clip_denoised=True,
        dynamic_thresholding_p=0.99,
        dynamic_thresholding_c=1.7,
        denoised_fn=None,
        model_kwargs=None,
        device=None,
        progress=False,
        eta=0.0,
    ):
        """
        Use DDIM to sample from the model and yield intermediate samples from
        each timestep of DDIM.
        Same usage as p_sample_loop_progressive().
        """
        if device is None:
            device = next(model.parameters()).device
        assert isinstance(shape, (tuple, list))
        if noise is not None:
            img = noise
        else:
            img = torch.randn(*shape, device=device)
        indices = list(range(self.num_timesteps))[::-1]

        if progress:
            # Lazy import so that we don't depend on tqdm.
            from tqdm.auto import tqdm

            indices = tqdm(indices)

        for i in indices:
            t = torch.tensor([i] * shape[0], device=device)
            with torch.no_grad():
                out = self.ddim_sample(
                    model,
                    img,
                    t,
                    clip_denoised=clip_denoised,
                    dynamic_thresholding_p=dynamic_thresholding_p,
                    dynamic_thresholding_c=dynamic_thresholding_c,
                    denoised_fn=denoised_fn,
                    model_kwargs=model_kwargs,
                    eta=eta,
                )
                yield out
                img = out['sample']

    def _vb_terms_bpd(
        self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
    ):
        """
        Get a term for the variational lower-bound.
        The resulting units are bits (rather than nats, as one might expect).
        This allows for comparison to other papers.
        :return: a dict with the following keys:
                 - 'output': a shape [N] tensor of NLLs or KLs.
                 - 'pred_xstart': the x_0 predictions.
        """
        true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
            x_start=x_start, x_t=x_t, t=t
        )
        out = self.p_mean_variance(
            model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
        )
        kl = normal_kl(
            true_mean, true_log_variance_clipped, out['mean'], out['log_variance']
        )
        kl = mean_flat(kl) / np.log(2.0)

        decoder_nll = -discretized_gaussian_log_likelihood(
            x_start, means=out['mean'], log_scales=0.5 * out['log_variance']
        )
        assert decoder_nll.shape == x_start.shape
        decoder_nll = mean_flat(decoder_nll) / np.log(2.0)

        # At the first timestep return the decoder NLL,
        # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
        output = torch.where((t == 0), decoder_nll, kl)
        return {'output': output, 'pred_xstart': out['pred_xstart']}

    def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
        """
        Compute training losses for a single timestep.
        :param model: the model to evaluate loss on.
        :param x_start: the [N x C x ...] tensor of inputs.
        :param t: a batch of timestep indices.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :param noise: if specified, the specific Gaussian noise to try to remove.
        :return: a dict with the key "loss" containing a tensor of shape [N].
                 Some mean or variance settings may also have other keys.
        """
        if model_kwargs is None:
            model_kwargs = {}
        if noise is None:
            noise = torch.randn_like(x_start)
        x_t = self.q_sample(x_start, t, noise=noise)

        terms = {}

        if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
            terms['loss'] = self._vb_terms_bpd(
                model=model,
                x_start=x_start,
                x_t=x_t,
                t=t,
                clip_denoised=False,
                model_kwargs=model_kwargs,
            )['output']
            if self.loss_type == LossType.RESCALED_KL:
                terms['loss'] *= self.num_timesteps
        elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
            model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)

            if self.model_var_type in [
                ModelVarType.LEARNED,
                ModelVarType.LEARNED_RANGE,
            ]:
                B, C = x_t.shape[:2]
                assert model_output.shape == (B, C * 2, *x_t.shape[2:])
                model_output, model_var_values = torch.split(model_output, C, dim=1)
                # Learn the variance using the variational bound, but don't let
                # it affect our mean prediction.
                frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1)
                terms['vb'] = self._vb_terms_bpd(
                    model=lambda *args, r=frozen_out: r,
                    x_start=x_start,
                    x_t=x_t,
                    t=t,
                    clip_denoised=False,
                )['output']
                if self.loss_type == LossType.RESCALED_MSE:
                    # Divide by 1000 for equivalence with initial implementation.
                    # Without a factor of 1/1000, the VB term hurts the MSE term.
                    terms['vb'] *= self.num_timesteps / 1000.0

            target = {
                ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
                    x_start=x_start, x_t=x_t, t=t
                )[0],
                ModelMeanType.START_X: x_start,
                ModelMeanType.EPSILON: noise,
            }[self.model_mean_type]
            assert model_output.shape == target.shape == x_start.shape
            terms['mse'] = mean_flat((target - model_output) ** 2)
            if 'vb' in terms:
                terms['loss'] = terms['mse'] + terms['vb']
            else:
                terms['loss'] = terms['mse']
        else:
            raise NotImplementedError(self.loss_type)

        return terms

    def _prior_bpd(self, x_start):
        """
        Get the prior KL term for the variational lower-bound, measured in
        bits-per-dim.
        This term can't be optimized, as it only depends on the encoder.
        :param x_start: the [N x C x ...] tensor of inputs.
        :return: a batch of [N] KL values (in bits), one per batch element.
        """
        batch_size = x_start.shape[0]
        t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
        qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
        kl_prior = normal_kl(
            mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
        )
        return mean_flat(kl_prior) / np.log(2.0)

    def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
        """
        Compute the entire variational lower-bound, measured in bits-per-dim,
        as well as other related quantities.
        :param model: the model to evaluate loss on.
        :param x_start: the [N x C x ...] tensor of inputs.
        :param clip_denoised: if True, clip denoised samples.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :return: a dict containing the following keys:
                 - total_bpd: the total variational lower-bound, per batch element.
                 - prior_bpd: the prior term in the lower-bound.
                 - vb: an [N x T] tensor of terms in the lower-bound.
                 - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
                 - mse: an [N x T] tensor of epsilon MSEs for each timestep.
        """
        device = x_start.device
        batch_size = x_start.shape[0]

        vb = []
        xstart_mse = []
        mse = []
        for t in list(range(self.num_timesteps))[::-1]:
            t_batch = torch.tensor([t] * batch_size, device=device)
            noise = torch.randn_like(x_start)
            x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
            # Calculate VLB term at the current timestep
            with torch.no_grad():
                out = self._vb_terms_bpd(
                    model,
                    x_start=x_start,
                    x_t=x_t,
                    t=t_batch,
                    clip_denoised=clip_denoised,
                    model_kwargs=model_kwargs,
                )
            vb.append(out['output'])
            xstart_mse.append(mean_flat((out['pred_xstart'] - x_start) ** 2))
            eps = self._predict_eps_from_xstart(x_t, t_batch, out['pred_xstart'])
            mse.append(mean_flat((eps - noise) ** 2))

        vb = torch.stack(vb, dim=1)
        xstart_mse = torch.stack(xstart_mse, dim=1)
        mse = torch.stack(mse, dim=1)

        prior_bpd = self._prior_bpd(x_start)
        total_bpd = vb.sum(dim=1) + prior_bpd
        return {
            'total_bpd': total_bpd,
            'prior_bpd': prior_bpd,
            'vb': vb,
            'xstart_mse': xstart_mse,
            'mse': mse,
        }


def _extract_into_tensor(arr, timesteps, broadcast_shape):
    """
    Extract values from a 1-D numpy array for a batch of indices.
    :param arr: the 1-D numpy array.
    :param timesteps: a tensor of indices into the array to extract.
    :param broadcast_shape: a larger shape of K dimensions with the batch
                            dimension equal to the length of timesteps.
    :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
    """
    res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
    while len(res.shape) < len(broadcast_shape):
        res = res[..., None]
    return res.expand(broadcast_shape)


================================================
FILE: deepfloyd_if/model/losses.py
================================================
# -*- coding: utf-8 -*-
"""
Helpers for various likelihood-based losses. These are ported from the original
Ho et al. diffusion model codebase:
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
"""

import torch
import numpy as np


def normal_kl(mean1, logvar1, mean2, logvar2):
    """
    Compute the KL divergence between two gaussians.
    Shapes are automatically broadcasted, so batches can be compared to
    scalars, among other use cases.
    """
    tensor = None
    for obj in (mean1, logvar1, mean2, logvar2):
        if isinstance(obj, torch.Tensor):
            tensor = obj
            break
    assert tensor is not None, 'at least one argument must be a Tensor'

    # Force variances to be Tensors. Broadcasting helps convert scalars to
    # Tensors, but it does not work for th.exp().
    logvar1, logvar2 = [
        x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
        for x in (logvar1, logvar2)
    ]

    return 0.5 * (
        -1.0
        + logvar2
        - logvar1
        + torch.exp(logvar1 - logvar2)
        + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
    )


def approx_standard_normal_cdf(x):
    """
    A fast approximation of the cumulative distribution function of the
    standard normal.
    """
    return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3))))


def discretized_gaussian_log_likelihood(x, *, means, log_scales):
    """
    Compute the log-likelihood of a Gaussian distribution discretizing to a
    given image.
    :param x: the target images. It is assumed that this was uint8 values,
              rescaled to the range [-1, 1].
    :param means: the Gaussian mean Tensor.
    :param log_scales: the Gaussian log stddev Tensor.
    :return: a tensor like x of log probabilities (in nats).
    """
    assert x.shape == means.shape == log_scales.shape
    centered_x = x - means
    inv_stdv = torch.exp(-log_scales)
    plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
    cdf_plus = approx_standard_normal_cdf(plus_in)
    min_in = inv_stdv * (centered_x - 1.0 / 255.0)
    cdf_min = approx_standard_normal_cdf(min_in)
    log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
    log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
    cdf_delta = cdf_plus - cdf_min
    log_probs = torch.where(
        x < -0.999,
        log_cdf_plus,
        torch.where(x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))),
    )
    assert log_probs.shape == x.shape
    return log_probs


================================================
FILE: deepfloyd_if/model/nn.py
================================================
# -*- coding: utf-8 -*-
import math

import torch
import torch.nn.functional as F
from torch import nn
from torch import Tensor


def mean_flat(tensor):
    """
    Take the mean over all non-batch dimensions.
    """
    return tensor.mean(dim=list(range(1, len(tensor.shape))))


def gelu(x):
    return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))


@torch.jit.script
def gelu_jit(x):
    """OpenAI's gelu implementation."""
    return gelu(x)


class GELUJit(torch.nn.Module):
    def forward(self, input: Tensor) -> Tensor:
        return gelu_jit(input)


def get_activation(activation):
    if activation == 'silu':
        return torch.nn.SiLU()
    elif activation == 'gelu_jit':
        return GELUJit()
    elif activation == 'gelu':
        return torch.nn.GELU()
    elif activation == 'none':
        return torch.nn.Identity()
    else:
        raise ValueError(f'unknown activation type {activation}')


class GroupNorm32(nn.GroupNorm):
    def __init__(self, num_groups, num_channels, eps=1e-5, dtype=None):
        super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, dtype=dtype)

    def forward(self, x):
        y = super().forward(x).to(x.dtype)
        return y


class AttentionPooling(nn.Module):

    def __init__(self, num_heads, embed_dim, dtype=None):
        super().__init__()
        self.dtype = dtype
        self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim ** 0.5)
        self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
        self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
        self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
        self.num_heads = num_heads
        self.dim_per_head = embed_dim // self.num_heads

    def forward(self, x):
        bs, length, width = x.size()

        def shape(x):
            # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
            x = x.view(bs, -1, self.num_heads, self.dim_per_head)
            # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
            x = x.transpose(1, 2)
            # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
            x = x.reshape(bs*self.num_heads, -1, self.dim_per_head)
            # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
            x = x.transpose(1, 2)
            return x

        class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
        x = torch.cat([class_token, x], dim=1)  # (bs, length+1, width)

        # (bs*n_heads, class_token_length, dim_per_head)
        q = shape(self.q_proj(class_token))
        # (bs*n_heads, length+class_token_length, dim_per_head)
        k = shape(self.k_proj(x))
        v = shape(self.v_proj(x))

        # (bs*n_heads, class_token_length, length+class_token_length):
        scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
        weight = torch.einsum(
            'bct,bcs->bts', q * scale, k * scale
        )  # More stable with f16 than dividing afterwards
        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)

        # (bs*n_heads, dim_per_head, class_token_length)
        a = torch.einsum('bts,bcs->bct', weight, v)

        # (bs, length+1, width)
        a = a.reshape(bs, -1, 1).transpose(1, 2)

        return a[:, 0, :]  # cls_token


def conv_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D convolution module.
    """
    if dims == 1:
        return nn.Conv1d(*args, **kwargs)
    elif dims == 2:
        return nn.Conv2d(*args, **kwargs)
    elif dims == 3:
        return nn.Conv3d(*args, **kwargs)
    raise ValueError(f'unsupported dimensions: {dims}')


def linear(*args, **kwargs):
    """
    Create a linear module.
    """
    return nn.Linear(*args, **kwargs)


def avg_pool_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D average pooling module.
    """
    if dims == 1:
        return nn.AvgPool1d(*args, **kwargs)
    elif dims == 2:
        return nn.AvgPool2d(*args, **kwargs)
    elif dims == 3:
        return nn.AvgPool3d(*args, **kwargs)
    raise ValueError(f'unsupported dimensions: {dims}')


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


def scale_module(module, scale):
    """
    Scale the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().mul_(scale)
    return module


def normalization(channels, dtype=None):
    """
    Make a standard normalization layer.
    :param channels: number of input channels.
    :return: an nn.Module for normalization.
    """
    return GroupNorm32(num_channels=channels, num_groups=32, dtype=dtype)


def timestep_embedding(timesteps, dim, max_period=10000, dtype=None):
    """
    Create sinusoidal timestep embeddings.
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    if dtype is None:
        dtype = torch.float32
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
    ).to(device=timesteps.device, dtype=dtype)
    args = timesteps[:, None].type(dtype) * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding


def attention(q, k, v, d_k):
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    scores = F.softmax(scores, dim=-1)
    output = torch.matmul(scores, v)
    return output


================================================
FILE: deepfloyd_if/model/resample.py
================================================
# -*- coding: utf-8 -*-
from abc import ABC, abstractmethod

import torch
import numpy as np


class ScheduleSampler(ABC):
    """
    A distribution over timesteps in the diffusion process, intended to reduce
    variance of the objective.
    By default, samplers perform unbiased importance sampling, in which the
    objective's mean is unchanged.
    However, subclasses may override sample() to change how the resampled
    terms are reweighted, allowing for actual changes in the objective.
    """

    @abstractmethod
    def weights(self):
        """
        Get a numpy array of weights, one per diffusion step.
        The weights needn't be normalized, but must be positive.
        """

    def sample(self, batch_size, device):
        """
        Importance-sample timesteps for a batch.
        :param batch_size: the number of timesteps.
        :param device: the torch device to save to.
        :return: a tuple (timesteps, weights):
                 - timesteps: a tensor of timestep indices.
                 - weights: a tensor of weights to scale the resulting losses.
        """
        w = self.weights()
        p = w / np.sum(w)
        indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
        indices = torch.from_numpy(indices_np).long().to(device)
        weights_np = 1 / (len(p) * p[indices_np])
        weights = torch.from_numpy(weights_np).float().to(device)
        return indices, weights


class UniformSampler(ScheduleSampler):
    def __init__(self, num_timesteps):
        self._weights = np.ones([num_timesteps])

    def weights(self):
        return self._weights


class StaticSampler(ABC):

    def sample(self, batch_size, device, static_step=100):
        indices_np = np.ones(batch_size, dtype=np.int) * static_step
        weights_np = np.ones(batch_size, dtype=np.int)
        indices = torch.from_numpy(indices_np).long().to(device)
        weights = torch.from_numpy(weights_np).float().to(device)
        return indices, weights


================================================
FILE: deepfloyd_if/model/respace.py
================================================
# -*- coding: utf-8 -*-
import torch
import numpy as np

from . import gaussian_diffusion as gd


def create_gaussian_diffusion(
    *,
    steps=1000,
    learn_sigma=False,
    sigma_small=False,
    noise_schedule='linear',
    use_kl=False,
    predict_xstart=False,
    rescale_timesteps=False,
    rescale_learned_sigmas=False,
    timestep_respacing='',
):
    betas = gd.get_named_beta_schedule(noise_schedule, steps)
    if use_kl:
        loss_type = gd.LossType.RESCALED_KL
    elif rescale_learned_sigmas:
        loss_type = gd.LossType.RESCALED_MSE
    else:
        loss_type = gd.LossType.MSE
    if not timestep_respacing:
        timestep_respacing = [steps]
    return SpacedDiffusion(
        use_timesteps=space_timesteps(steps, timestep_respacing),
        betas=betas,
        model_mean_type=(
            gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
        ),
        model_var_type=(
            (
                gd.ModelVarType.FIXED_LARGE
                if not sigma_small
                else gd.ModelVarType.FIXED_SMALL
            )
            if not learn_sigma
            else gd.ModelVarType.LEARNED_RANGE
        ),
        loss_type=loss_type,
        rescale_timesteps=rescale_timesteps,
    )


def space_timesteps(num_timesteps, section_counts):
    """
    Create a list of timesteps to use from an original diffusion process,
    given the number of timesteps we want to take from equally-sized portions
    of the original process.
    For example, if there's 300 timesteps and the section counts are [10,15,20]
    then the first 100 timesteps are strided to be 10 timesteps, the second 100
    are strided to be 15 timesteps, and the final 100 are strided to be 20.
    If the stride is a string starting with "ddim", then the fixed striding
    from the DDIM paper is used, and only one section is allowed.
    :param num_timesteps: the number of diffusion steps in the original
                          process to divide up.
    :param section_counts: either a list of numbers, or a string containing
                           comma-separated numbers, indicating the step count
                           per section. As a special case, use "ddimN" where N
                           is a number of steps to use the striding from the
                           DDIM paper.
    :return: a set of diffusion steps from the original process to use.
    """
    if isinstance(section_counts, str):
        if section_counts.startswith('ddim'):
            desired_count = int(section_counts[len('ddim'):])
            for i in range(1, num_timesteps):
                if len(range(0, num_timesteps, i)) == desired_count:
                    return set(range(0, num_timesteps, i))
            raise ValueError(
                f'cannot create exactly {num_timesteps} steps with an integer stride'
            )
        elif section_counts == 'fast27':
            steps = space_timesteps(num_timesteps, '10,10,3,2,2')
            # Help reduce DDIM artifacts from noisiest timesteps.
            steps.remove(num_timesteps - 1)
            steps.add(num_timesteps - 3)
            return steps
        section_counts = [int(x) for x in section_counts.split(',')]
    size_per = num_timesteps // len(section_counts)
    extra = num_timesteps % len(section_counts)
    start_idx = 0
    all_steps = []
    for i, section_count in enumerate(section_counts):
        size = size_per + (1 if i < extra else 0)
        if size < section_count:
            raise ValueError(
                f'cannot divide section of {size} steps into {section_count}'
            )
        if section_count <= 1:
            frac_stride = 1
        else:
            frac_stride = (size - 1) / (section_count - 1)
        cur_idx = 0.0
        taken_steps = []
        for _ in range(section_count):
            taken_steps.append(start_idx + round(cur_idx))
            cur_idx += frac_stride
        all_steps += taken_steps
        start_idx += size
    return set(all_steps)


class SpacedDiffusion(gd.GaussianDiffusion):
    """
    A diffusion process which can skip steps in a base diffusion process.
    :param use_timesteps: a collection (sequence or set) of timesteps from the
                          original diffusion process to retain.
    :param kwargs: the kwargs to create the base diffusion process.
    """

    def __init__(self, use_timesteps, **kwargs):
        self.use_timesteps = set(use_timesteps)
        self.timestep_map = []
        self.original_num_steps = len(kwargs['betas'])

        base_diffusion = gd.GaussianDiffusion(**kwargs)  # pylint: disable=missing-kwoa
        last_alpha_cumprod = 1.0
        new_betas = []
        for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
            if i in self.use_timesteps:
                new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
                last_alpha_cumprod = alpha_cumprod
                self.timestep_map.append(i)
        kwargs['betas'] = np.array(new_betas)
        super().__init__(**kwargs)

    def p_mean_variance(
        self, model, *args, **kwargs
    ):  # pylint: disable=signature-differs
        return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)

    def training_losses(
        self, model, *args, **kwargs
    ):  # pylint: disable=signature-differs
        return super().training_losses(self._wrap_model(model), *args, **kwargs)

    def _wrap_model(self, model):
        if isinstance(model, _WrappedModel):
            return model
        return _WrappedModel(
            model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
        )

    def _scale_timesteps(self, t):
        # Scaling is done by the wrapped model.
        return t


class _WrappedModel:
    def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
        self.model = model
        self.timestep_map = timestep_map
        self.rescale_timesteps = rescale_timesteps
        self.original_num_steps = original_num_steps

    def __call__(self, x, ts, **kwargs):
        map_tensor = torch.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
        new_ts = map_tensor[ts]
        if self.rescale_timesteps:
            new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
        return self.model(x, new_ts, **kwargs)


================================================
FILE: deepfloyd_if/model/unet.py
================================================
# -*- coding: utf-8 -*-
import os
import math
from abc import abstractmethod

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from .nn import avg_pool_nd, conv_nd, linear, normalization, timestep_embedding, zero_module, get_activation, \
    AttentionPooling

_FORCE_MEM_EFFICIENT_ATTN = int(os.environ.get('FORCE_MEM_EFFICIENT_ATTN', 0))
print('FORCE_MEM_EFFICIENT_ATTN=', _FORCE_MEM_EFFICIENT_ATTN, '@UNET:QKVATTENTION')
if _FORCE_MEM_EFFICIENT_ATTN:
    from xformers.ops import memory_efficient_attention  # noqa


class TimestepBlock(nn.Module):
    """
    Any module where forward() takes timestep embeddings as a second argument.
    """

    @abstractmethod
    def forward(self, x, emb):
        """
        Apply the module to `x` given `emb` timestep embeddings.
        """


class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    """
    A sequential module that passes timestep embeddings to the children that
    support it as an extra input.
    """

    def forward(self, x, emb, encoder_out=None):
        for layer in self:
            if isinstance(layer, TimestepBlock):
                x = layer(x, emb)
            elif isinstance(layer, AttentionBlock):
                x = layer(x, encoder_out)
            else:
                x = layer(x)
        return x


class Upsample(nn.Module):
    """
    An upsampling layer with an optional convolution.
    :param channels: channels in the inputs and outputs.
    :param use_conv: a bool determining a convolution is applied.
    :param dims: determines the signal is 1D, 2D, or 3D. If 3D, then
                 upsampling occurs in the inner-two dimensions.
    """

    def __init__(self, channels, use_conv, dims=2, out_channels=None, dtype=None):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.dims = dims
        self.dtype = dtype
        if use_conv:
            self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1, dtype=self.dtype)

    def forward(self, x):
        assert x.shape[1] == self.channels
        if self.dims == 3:
            x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest')
        else:
            if self.dtype == torch.bfloat16:
                x = x.type(torch.float32 if x.device.type == 'cpu' else torch.float16)
            x = F.interpolate(x, scale_factor=2, mode='nearest')
            if self.dtype == torch.bfloat16:
                x = x.type(torch.bfloat16)
        if self.use_conv:
            x = self.conv(x)
        return x


class Downsample(nn.Module):
    """
    A downsampling layer with an optional convolution.
    :param channels: channels in the inputs and outputs.
    :param use_conv: a bool determining a convolution is applied.
    :param dims: determines the signal is 1D, 2D, or 3D. If 3D, then
                 downsampling occurs in the inner-two dimensions.
    """

    def __init__(self, channels, use_conv, dims=2, out_channels=None, dtype=None):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.dims = dims
        self.dtype = dtype
        stride = 2 if dims != 3 else (1, 2, 2)
        if use_conv:
            self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1, dtype=self.dtype)
        else:
            assert self.channels == self.out_channels
            self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)

    def forward(self, x):
        assert x.shape[1] == self.channels
        return self.op(x)


class ResBlock(TimestepBlock):
    """
    A residual block that can optionally change the number of channels.
    :param channels: the number of input channels.
    :param emb_channels: the number of timestep embedding channels.
    :param dropout: the rate of dropout.
    :param out_channels: specified, the number of out channels.
    :param use_conv: True and out_channels is specified, use a spatial
        convolution instead of a smaller 1x1 convolution to change the
        channels in the skip connection.
    :param dims: determines the signal is 1D, 2D, or 3D.
    :param up: True, use this block for upsampling.
    :param down: True, use this block for downsampling.
    """

    def __init__(
            self,
            channels,
            emb_channels,
            dropout,
            activation,
            out_channels=None,
            use_conv=False,
            use_scale_shift_norm=False,
            dims=2,
            up=False,
            down=False,
            dtype=None,
            efficient_activation=False,
            scale_skip_connection=False,
    ):
        super().__init__()
        self.dtype = dtype
        self.channels = channels
        self.emb_channels = emb_channels
        self.dropout = dropout
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_scale_shift_norm = use_scale_shift_norm
        self.efficient_activation = efficient_activation
        self.scale_skip_connection = scale_skip_connection

        self.in_layers = nn.Sequential(
            normalization(channels, dtype=self.dtype),
            get_activation(activation),
            conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=self.dtype),
        )

        self.updown = up or down

        if up:
            self.h_upd = Upsample(channels, False, dims, dtype=self.dtype)
            self.x_upd = Upsample(channels, False, dims, dtype=self.dtype)
        elif down:
            self.h_upd = Downsample(channels, False, dims, dtype=self.dtype)
            self.x_upd = Downsample(channels, False, dims, dtype=self.dtype)
        else:
            self.h_upd = self.x_upd = nn.Identity()

        self.emb_layers = nn.Sequential(
            nn.Identity() if self.efficient_activation else get_activation(activation),
            linear(
                emb_channels,
                2 * self.out_channels if use_scale_shift_norm else self.out_channels,
                dtype=self.dtype
            ),
        )
        self.out_layers = nn.Sequential(
            normalization(self.out_channels, dtype=self.dtype),
            get_activation(activation),
            nn.Dropout(p=dropout),
            zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=self.dtype)),
        )

        if self.out_channels == channels:
            self.skip_connection = nn.Identity()
        elif use_conv:
            self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=self.dtype)
        else:
            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=self.dtype)

    def forward(self, x, emb):
        """
        Apply the block to a Tensor, conditioned on a timestep embedding.
        :param x: an [N x C x ...] Tensor of features.
        :param emb: an [N x emb_channels] Tensor of timestep embeddings.
        :return: an [N x C x ...] Tensor of outputs.
        """
        if self.updown:
            in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
            h = in_rest(x)
            h = self.h_upd(h)
            x = self.x_upd(x)
            h = in_conv(h)
        else:
            h = self.in_layers(x)
        emb_out = self.emb_layers(emb).type(h.dtype)
        while len(emb_out.shape) < len(h.shape):
            emb_out = emb_out[..., None]
        if self.use_scale_shift_norm:
            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
            scale, shift = torch.chunk(emb_out, 2, dim=1)
            h = out_norm(h) * (1 + scale) + shift
            h = out_rest(h)
        else:
            h = h + emb_out
            h = self.out_layers(h)

        res = self.skip_connection(x) + h
        if self.scale_skip_connection:
            res *= 0.7071  # 1 / sqrt(2), https://arxiv.org/pdf/2104.07636.pdf
        return res


class AttentionBlock(nn.Module):
    """
    An attention block that allows spatial positions to attend to each other.
    Originally ported from here, but adapted to the N-d case.
    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
    """

    def __init__(
            self,
            channels,
            num_heads=1,
            num_head_channels=-1,
            disable_self_attention=False,
            encoder_channels=None,
            dtype=None,
    ):
        super().__init__()
        self.dtype = dtype
        self.channels = channels
        self.disable_self_attention = disable_self_attention
        if num_head_channels == -1:
            self.num_heads = num_heads
        else:
            assert (
                channels % num_head_channels == 0
            ), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}'
            self.num_heads = channels // num_head_channels
        self.norm = normalization(channels, dtype=self.dtype)
        self.qkv = conv_nd(1, channels, channels * 3, 1, dtype=self.dtype)
        if self.disable_self_attention:
            self.qkv = conv_nd(1, channels, channels, 1, dtype=self.dtype)
        else:
            self.qkv = conv_nd(1, channels, channels * 3, 1, dtype=self.dtype)
        self.attention = QKVAttention(self.num_heads, disable_self_attention=disable_self_attention)

        if encoder_channels is not None:
            self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1, dtype=self.dtype)
            self.norm_encoder = normalization(encoder_channels, dtype=self.dtype)
        self.proj_out = zero_module(conv_nd(1, channels, channels, 1, dtype=self.dtype))

    def forward(self, x, encoder_out=None):
        b, c, *spatial = x.shape
        qkv = self.qkv(self.norm(x).view(b, c, -1))
        if encoder_out is not None:
            # from imagen article: https://arxiv.org/pdf/2205.11487.abs
            encoder_out = self.norm_encoder(encoder_out)
            # # #
            encoder_out = self.encoder_kv(encoder_out)
            h = self.attention(qkv, encoder_out)
        else:
            h = self.attention(qkv)
        h = self.proj_out(h)
        return x + h.reshape(b, c, *spatial)


class QKVAttention(nn.Module):
    """
    A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
    """

    def __init__(self, n_heads, disable_self_attention=False):
        super().__init__()
        self.n_heads = n_heads
        self.disable_self_attention = disable_self_attention

    def forward(self, qkv, encoder_kv=None):
        """
        Apply QKV attention.
        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
        :return: an [N x (H * C) x T] tensor after attention.
        """
        bs, width, length = qkv.shape
        if self.disable_self_attention:
            ch = width // (1 * self.n_heads)
            q, = qkv.reshape(bs * self.n_heads, ch * 1, length).split(ch, dim=1)
        else:
            assert width % (3 * self.n_heads) == 0
            ch = width // (3 * self.n_heads)
            q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
        if encoder_kv is not None:
            assert encoder_kv.shape[1] == self.n_heads * ch * 2
            if self.disable_self_attention:
                k, v = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
            else:
                ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
                k = torch.cat([ek, k], dim=-1)
                v = torch.cat([ev, v], dim=-1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        if _FORCE_MEM_EFFICIENT_ATTN:
            q, k, v = map(lambda t: t.permute(0, 2, 1).contiguous(), (q, k, v))
            a = memory_efficient_attention(q, k, v)
            a = a.permute(0, 2, 1)
        else:
            weight = torch.einsum(
                'bct,bcs->bts', q * scale, k * scale
            )  # More stable with f16 than dividing afterwards
            weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
            a = torch.einsum('bts,bcs->bct', weight, v)
        return a.reshape(bs, -1, length)


class UNetModel(nn.Module):
    """
    The full UNet model with attention and timestep embedding.
    :param in_channels: channels in the input Tensor.
    :param model_channels: base channel count for the model.
    :param out_channels: channels in the output Tensor.
    :param num_res_blocks: number of residual blocks per downsample.
    :param attention_resolutions: a collection of downsample rates at which
        attention will take place. May be a set, list, or tuple.
        For example, this contains 4, then at 4x downsampling, attention
        will be used.
    :param dropout: the dropout probability.
    :param channel_mult: channel multiplier for each level of the UNet.
    :param conv_resample: True, use learned convolutions for upsampling and
        downsampling.
    :param dims: determines the signal is 1D, 2D, or 3D.
    :param num_classes: specified (as an int), then this model will be
        class-conditional with `num_classes` classes.
    :param num_heads: the number of attention heads in each attention layer.
    :param num_heads_channels: specified, ignore num_heads and instead use
                               a fixed channel width per attention head.
    :param num_heads_upsample: works with num_heads to set a different number
                               of heads for upsampling. Deprecated.
    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
    :param resblock_updown: use residual blocks for up/downsampling.
    """

    def __init__(
            self,
            in_channels,
            model_channels,
            out_channels,
            num_res_blocks,
            attention_resolutions,
            activation,
            encoder_dim,
            att_pool_heads,
            encoder_channels,
            image_size,
            disable_self_attentions=None,
            dropout=0,
            channel_mult=(1, 2, 4, 8),
            conv_resample=True,
            dims=2,
            num_classes=None,
            precision='32',
            num_heads=1,
            num_head_channels=-1,
            num_heads_upsample=-1,
            use_scale_shift_norm=False,
            resblock_updown=False,
            efficient_activation=False,
            scale_skip_connection=False,
    ):
        super().__init__()

        if num_heads_upsample == -1:
            num_heads_upsample = num_heads

        self.encoder_channels = encoder_channels
        self.encoder_dim = encoder_dim
        self.efficient_activation = efficient_activation
        self.scale_skip_connection = scale_skip_connection
        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.dropout = dropout

        # adapt attention resolutions
        if isinstance(attention_resolutions, str):
            self.attention_resolutions = []
            for res in attention_resolutions.split(','):
                self.attention_resolutions.append(image_size // int(res))
        else:
            self.attention_resolutions = attention_resolutions
        self.attention_resolutions = tuple(self.attention_resolutions)
        #

        # adapt disable self attention resolutions
        if not disable_self_attentions:
            self.disable_self_attentions = []
        elif disable_self_attentions is True:
            self.disable_self_attentions = attention_resolutions
        elif isinstance(disable_self_attentions, str):
            self.disable_self_attentions = []
            for res in disable_self_attentions.split(','):
                self.disable_self_attentions.append(image_size // int(res))
        else:
            self.disable_self_attentions = disable_self_attentions
        self.disable_self_attentions = tuple(self.disable_self_attentions)
        #

        # adapt channel mult
        if isinstance(channel_mult, str):
            self.channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(','))
        else:
            self.channel_mult = tuple(channel_mult)
        #

        self.conv_resample = conv_resample
        self.num_classes = num_classes
        self.dtype = torch.float32

        self.precision = str(precision)
        self.use_fp16 = precision == '16'
        if self.precision == '16':
            self.dtype = torch.float16
        elif self.precision == 'bf16':
            self.dtype = torch.bfloat16

        self.num_heads = num_heads
        self.num_head_channels = num_head_channels
        self.num_heads_upsample = num_heads_upsample

        self.time_embed_dim = model_channels * max(self.channel_mult)
        self.time_embed = nn.Sequential(
            linear(model_channels, self.time_embed_dim, dtype=self.dtype),
            get_activation(activation),
            linear(self.time_embed_dim, self.time_embed_dim, dtype=self.dtype),
        )

        if self.num_classes is not None:
            self.label_emb = nn.Embedding(num_classes, self.time_embed_dim)

        ch = input_ch = int(self.channel_mult[0] * model_channels)
        self.input_blocks = nn.ModuleList(
            [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1, dtype=self.dtype))]
        )
        self._feature_size = ch
        input_block_chans = [ch]
        ds = 1

        if isinstance(num_res_blocks, int):
            num_res_blocks = [num_res_blocks]*len(self.channel_mult)
        self.num_res_blocks = num_res_blocks

        for level, mult in enumerate(self.channel_mult):
            for _ in range(num_res_blocks[level]):
                layers = [
                    ResBlock(
                        ch,
                        self.time_embed_dim,
                        dropout,
                        out_channels=int(mult * model_channels),
                        dims=dims,
                        use_scale_shift_norm=use_scale_shift_norm,
                        dtype=self.dtype,
                        activation=activation,
                        efficient_activation=self.efficient_activation,
                        scale_skip_connection=self.scale_skip_connection,
                    )
                ]
                ch = int(mult * model_channels)
                if ds in self.attention_resolutions:
                    layers.append(
                        AttentionBlock(
                            ch,
                            num_heads=num_heads,
                            num_head_channels=num_head_channels,
                            encoder_channels=encoder_channels,
                            dtype=self.dtype,
                            disable_self_attention=ds in self.disable_self_attentions,
                        )
                    )
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch
                input_block_chans.append(ch)
            if level != len(self.channel_mult) - 1:
                out_ch = ch
                self.input_blocks.append(
                    TimestepEmbedSequential(
                        ResBlock(
                            ch,
                            self.time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_scale_shift_norm=use_scale_shift_norm,
                            down=True,
                            dtype=self.dtype,
                            activation=activation,
                            efficient_activation=self.efficient_activation,
                            scale_skip_connection=self.scale_skip_connection,
                        )
                        if resblock_updown
                        else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
                    )
                )
                ch = out_ch
                input_block_chans.append(ch)
                ds *= 2
                self._feature_size += ch

        self.middle_block = TimestepEmbedSequential(
            ResBlock(
                ch,
                self.time_embed_dim,
                dropout,
                dims=dims,
                use_scale_shift_norm=use_scale_shift_norm,
                dtype=self.dtype,
                activation=activation,
                efficient_activation=self.efficient_activation,
                scale_skip_connection=self.scale_skip_connection,
            ),
            AttentionBlock(
                ch,
                num_heads=num_heads,
                num_head_channels=num_head_channels,
                encoder_channels=encoder_channels,
                dtype=self.dtype,
                disable_self_attention=ds in self.disable_self_attentions,
            ),
            ResBlock(
                ch,
                self.time_embed_dim,
                dropout,
                dims=dims,
                use_scale_shift_norm=use_scale_shift_norm,
                dtype=self.dtype,
                activation=activation,
                efficient_activation=self.efficient_activation,
                scale_skip_connection=self.scale_skip_connection,
            ),
        )
        self._feature_size += ch

        self.output_blocks = nn.ModuleList([])
        for level, mult in list(enumerate(self.channel_mult))[::-1]:
            for i in range(num_res_blocks[level] + 1):
                ich = input_block_chans.pop()
                layers = [
                    ResBlock(
                        ch + ich,
                        self.time_embed_dim,
                        dropout,
                        out_channels=int(model_channels * mult),
                        dims=dims,
                        use_scale_shift_norm=use_scale_shift_norm,
                        dtype=self.dtype,
                        activation=activation,
                        efficient_activation=self.efficient_activation,
                        scale_skip_connection=self.scale_skip_connection,
                    )
                ]
                ch = int(model_channels * mult)
                if ds in self.attention_resolutions:
                    layers.append(
                        AttentionBlock(
                            ch,
                            num_heads=num_heads_upsample,
                            num_head_channels=num_head_channels,
                            encoder_channels=encoder_channels,
                            dtype=self.dtype,
                            disable_self_attention=ds in self.disable_self_attentions,
                        )
                    )
                if level and i == num_res_blocks[level]:
                    out_ch = ch
                    layers.append(
                        ResBlock(
                            ch,
                            self.time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_scale_shift_norm=use_scale_shift_norm,
                            up=True,
                            dtype=self.dtype,
                            activation=activation,
                            efficient_activation=self.efficient_activation,
                            scale_skip_connection=self.scale_skip_connection,
                        )
                        if resblock_updown
                        else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
                    )
                    ds //= 2
                self.output_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch

        self.out = nn.Sequential(
            normalization(ch, dtype=self.dtype),
            get_activation(activation),
            zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1, dtype=self.dtype)),
        )

        self.activation_layer = get_activation(activation) if self.efficient_activation else nn.Identity()

        self.encoder_pooling = nn.Sequential(
            nn.LayerNorm(encoder_dim, dtype=self.dtype),
            AttentionPooling(att_pool_heads, encoder_dim, dtype=self.dtype),
            nn.Linear(encoder_dim, self.time_embed_dim, dtype=self.dtype),
            nn.LayerNorm(self.time_embed_dim, dtype=self.dtype)
        )

        if encoder_dim != encoder_channels:
            self.encoder_proj = nn.Linear(encoder_dim, encoder_channels, dtype=self.dtype)
        else:
            self.encoder_proj = nn.Identity()

        self.cache = None

    def forward(self, x, timesteps, text_emb, timestep_text_emb=None, aug_emb=None, use_cache=False, **kwargs):
        hs = []
        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels, dtype=self.dtype))

        if use_cache and self.cache is not None:
            encoder_out, encoder_pool = self.cache
        else:
            text_emb = text_emb.type(self.dtype)
            encoder_out = self.encoder_proj(text_emb)
            encoder_out = encoder_out.permute(0, 2, 1)  # NLC -> NCL
            if timestep_text_emb is None:
                timestep_text_emb = text_emb
            encoder_pool = self.encoder_pooling(timestep_text_emb)
            if use_cache:
                self.cache = (encoder_out, encoder_pool)

        emb = emb + encoder_pool.to(emb)

        if aug_emb is not None:
            emb = emb + aug_emb.to(emb)

        emb = self.activation_layer(emb)

        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb, encoder_out)
            hs.append(h)
        h = self.middle_block(h, emb, encoder_out)
        for module in self.output_blocks:
            h = torch.cat([h, hs.pop()], dim=1)
            h = module(h, emb, encoder_out)
        h = h.type(self.dtype)
        h = self.out(h)
        return h


class SuperResUNetModel(UNetModel):
    """
    A text2im model that performs super-resolution.
    Expects an extra kwarg `low_res` to condition on a low-resolution image.
    """

    def __init__(self, low_res_diffusion, interpolate_mode='bilinear', *args, **kwargs):
        self.low_res_diffusion = low_res_diffusion
        self.interpolate_mode = interpolate_mode
        super().__init__(*args, **kwargs)

        self.aug_proj = nn.Sequential(
            linear(self.model_channels, self.time_embed_dim, dtype=self.dtype),
            get_activation(kwargs['activation']),
            linear(self.time_embed_dim, self.time_embed_dim, dtype=self.dtype),
        )

    def forward(self, x, timesteps, low_res, aug_level=None, **kwargs):
        bs, _, new_height, new_width = x.shape

        align_corners = True
        if self.interpolate_mode == 'nearest':
            align_corners = None

        upsampled = F.interpolate(
            low_res, (new_height, new_width), mode=self.interpolate_mode, align_corners=align_corners
        )

        if aug_level is None:
            aug_steps = (np.random.random(bs)*1000).astype(np.int64)  # uniform [0, 1)
            aug_steps = torch.from_numpy(aug_steps).to(x.device, dtype=torch.long)
        else:
            aug_steps = torch.tensor([int(aug_level * 1000)]).repeat(bs).to(x.device, dtype=torch.long)

        upsampled = self.low_res_diffusion.q_sample(upsampled, aug_steps)
        x = torch.cat([x, upsampled], dim=1)

        aug_emb = self.aug_proj(
            timestep_embedding(aug_steps, self.model_channels, dtype=self.dtype)
        )
        return super().forward(x, timesteps, aug_emb=aug_emb, **kwargs)


================================================
FILE: deepfloyd_if/modules/__init__.py
================================================
# -*- coding: utf-8 -*-
from .stage_I import IFStageI
from .stage_II import IFStageII
from .stage_III import IFStageIII
from .stage_III_sd_x4 import StableStageIII
from .t5 import T5Embedder
from .base import IFBaseModule

__all__ = ['IFBaseModule', 'IFStageI', 'IFStageII', 'IFStageIII', 'StableStageIII', 'T5Embedder']


================================================
FILE: deepfloyd_if/modules/base.py
================================================
# -*- coding: utf-8 -*-
import os
import random
import platform
from datetime import datetime

import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as T
from PIL import Image
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download
from accelerate.utils import set_module_tensor_to_device


from .. import utils
from ..model.respace import create_gaussian_diffusion
from .utils import load_model_weights, predict_proba, clip_process_generations


class IFBaseModule:

    stage = '-'

    available_models = []
    cpu_zero_emb = np.load(os.path.join(utils.RESOURCES_ROOT, 'zero_t5-v1_1-xxl_vector.npy'))
    cpu_zero_emb = torch.from_numpy(cpu_zero_emb)

    respacing_modes = {
        'fast27': '10,10,3,2,2',
        'smart27': '7,4,2,1,2,4,7',
        'smart50': '10,6,4,3,2,2,3,4,6,10',
        'smart100': '1,1,1,1,2,2,2,2,2,2,3,3,4,4,5,5,6,7,7,8,9,10,13',
        'smart185': '1,1,2,2,2,3,3,3,4,5,6,7,8,9,10,11,12,13,14,15,16,18,20',
        'super27': '1,1,1,1,1,1,1,2,5,13',  # for III super-res
        'super40': '2,2,2,2,2,2,3,4,6,15',  # for III super-res
        'super100': '4,4,6,6,8,8,10,10,14,30',  # for III super-res
    }

    wm_pil_img = Image.open(os.path.join(utils.RESOURCES_ROOT, 'wm.png'))

    try:
        import clip  # noqa
    except ModuleNotFoundError:
        print('Warning! You should install CLIP: "pip install git+https://github.com/openai/CLIP.git --no-deps"')
        raise

    clip_model, clip_preprocess = clip.load('ViT-L/14', device='cpu')
    clip_model.eval()

    cpu_w_weights, cpu_w_biases = load_model_weights(os.path.join(utils.RESOURCES_ROOT, 'w_head_v1.npz'))
    cpu_p_weights, cpu_p_biases = load_model_weights(os.path.join(utils.RESOURCES_ROOT, 'p_head_v1.npz'))
    w_threshold, p_threshold = 0.5, 0.5

    def __init__(self, dir_or_name, device, pil_img_size=256, cache_dir=None, hf_token=None):
        self.hf_token = hf_token
        self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_')
        self.dir_or_name = dir_or_name
        self.conf = self.load_conf(dir_or_name) if not self.use_diffusers else None
        self.device = torch.device(device)
        self.zero_emb = self.cpu_zero_emb.clone().to(self.device)
        self.pil_img_size = pil_img_size

    @property
    def use_diffusers(self):
        return False

    def embeddings_to_image(
        self, t5_embs, low_res=None, *,
        style_t5_embs=None,
        positive_t5_embs=None,
        negative_t5_embs=None,
        batch_repeat=1,
        dynamic_thresholding_p=0.95,
        sample_loop='ddpm',
        sample_timestep_respacing='smart185',
        dynamic_thresholding_c=1.5,
        guidance_scale=7.0,
        aug_level=0.25,
        positive_mixer=0.15,
        blur_sigma=None,
        img_size=None,
        img_scale=4.0,
        aspect_ratio='1:1',
        progress=True,
        seed=None,
        sample_fn=None,
        support_noise=None,
        support_noise_less_qsample_steps=0,
        inpainting_mask=None,
        **kwargs,
    ):
        self._clear_cache()
        image_w, image_h = self._get_image_sizes(low_res, img_size, aspect_ratio, img_scale)
        diffusion = self.get_diffusion(sample_timestep_respacing)

        bs_scale = 2 if positive_t5_embs is None else 3

        def model_fn(x_t, ts, **kwargs):
            half = x_t[: len(x_t) // bs_scale]
            combined = torch.cat([half]*bs_scale, dim=0)
            model_out = self.model(combined, ts, **kwargs)
            eps, rest = model_out[:, :3], model_out[:, 3:]
            if bs_scale == 3:
                cond_eps, pos_cond_eps, uncond_eps = torch.split(eps, len(eps) // bs_scale, dim=0)
                half_eps = uncond_eps + guidance_scale * (
                    cond_eps * (1 - positive_mixer) + pos_cond_eps * positive_mixer - uncond_eps)
                pos_half_eps = uncond_eps + guidance_scale * (pos_cond_eps - uncond_eps)
                eps = torch.cat([half_eps, pos_half_eps, half_eps], dim=0)
            else:
                cond_eps, uncond_eps = torch.split(eps, len(eps) // bs_scale, dim=0)
                half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
                eps = torch.cat([half_eps, half_eps], dim=0)
            return torch.cat([eps, rest], dim=1)

        seed = self.seed_everything(seed)

        text_emb = t5_embs.to(self.device, dtype=self.model.dtype).repeat(batch_repeat, 1, 1)
        batch_size = text_emb.shape[0] * batch_repeat

        if positive_t5_embs is not None:
            positive_t5_embs = positive_t5_embs.to(self.device, dtype=self.model.dtype).repeat(batch_repeat, 1, 1)

        if negative_t5_embs is not None:
            negative_t5_embs = negative_t5_embs.to(self.device, dtype=self.model.dtype).repeat(batch_repeat, 1, 1)

        timestep_text_emb = None
        if style_t5_embs is not None:
            list_timestep_text_emb = [
                style_t5_embs.to(self.device, dtype=self.model.dtype).repeat(batch_repeat, 1, 1),
            ]
            if positive_t5_embs is not None:
                list_timestep_text_emb.append(positive_t5_embs)
            if negative_t5_embs is not None:
                list_timestep_text_emb.append(negative_t5_embs)
            else:
                list_timestep_text_emb.append(
                    self.zero_emb.unsqueeze(0).repeat(batch_size, 1, 1).to(self.device, dtype=self.model.dtype))
            timestep_text_emb = torch.cat(list_timestep_text_emb, dim=0).to(self.device, dtype=self.model.dtype)

        metadata = {
            'seed': seed,
            'guidance_scale': guidance_scale,
            'dynamic_thresholding_p': dynamic_thresholding_p,
            'dynamic_thresholding_c': dynamic_thresholding_c,
            'batch_size': batch_size,
            'device_name': self.device_name,
            'img_size': [image_w, image_h],
            'sample_loop': sample_loop,
            'sample_timestep_respacing': sample_timestep_respacing,
            'stage': self.stage,
        }

        list_text_emb = [t5_embs.to(self.device)]
        if positive_t5_embs is not None:
            list_text_emb.append(positive_t5_embs.to(self.device))
        if negative_t5_embs is not None:
            list_text_emb.append(negative_t5_embs.to(self.device))
        else:
            list_text_emb.append(
                self.zero_emb.unsqueeze(0).repeat(batch_size, 1, 1).to(self.device, dtype=self.model.dtype))

        model_kwargs = dict(
            text_emb=torch.cat(list_text_emb, dim=0).to(self.device, dtype=self.model.dtype),
            timestep_text_emb=timestep_text_emb,
            use_cache=True,
        )
        if low_res is not None:
            if blur_sigma is not None:
                low_res = T.GaussianBlur(3, sigma=(blur_sigma, blur_sigma))(low_res)
            model_kwargs['low_res'] = torch.cat([low_res]*bs_scale, dim=0).to(self.device)
            model_kwargs['aug_level'] = aug_level

        if support_noise is None:
            noise = torch.randn(
                (batch_size * bs_scale, 3, image_h, image_w), device=self.device, dtype=self.model.dtype)
        else:
            assert support_noise_less_qsample_steps < len(diffusion.timestep_map) - 1
            assert support_noise.shape == (1, 3, image_h, image_w)
            q_sample_steps = torch.tensor([int(len(diffusion.timestep_map) - 1 - support_noise_less_qsample_steps)])
            support_noise = support_noise.cpu()
            noise = support_noise.clone()
            noise[inpainting_mask.cpu().bool() if inpainting_mask is not None else ...] = diffusion.q_sample(
                support_noise[inpainting_mask.cpu().bool() if inpainting_mask is not None else ...],
                q_sample_steps,
            )
            noise = noise.repeat(batch_size*bs_scale, 1, 1, 1).to(device=self.device, dtype=self.model.dtype)

        if inpainting_mask is not None:
            inpainting_mask = inpainting_mask.to(device=self.device, dtype=torch.long)

        if sample_loop == 'ddpm':
            with torch.no_grad():
                sample = diffusion.p_sample_loop(
                    model_fn,
                    (batch_size * bs_scale, 3, image_h, image_w),
                    noise=noise,
                    clip_denoised=True,
                    model_kwargs=model_kwargs,
                    dynamic_thresholding_p=dynamic_thresholding_p,
                    dynamic_thresholding_c=dynamic_thresholding_c,
                    inpainting_mask=inpainting_mask,
                    device=self.device,
                    progress=progress,
                    sample_fn=sample_fn,
                )[:batch_size]
        elif sample_loop == 'ddim':
            with torch.no_grad():
                sample = diffusion.ddim_sample_loop(
                    model_fn,
                    (batch_size * bs_scale, 3, image_h, image_w),
                    noise=noise,
                    clip_denoised=True,
                    model_kwargs=model_kwargs,
                    dynamic_thresholding_p=dynamic_thresholding_p,
                    dynamic_thresholding_c=dynamic_thresholding_c,
                    device=self.device,
                    progress=progress,
                    sample_fn=sample_fn,
                )[:batch_size]
        else:
            raise ValueError(f'Sample loop "{sample_loop}" doesnt support')

        sample = self.__validate_generations(sample)
        self._clear_cache()

        return sample, metadata

    def load_conf(self, dir_or_name, filename='config.yml'):
        path = self._get_path_or_download_file_from_hf(dir_or_name, filename)
        conf = OmegaConf.load(path)
        return conf

    def load_checkpoint(self, model, dir_or_name, filename='pytorch_model.bin'):
        path = self._get_path_or_download_file_from_hf(dir_or_name, filename)
        if os.path.exists(path):
            checkpoint = torch.load(path, map_location='cpu')
            param_device = 'cpu'
            for param_name, param in checkpoint.items():
                set_module_tensor_to_device(model, param_name, param_device, value=param)
        else:
            print(f'Warning! In directory "{dir_or_name}" filename "pytorch_model.bin" is not found.')
        return model

    def _get_path_or_download_file_from_hf(self, dir_or_name, filename):
        if dir_or_name in self.available_models:
            cache_dir = os.path.join(self.cache_dir, dir_or_name)
            hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir,
                            force_filename=filename, token=self.hf_token)
            return os.path.join(cache_dir, filename)
        else:
            return os.path.join(dir_or_name, filename)

    def get_diffusion(self, timestep_respacing):
        timestep_respacing = self.respacing_modes.get(timestep_respacing, timestep_respacing)
        diffusion = create_gaussian_diffusion(
            steps=1000,
            learn_sigma=True,
            sigma_small=False,
            noise_schedule='cosine',
            use_kl=False,
            predict_xstart=False,
            rescale_timesteps=True,
            rescale_learned_sigmas=True,
            timestep_respacing=timestep_respacing,
        )
        return diffusion

    @staticmethod
    def seed_everything(seed=None):
        if seed is None:
            seed = int((datetime.utcnow().timestamp() * 10 ** 6) % (2 ** 32 - 1))
        random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True
        return seed

    def device_name(self):
        if self.device.type == 'cpu':
            return 'cpu_' + str(platform.processor())
        if self.device.type == 'cuda':
            return torch.cuda.get_device_name(self.device)
        return '-'

    def to_images(self, generations, disable_watermark=False):
        bs, c, h, w = generations.shape
        coef = min(h / self.pil_img_size, w / self.pil_img_size)
        img_h, img_w = (int(h / coef), int(w / coef)) if coef < 1 else (h, w)

        S1, S2 = 1024 ** 2, img_w * img_h
        K = (S2 / S1) ** 0.5
        wm_size, wm_x, wm_y = int(K * 62), img_w - int(14 * K), img_h - int(14 * K)

        wm_img = self.wm_pil_img.resize(
            (wm_size, wm_size), getattr(Image, 'Resampling', Image).BICUBIC, reducing_gap=None)

        pil_images = []
        for image in ((generations + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu():
            pil_img = torchvision.transforms.functional.to_pil_image(image).convert('RGB')
            pil_img = pil_img.resize((img_w, img_h), getattr(Image, 'Resampling', Image).NEAREST)
            if not disable_watermark:
                pil_img.paste(wm_img, box=(wm_x - wm_size, wm_y - wm_size, wm_x, wm_y), mask=wm_img.split()[-1])
            pil_images.append(pil_img)
        return pil_images

    def show(self, pil_images, nrow=None, size=10):
        if nrow is None:
            nrow = round(len(pil_images)**0.5)

        imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow)
        if not isinstance(imgs, list):
            imgs = [imgs.cpu()]

        fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(size, size))
        for i, img in enumerate(imgs):
            img = img.detach()
            img = torchvision.transforms.functional.to_pil_image(img)
            axs[0, i].imshow(np.asarray(img))
            axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

        fix.show()
        plt.show()

    def _clear_cache(self):
        self.model.cache = None

    def _get_image_sizes(self, low_res, img_size, aspect_ratio, img_scale):
        if low_res is not None:
            bs, c, h, w = low_res.shape
            image_h, image_w = int((h*img_scale)//32)*32, int((w*img_scale//32))*32
        else:
            scale_w, scale_h = aspect_ratio.split(':')
            scale_w, scale_h = int(scale_w), int(scale_h)
            coef = scale_w / scale_h
            image_h, image_w = img_size, img_size
            if coef >= 1:
                image_w = int(round(img_size/8 * coef) * 8)
            else:
                image_h = int(round(img_size/8 / coef) * 8)

        assert image_h % 8 == 0
        assert image_w % 8 == 0

        return image_w, image_h

    def __validate_generations(self, generations):
        with torch.no_grad():
            imgs = clip_process_generations(generations)
            image_features = self.clip_model.encode_image(imgs.to('cpu'))
            image_features = image_features.detach().cpu().numpy().astype(np.float16)
            p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases)
            w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases)
            query = p_pred > self.p_threshold
            if query.sum() > 0:
                generations[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(generations[query])
            query = w_pred > self.w_threshold
            if query.sum() > 0:
                generations[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(generations[query])
        return generations


================================================
FILE: deepfloyd_if/modules/stage_I.py
================================================
# -*- coding: utf-8 -*-
import accelerate

from .base import IFBaseModule
from ..model import UNetModel


class IFStageI(IFBaseModule):
    stage = 'I'
    available_models = ['IF-I-M-v1.0', 'IF-I-L-v1.0', 'IF-I-XL-v1.0']

    def __init__(self, *args, model_kwargs=None, pil_img_size=64, **kwargs):
        """
        :param conf_or_path:
        :param device:
        :param cache_dir:
        :param use_auth_token:
        """
        super().__init__(*args, pil_img_size=pil_img_size, **kwargs)
        model_params = dict(self.conf.params)
        model_params.update(model_kwargs or {})
        with accelerate.init_empty_weights():
            self.model = UNetModel(**model_params)
        self.model = self.load_checkpoint(self.model, self.dir_or_name)
        self.model.eval().to(self.device)

    def embeddings_to_image(self, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None,
                            batch_repeat=1, dynamic_thresholding_p=0.95, sample_loop='ddpm', positive_mixer=0.25,
                            sample_timestep_respacing='150', dynamic_thresholding_c=1.5, guidance_scale=7.0,
                            aspect_ratio='1:1', progress=True, seed=None, sample_fn=None, **kwargs):

        return super().embeddings_to_image(
            t5_embs=t5_embs,
            style_t5_embs=style_t5_embs,
            positive_t5_embs=positive_t5_embs,
            negative_t5_embs=negative_t5_embs,
            batch_repeat=batch_repeat,
            dynamic_thresholding_p=dynamic_thresholding_p,
            dynamic_thresholding_c=dynamic_thresholding_c,
            sample_loop=sample_loop,
            sample_timestep_respacing=sample_timestep_respacing,
            guidance_scale=guidance_scale,
            img_size=64,
            aspect_ratio=aspect_ratio,
            progress=progress,
            seed=seed,
            sample_fn=sample_fn,
            positive_mixer=positive_mixer,
            **kwargs
        )


================================================
FILE: deepfloyd_if/modules/stage_II.py
================================================
# -*- coding: utf-8 -*-
import accelerate

from .base import IFBaseModule
from ..model import SuperResUNetModel


class IFStageII(IFBaseModule):
    stage = 'II'
    available_models = ['IF-II-M-v1.0', 'IF-II-L-v1.0']

    def __init__(self, *args, model_kwargs=None, pil_img_size=256, **kwargs):
        super().__init__(*args, pil_img_size=pil_img_size, **kwargs)
        model_params = dict(self.conf.params)
        model_params.update(model_kwargs or {})
        with accelerate.init_empty_weights():
            self.model = SuperResUNetModel(low_res_diffusion=self.get_diffusion('1000'), **model_params)
        self.model = self.load_checkpoint(self.model, self.dir_or_name)
        self.model.eval().to(self.device)

    def embeddings_to_image(
            self, low_res, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1,
            aug_level=0.25, dynamic_thresholding_p=0.95, dynamic_thresholding_c=1.0, sample_loop='ddpm',
            sample_timestep_respacing='smart50', guidance_scale=4.0, img_scale=4.0, positive_mixer=0.5,
            progress=True, seed=None, sample_fn=None, **kwargs):
        return super().embeddings_to_image(
            t5_embs=t5_embs,
            low_res=low_res,
            style_t5_embs=style_t5_embs,
            positive_t5_embs=positive_t5_embs,
            negative_t5_embs=negative_t5_embs,
            batch_repeat=batch_repeat,
            aug_level=aug_level,
            dynamic_thresholding_p=dynamic_thresholding_p,
            dynamic_thresholding_c=dynamic_thresholding_c,
            sample_loop=sample_loop,
            sample_timestep_respacing=sample_timestep_respacing,
            guidance_scale=guidance_scale,
            positive_mixer=positive_mixer,
            img_size=256,
            img_scale=img_scale,
            progress=progress,
            seed=seed,
            sample_fn=sample_fn,
            **kwargs
        )


================================================
FILE: deepfloyd_if/modules/stage_III.py
================================================
# -*- coding: utf-8 -*-
import accelerate

from .base import IFBaseModule
from ..model import SuperResUNetModel


class IFStageIII(IFBaseModule):

    available_models = ['IF-III-L-v1.0']

    def __init__(self, *args, model_kwargs=None, pil_img_size=1024, **kwargs):
        super().__init__(*args, pil_img_size=pil_img_size, **kwargs)
        model_params = dict(self.conf.params)
        model_params.update(model_kwargs or {})
        with accelerate.init_empty_weights():
            self.model = SuperResUNetModel(low_res_diffusion=self.get_diffusion('1000'), **model_params)
        self.model = self.load_checkpoint(self.model, self.dir_or_name)
        self.model.eval().to(self.device)

    def embeddings_to_image(
            self, low_res, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1,
            aug_level=0.0, blur_sigma=None, dynamic_thresholding_p=0.95, dynamic_thresholding_c=1.0, positive_mixer=0.5,
            sample_loop='ddpm', sample_timestep_respacing='super40', guidance_scale=4.0, img_scale=4.0,
            progress=True, seed=None, sample_fn=None, **kwargs):
        return super().embeddings_to_image(
            t5_embs=t5_embs,
            low_res=low_res,
            style_t5_embs=style_t5_embs,
            positive_t5_embs=positive_t5_embs,
            negative_t5_embs=negative_t5_embs,
            batch_repeat=batch_repeat,
            aug_level=aug_level,
            blur_sigma=blur_sigma,
            dynamic_thresholding_p=dynamic_thresholding_p,
            dynamic_thresholding_c=dynamic_thresholding_c,
            sample_loop=sample_loop,
            sample_timestep_respacing=sample_timestep_respacing,
            guidance_scale=guidance_scale,
            positive_mixer=positive_mixer,
            img_size=1024,
            img_scale=img_scale,
            progress=progress,
            seed=seed,
            sample_fn=sample_fn,
            **kwargs
        )


================================================
FILE: deepfloyd_if/modules/stage_III_sd_x4.py
================================================
# -*- coding: utf-8 -*-
import diffusers
from diffusers import DiffusionPipeline, DDPMScheduler
import torch
import os

from .base import IFBaseModule
import packaging.version as pv


class StableStageIII(IFBaseModule):

    available_models = ['stable-diffusion-x4-upscaler']

    def __init__(self, *args, model_kwargs=None, pil_img_size=1024, **kwargs):
        super().__init__(*args, pil_img_size=pil_img_size, **kwargs)
        if pv.parse(diffusers.__version__) <= pv.parse('0.15.1'):
            raise ValueError(
                'Make sure to have `diffusers >= 0.16.0` installed.'
                ' Please run `pip install diffusers --upgrade`'
            )

        model_id = os.path.join('stabilityai', self.dir_or_name)

        model_kwargs = model_kwargs or {}
        precision = str(model_kwargs.get('precision', '16'))
        if precision == '16':
            torch_dtype = torch.float16
        elif precision == 'bf16':
            torch_dtype = torch.bfloat16
        else:
            torch_dtype = torch.float32

        self.model = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype, token=self.hf_token)
        self.model.to(self.device)

        if bool(os.environ.get('FORCE_MEM_EFFICIENT_ATTN')):
            self.model.enable_xformers_memory_efficient_attention()

    @property
    def use_diffusers(self):
        if self.dir_or_name == self.available_models[-1]:
            return True
        elif os.path.isdir(self.dir_or_name) and os.path.isfile(os.path.join(self.dir_or_name, 'model_index.json')):
            return True
        return False

    def embeddings_to_image(
            self, low_res, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1,
            aug_level=0.0, blur_sigma=None, dynamic_thresholding_p=0.95, dynamic_thresholding_c=1.0, positive_mixer=0.5,
            sample_loop='ddpm', sample_timestep_respacing='75', guidance_scale=4.0, img_scale=4.0,
            progress=True, seed=None, sample_fn=None, **kwargs):

        prompt = kwargs.pop('prompt')
        noise_level = kwargs.pop('noise_level', 20)

        if sample_loop == 'ddpm':
            self.model.scheduler = DDPMScheduler.from_config(self.model.scheduler.config)
        else:
            raise ValueError(f"For now only the 'ddpm' sample loop type is supported, but you passed {sample_loop}")

        num_inference_steps = int(sample_timestep_respacing)

        self.model.set_progress_bar_config(disable=not progress)

        generator = torch.manual_seed(seed)
        prompt = sum([batch_repeat * [p] for p in prompt], [])
        low_res = low_res.repeat(batch_repeat, 1, 1, 1)

        metadata = {
            'image': low_res,
            'prompt': prompt,
            'noise_level': noise_level,
            'generator': generator,
            'guidance_scale': guidance_scale,
            'num_inference_steps': num_inference_steps,
            'output_type': 'pt',
        }

        images = self.model(**metadata).images

        sample = self._IFBaseModule__validate_generations(images)

        return sample, metadata


================================================
FILE: deepfloyd_if/modules/t5.py
================================================
# -*- coding: utf-8 -*-
import os
import re
import html
import urllib.parse as ul

import ftfy
import torch
from bs4 import BeautifulSoup
from transformers import T5EncoderModel, AutoTokenizer
from huggingface_hub import hf_hub_download


class T5Embedder:

    available_models = ['t5-v1_1-xxl']
    bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}')  # noqa

    def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, cache_dir=None, hf_token=None, use_text_preprocessing=True,
                 t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None):
        self.device = torch.device(device)
        self.torch_dtype = torch_dtype or torch.bfloat16
        if t5_model_kwargs is None:
            t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype}
            if use_offload_folder is not None:
                t5_model_kwargs['offload_folder'] = use_offload_folder
                t5_model_kwargs['device_map'] = {
                    'shared': self.device,
                    'encoder.embed_tokens': self.device,
                    'encoder.block.0': self.device,
                    'encoder.block.1': self.device,
                    'encoder.block.2': self.device,
                    'encoder.block.3': self.device,
                    'encoder.block.4': self.device,
                    'encoder.block.5': self.device,
                    'encoder.block.6': self.device,
                    'encoder.block.7': self.device,
                    'encoder.block.8': self.device,
                    'encoder.block.9': self.device,
                    'encoder.block.10': self.device,
                    'encoder.block.11': self.device,
                    'encoder.block.12': 'disk',
                    'encoder.block.13': 'disk',
                    'encoder.block.14': 'disk',
                    'encoder.block.15': 'disk',
                    'encoder.block.16': 'disk',
                    'encoder.block.17': 'disk',
                    'encoder.block.18': 'disk',
                    'encoder.block.19': 'disk',
                    'encoder.block.20': 'disk',
                    'encoder.block.21': 'disk',
                    'encoder.block.22': 'disk',
                    'encoder.block.23': 'disk',
                    'encoder.final_layer_norm': 'disk',
                    'encoder.dropout': 'disk',
                }
            else:
                t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device}

        self.use_text_preprocessing = use_text_preprocessing
        self.hf_token = hf_token
        self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_')
        self.dir_or_name = dir_or_name

        tokenizer_path, path = dir_or_name, dir_or_name
        if dir_or_name in self.available_models:
            cache_dir = os.path.join(self.cache_dir, dir_or_name)
            for filename in [
                'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
                'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin'
            ]:
                hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir,
                                force_filename=filename, token=self.hf_token)
            tokenizer_path, path = cache_dir, cache_dir
        else:
            cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl')
            for filename in [
                'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
            ]:
                hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir,
                                force_filename=filename, token=self.hf_token)
            tokenizer_path = cache_dir

        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()

    def get_text_embeddings(self, texts):
        texts = [self.text_preprocessing(text) for text in texts]

        text_tokens_and_mask = self.tokenizer(
            texts,
            max_length=77,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors='pt'
        )
        text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids']
        text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask']

        with torch.no_grad():
            text_encoder_embs = self.model(
                input_ids=text_tokens_and_mask['input_ids'].to(self.device),
                attention_mask=text_tokens_and_mask['attention_mask'].to(self.device),
            )['last_hidden_state'].detach()

        return text_encoder_embs

    def text_preprocessing(self, text):
        if self.use_text_preprocessing:
            # The exact text cleaning as was in the training stage:
            text = self.clean_caption(text)
            text = self.clean_caption(text)
            return text
        else:
            return text.lower().strip()

    @staticmethod
    def basic_clean(text):
        text = ftfy.fix_text(text)
        text = html.unescape(html.unescape(text))
        return text.strip()

    def clean_caption(self, caption):
        caption = str(caption)
        caption = ul.unquote_plus(caption)
        caption = caption.strip().lower()
        caption = re.sub('<person>', 'person', caption)
        # urls:
        caption = re.sub(
            r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))',  # noqa
            '', caption)  # regex for urls
        caption = re.sub(
            r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))',  # noqa
            '', caption)  # regex for urls
        # html:
        caption = BeautifulSoup(caption, features='html.parser').text

        # @<nickname>
        caption = re.sub(r'@[\w\d]+\b', '', caption)

        # 31C0—31EF CJK Strokes
        # 31F0—31FF Katakana Phonetic Extensions
        # 3200—32FF Enclosed CJK Letters and Months
        # 3300—33FF CJK Compatibility
        # 3400—4DBF CJK Unified Ideographs Extension A
        # 4DC0—4DFF Yijing Hexagram Symbols
        # 4E00—9FFF CJK Unified Ideographs
        caption = re.sub(r'[\u31c0-\u31ef]+', '', caption)
        caption = re.sub(r'[\u31f0-\u31ff]+', '', caption)
        caption = re.sub(r'[\u3200-\u32ff]+', '', caption)
        caption = re.sub(r'[\u3300-\u33ff]+', '', caption)
        caption = re.sub(r'[\u3400-\u4dbf]+', '', caption)
        caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption)
        caption = re.sub(r'[\u4e00-\u9fff]+', '', caption)
        #######################################################

        # все виды тире / all types of dash --> "-"
        caption = re.sub(
            r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+',  # noqa
            '-', caption)

        # кавычки к одному стандарту
        caption = re.sub(r'[`´«»“”¨]', '"', caption)
        caption = re.sub(r'[‘’]', "'", caption)

        # &quot;
        caption = re.sub(r'&quot;?', '', caption)
        # &amp
        caption = re.sub(r'&amp', '', caption)

        # ip adresses:
        caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption)

        # article ids:
        caption = re.sub(r'\d:\d\d\s+$', '', caption)

        # \n
        caption = re.sub(r'\\n', ' ', caption)

        # "#123"
        caption = re.sub(r'#\d{1,3}\b', '', caption)
        # "#12345.."
        caption = re.sub(r'#\d{5,}\b', '', caption)
        # "123456.."
        caption = re.sub(r'\b\d{6,}\b', '', caption)
        # filenames:
        caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption)

        #
        caption = re.sub(r'[\"\']{2,}', r'"', caption)  # """AUSVERKAUFT"""
        caption = re.sub(r'[\.]{2,}', r' ', caption)  # """AUSVERKAUFT"""

        caption = re.sub(self.bad_punct_regex, r' ', caption)  # ***AUSVERKAUFT***, #AUSVERKAUFT
        caption = re.sub(r'\s+\.\s+', r' ', caption)  # " . "

        # this-is-my-cute-cat / this_is_my_cute_cat
        regex2 = re.compile(r'(?:\-|\_)')
        if len(re.findall(regex2, caption)) > 3:
            caption = re.sub(regex2, ' ', caption)

        caption = self.basic_clean(caption)

        caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption)  # jc6640
        caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption)  # jc6640vc
        caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption)  # 6640vc231

        caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption)
        caption = re.sub(r'(free\s)?download(\sfree)?', '', caption)
        caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption)
        caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption)
        caption = re.sub(r'\bpage\s+\d+\b', '', caption)

        caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption)  # j2d1a2a...

        caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption)

        caption = re.sub(r'\b\s+\:\s+', r': ', caption)
        caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption)
        caption = re.sub(r'\s+', ' ', caption)

        caption.strip()

        caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption)
        caption = re.sub(r'^[\'\_,\-\:;]', r'', caption)
        caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption)
        caption = re.sub(r'^\.\S+$', '', caption)

        return caption.strip()


================================================
FILE: deepfloyd_if/modules/utils.py
================================================
# -*- coding: utf-8 -*-
import numpy as np
import torchvision.transforms as T


def predict_proba(X, weights, biases):
    logits = X @ weights.T + biases
    proba = np.where(logits >= 0, 1 / (1 + np.exp(-logits)), np.exp(logits) / (1 + np.exp(logits)))
    return proba.T


def load_model_weights(path):
    model_weights = np.load(path)
    return model_weights['weights'], model_weights['biases']


def clip_process_generations(generations):
    min_size = min(generations.shape[-2:])
    return T.Compose([
        T.CenterCrop(min_size),
        T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True),
        T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])(generations)


================================================
FILE: deepfloyd_if/pipelines/__init__.py
================================================
# -*- coding: utf-8 -*-
from .dream import dream
from .style_transfer import style_transfer
from .super_resolution import super_resolution
from .inpainting import inpainting

__all__ = ['dream', 'style_transfer', 'super_resolution', 'inpainting']


================================================
FILE: deepfloyd_if/pipelines/dream.py
================================================
# -*- coding: utf-8 -*-
from datetime import datetime

import torch


def dream(
    t5,
    if_I,
    if_II=None,
    if_III=None,
    *,
    prompt,
    style_prompt=None,
    negative_prompt=None,
    seed=None,
    aspect_ratio='1:1',
    if_I_kwargs=None,
    if_II_kwargs=None,
    if_III_kwargs=None,
    progress=True,
    return_tensors=False,
    disable_watermark=False,
):
    """
    Generate pictures using text description!

    :param optional dict if_I_kwargs:
        "dynamic_thresholding_p": 0.95, [0.5, 1.0] it controls color saturation on high cfg values
        "dynamic_thresholding_c": 1.5, [1.0, 15.0] clips the limiter to avoid greyish images on high limiter values
        "guidance_scale": 7.0, [1.0, 20.0] control the level of text understanding
        "positive_mixer": 0.25, [0.0, 1.0] contribution of the second positive prompt, 0.0 - minimum, 1.0 - maximum
        "sample_timestep_respacing": "150", see available modes IFBaseModule.respacing_modes or use custom

    :param optional dict if_II_kwargs:
        "dynamic_thresholding_p": 0.95, [0.5, 1.0] it controls color saturation on high cfg values
        "dynamic_thresholding_c": 1.0, [1.0, 15.0] clips the limiter to avoid greyish images on high limiter values
        "guidance_scale": 4.0, [1.0, 20.0] control the amount of texture and details in the final image
        "aug_level": 0.25, [0.0, 1.0] adds additional augmentation to generate more realistic images
        "positive_mixer": 0.5, [0.0, 1.0] contribution of the second positive prompt, 0.0 - minimum, 1.0 - maximum
        "sample_timestep_respacing": "smart50", see available modes IFBaseModule.respacing_modes or use custom

    :param deepfloyd_if.modules.IFStageI if_I: obj
    :param deepfloyd_if.modules.IFStageII if_II: obj
    :param deepfloyd_if.modules.IFStageIII if_III: obj
    :param deepfloyd_if.modules.T5Embedder t5: obj

    :param int seed: int, in case None will use random value
    :param aspect_ratio:
    :param str prompt: text hint/description
    :param str style_prompt: text hint/description for style
    :param str negative_prompt: text hint/description for negative prompt, will use it as unconditional emb
    :param progress:
    :return:
    """
    if seed is None:
        seed = int((datetime.utcnow().timestamp() * 10 ** 6) % (2 ** 32 - 1))
    if_I.seed_everything(seed)

    if isinstance(prompt, str):
        prompt = [prompt]

    t5_embs = t5.get_text_embeddings(prompt)

    if_I_kwargs = if_I_kwargs or {}
    if_I_kwargs['seed'] = seed
    if_I_kwargs['t5_embs'] = t5_embs
    if_I_kwargs['aspect_ratio'] = aspect_ratio
    if_I_kwargs['progress'] = progress

    if style_prompt is not None:
        if isinstance(style_prompt, str):
            style_prompt = [style_prompt]
        style_t5_embs = t5.get_text_embeddings(style_prompt)
        if_I_kwargs['style_t5_embs'] = style_t5_embs
        if_I_kwargs['positive_t5_embs'] = style_t5_embs

    if negative_prompt is not None:
        if isinstance(negative_prompt, str):
            negative_prompt = [negative_prompt]
        negative_t5_embs = t5.get_text_embeddings(negative_prompt)
        if_I_kwargs['negative_t5_embs'] = negative_t5_embs

    stageI_generations, _ = if_I.embeddings_to_image(**if_I_kwargs)
    pil_images_I = if_I.to_images(stageI_generations, disable_watermark=disable_watermark)

    result = {'I': pil_images_I}

    if if_II is not None:
        if_II_kwargs = if_II_kwargs or {}
        if_II_kwargs['low_res'] = stageI_generations
        if_II_kwargs['seed'] = seed
        if_II_kwargs['t5_embs'] = t5_embs
        if_II_kwargs['progress'] = progress
        if_II_kwargs['style_t5_embs'] = if_I_kwargs.get('style_t5_embs')
        if_II_kwargs['positive_t5_embs'] = if_I_kwargs.get('positive_t5_embs')

        stageII_generations, _meta = if_II.embeddings_to_image(**if_II_kwargs)
        pil_images_II = if_II.to_images(stageII_generations, disable_watermark=disable_watermark)

        result['II'] = pil_images_II
    else:
        stageII_generations = None

    if if_II is not None and if_III is not None:
        if_III_kwargs = if_III_kwargs or {}

        stageIII_generations = []
        for idx in range(len(stageII_generations)):
            if if_III.use_diffusers:
                if_III_kwargs['prompt'] = prompt[idx: idx+1]

            if_III_kwargs['low_res'] = stageII_generations[idx:idx+1]
            if_III_kwargs['seed'] = seed
            if_III_kwargs['t5_embs'] = t5_embs[idx:idx+1]
            if_III_kwargs['progress'] = progress
            style_t5_embs = if_I_kwargs.get('style_t5_embs')
            if style_t5_embs is not None:
                style_t5_embs = style_t5_embs[idx:idx+1]
            positive_t5_embs = if_I_kwargs.get('positive_t5_embs')
            if positive_t5_embs is not None:
                positive_t5_embs = positive_t5_embs[idx:idx+1]
            if_III_kwargs['style_t5_embs'] = style_t5_embs
            if_III_kwargs['positive_t5_embs'] = positive_t5_embs

            _stageIII_generations, _meta = if_III.embeddings_to_image(**if_III_kwargs)
            stageIII_generations.append(_stageIII_generations)

        stageIII_generations = torch.cat(stageIII_generations, 0)
        pil_images_III = if_III.to_images(stageIII_generations, disable_watermark=disable_watermark)

        result['III'] = pil_images_III
    else:
        stageIII_generations = None

    if return_tensors:
        return result, (stageI_generations, stageII_generations, stageIII_generations)
    else:
        return result


================================================
FILE: deepfloyd_if/pipelines/inpainting.py
================================================
# -*- coding: utf-8 -*-
from datetime import datetime

import PIL
import torch

from .utils import _prepare_pil_image


def inpainting(
    t5,
    if_I,
    if_II=None,
    if_III=None,
    *,
    support_pil_img,
    prompt,
    inpainting_mask,
    negative_prompt=None,
    seed=None,
    if_I_kwargs=None,
    if_II_kwargs=None,
    if_III_kwargs=None,
    progress=True,
    return_tensors=False,
    disable_watermark=False,
):
    from skimage.transform import resize  # noqa
    from skimage import img_as_bool  # noqa
    assert isinstance(support_pil_img, PIL.Image.Image)

    if seed is None:
        seed = int((datetime.utcnow().timestamp() * 10 ** 6) % (2 ** 32 - 1))

    t5_embs = t5.get_text_embeddings(prompt)

    if negative_prompt is not None:
        if isinstance(negative_prompt, str):
            negative_prompt = [negative_prompt]
        negative_t5_embs = t5.get_text_embeddings(negative_prompt)
    else:
        negative_t5_embs = None

    low_res = _prepare_pil_image(support_pil_img, 64)
    mid_res = _prepare_pil_image(support_pil_img, 256)
    high_res = _prepare_pil_image(support_pil_img, 1024)

    result = {}

    _, _, image_h, image_w = low_res.shape
    if_I_kwargs = if_I_kwargs or {}
    if_I_kwargs['seed'] = seed
    if_I_kwargs['progress'] = progress
    if_I_kwargs['aspect_ratio'] = f'{image_w}:{image_h}'

    if_I_kwargs['t5_embs'] = t5_embs
    if_I_kwargs['negative_t5_embs'] = negative_t5_embs

    if_I_kwargs['support_noise'] = low_res

    inpainting_mask_I = img_as_bool(resize(inpainting_mask[0].cpu(), (3, image_h, image_w)))
    inpainting_mask_I = torch.from_numpy(inpainting_mask_I).unsqueeze(0).to(if_I.device)

    if_I_kwargs['inpainting_mask'] = inpainting_mask_I

    stageI_generations, _ = if_I.embeddings_to_image(**if_I_kwargs)
    pil_images_I = if_I.to_images(stageI_generations, disable_watermark=disable_watermark)

    result['I'] = pil_images_I

    if if_II is not None:
        _, _, image_h, image_w = mid_res.shape

        if_II_kwargs = if_II_kwargs or {}
        if_II_kwargs['low_res'] = stageI_generations
        if_II_kwargs['seed'] = seed
        if_II_kwargs['t5_embs'] = t5_embs
        if_II_kwargs['negative_t5_embs'] = negative_t5_embs
        if_II_kwargs['progress'] = progress

        if_II_kwargs['support_noise'] = mid_res

        if 'inpainting_mask' not in if_II_kwargs:
            inpainting_mask_II = img_as_bool(resize(inpainting_mask[0].cpu(), (3, image_h, image_w)))
            inpainting_mask_II = torch.from_numpy(inpainting_mask_II).unsqueeze(0).to(if_II.device)
            if_II_kwargs['inpainting_mask'] = inpainting_mask_II

        stageII_generations, _meta = if_II.embeddings_to_image(**if_II_kwargs)
        pil_images_II = if_II.to_images(stageII_generations, disable_watermark=disable_watermark)

        result['II'] = pil_images_II
    else:
        stageII_generations = None

    if if_II is not None and if_III is not None:
        _, _, image_h, image_w = high_res.shape
        if_III_kwargs = if_III_kwargs or {}

        stageIII_generations = []
        for idx in range(len(stageII_generations)):
            if if_III.use_diffusers:
                if_III_kwargs['prompt'] = prompt[idx: idx+1]

            if_III_kwargs['low_res'] = stageII_generations[idx:idx+1]
            if_III_kwargs['seed'] = seed
            if_III_kwargs['t5_embs'] = t5_embs[idx:idx+1]
            if negative_t5_embs is not None:
                if_III_kwargs['negative_t5_embs'] = negative_t5_embs[idx:idx+1]
            if_III_kwargs['progress'] = progress
            if_III_kwargs['support_noise'] = high_res

            if 'inpainting_mask' not in if_III_kwargs:
                inpainting_mask_III = img_as_bool(resize(inpainting_mask[0].cpu(), (3, image_h, image_w)))
                inpainting_mask_III = torch.from_numpy(inpainting_mask_III).unsqueeze(0).to(if_III.device)
                if_III_kwargs['inpainting_mask'] = inpainting_mask_III

            _stageIII_generations, _meta = if_III.embeddings_to_image(**if_III_kwargs)
            stageIII_generations.append(_stageIII_generations)

        stageIII_generations = torch.cat(stageIII_generations, 0)
        pil_images_III = if_III.to_images(stageIII_generations, disable_watermark=disable_watermark)

        result['III'] = pil_images_III
    else:
        stageIII_generations = None

    if return_tensors:
        return result, (stageI_generations, stageII_generations, stageIII_generations)
    else:
        return result


================================================
FILE: deepfloyd_if/pipelines/style_transfer.py
================================================
# -*- coding: utf-8 -*-

from datetime import datetime

import PIL
import torch

from .utils import _prepare_pil_image


def style_transfer(
    t5,
    if_I,
    if_II,
    if_III=None,
    *,
    support_pil_img,
    style_prompt,
    prompt=None,
    negative_prompt=None,
    seed=None,
    if_I_kwargs=None,
    if_II_kwargs=None,
    if_III_kwargs=None,
    progress=True,
    return_tensors=False,
    disable_watermark=False,
):
    assert isinstance(support_pil_img, PIL.Image.Image)

    bs = len(style_prompt)

    if seed is None:
        seed = int((datetime.utcnow().timestamp() * 10 ** 6) % (2 ** 32 - 1))

    if prompt is not None:
        t5_embs = t5.get_text_embeddings(prompt)
    else:
        t5_embs = t5.get_text_embeddings(style_prompt)

    style_t5_embs = t5.get_text_embeddings(style_prompt)

    if negative_prompt is not None:
        if isinstance(negative_prompt, str):
            negative_prompt = [negative_prompt]
        negative_t5_embs = t5.get_text_embeddings(negative_prompt)
    else:
        negative_t5_embs = None

    low_res = _prepare_pil_image(support_pil_img, 64)
    mid_res = _prepare_pil_image(support_pil_img, 256)
    # high_res = _prepare_pil_image(support_pil_img, 1024)

    result = {}
    if if_I is not None:
        _, _, image_h, image_w = low_res.shape
        if_I_kwargs = if_I_kwargs or {'sample_timestep_respacing': '20,20,20,20,10,0,0,0,0,0'}
        if_I_kwargs['seed'] = seed
        if_I_kwargs['progress'] = progress
        if_I_kwargs['aspect_ratio'] = f'{image_w}:{image_h}'

        if_I_kwargs['t5_embs'] = t5_embs
        if_I_kwargs['style_t5_embs'] = style_t5_embs
        if_I_kwargs['positive_t5_embs'] = style_t5_embs
        if_I_kwargs['negative_t5_embs'] = negative_t5_embs

        if_I_kwargs['support_noise'] = low_res

        stageI_generations, _ = if_I.embeddings_to_image(**if_I_kwargs)
        pil_images_I = if_I.to_images(stageI_generations, disable_watermark=disable_watermark)

        result['I'] = pil_images_I
    else:
        stageI_generations = None

    if if_II is not None:
        if stageI_generations is None:
            stageI_generations = low_res.repeat(bs, 1, 1, 1)

        if_II_kwargs = if_II_kwargs or {}
        if_II_kwargs['low_res'] = stageI_generations
        if_II_kwargs['seed'] = seed
        if_II_kwargs['t5_embs'] = t5_embs
        if_II_kwargs['style_t5_embs'] = style_t5_embs
        if_II_kwargs['positive_t5_embs'] = style_t5_embs
        if_II_kwargs['negative_t5_embs'] = negative_t5_embs
        if_II_kwargs['progress'] = progress

        if_II_kwargs['support_noise'] = mid_res

        stageII_generations, _meta = if_II.embeddings_to_image(**if_II_kwargs)
        pil_images_II = if_II.to_images(stageII_generations, disable_watermark=disable_watermark)

        result['II'] = pil_images_II
    else:
        stageII_generations = None

    if if_II is not None and if_III is not None:
        if_III_kwargs = if_III_kwargs or {}

        stageIII_generations = []
        for idx in range(len(stageII_generations)):
            if if_III.use_diffusers:
                if_III_kwargs['prompt'] = prompt[idx: idx+1] if prompt is not None else style_prompt[idx: idx+1]

            if_III_kwargs['low_res'] = stageII_generations[idx:idx+1]
            if_III_kwargs['seed'] = seed
            if_III_kwargs['t5_embs'] = t5_embs[idx:idx+1]
            if_III_kwargs['progress'] = progress
            style_t5_embs = if_II_kwargs.get('style_t5_embs')
            if style_t5_embs is not None:
                style_t5_embs = style_t5_embs[idx:idx+1]
            positive_t5_embs = if_II_kwargs.get('positive_t5_embs')
            if positive_t5_embs is not None:
                positive_t5_embs = positive_t5_embs[idx:idx+1]
            if_III_kwargs['style_t5_embs'] = style_t5_embs
            if_III_kwargs['positive_t5_embs'] = positive_t5_embs

            _stageIII_generations, _meta = if_III.embeddings_to_image(**if_III_kwargs)
            stageIII_generations.append(_stageIII_generations)

        stageIII_generations = torch.cat(stageIII_generations, 0)
        pil_images_III = if_III.to_images(stageIII_generations, disable_watermark=disable_watermark)

        result['III'] = pil_images_III
    else:
        stageIII_generations = None

    if return_tensors:
        return result, (stageI_generations, stageII_generations, stageIII_generations)
    else:
        return result


================================================
FILE: deepfloyd_if/pipelines/super_resolution.py
================================================
# -*- coding: utf-8 -*-

from datetime import datetime

import PIL
from .utils import _prepare_pil_image


def super_resolution(
    t5,
    if_III=None,
    *,
    support_pil_img,
    prompt=None,
    negative_prompt=None,
    seed=None,
    if_III_kwargs=None,
    progress=True,
    img_size=256,
    img_scale=4.0,
    return_tensors=False,
    disable_watermark=False,
):
    assert isinstance(support_pil_img, PIL.Image.Image)
    assert img_size % 8 == 0

    if seed is None:
        seed = int((datetime.utcnow().timestamp() * 10 ** 6) % (2 ** 32 - 1))

    if prompt is not None:
        t5_embs = t5.get_text_embeddings(prompt)
    else:
        t5_embs = t5.get_text_embeddings('')

    if negative_prompt is not None:
        if isinstance(negative_prompt, str):
            negative_prompt = [negative_prompt]
        negative_t5_embs = t5.get_text_embeddings(negative_prompt)
    else:
        negative_t5_embs = None

    low_res = _prepare_pil_image(support_pil_img, img_size)

    result = {}

    bs = 1
    if_III_kwargs = if_III_kwargs or {}

    if if_III.use_diffusers:
        if_III_kwargs['prompt'] = prompt

    if_III_kwargs['low_res'] = low_res.repeat(bs, 1, 1, 1)
    if_III_kwargs['seed'] = seed
    if_III_kwargs['t5_embs'] = t5_embs
    if_III_kwargs['negative_t5_embs'] = negative_t5_embs
    if_III_kwargs['progress'] = progress
    if_III_kwargs['img_scale'] = img_scale

    stageIII_generations, _meta = if_III.embeddings_to_image(**if_III_kwargs)
    pil_images_III = if_III.to_images(stageIII_generations, disable_watermark=disable_watermark)
    result['III'] = pil_images_III

    if return_tensors:
        return result, (stageIII_generations,)
    else:
        return result


================================================
FILE: deepfloyd_if/pipelines/utils.py
================================================
# -*- coding: utf-8 -*-

import torch
import numpy as np
from PIL import Image


def _prepare_pil_image(raw_pil_img, img_size):
    raw_pil_img = raw_pil_img.convert('RGB')
    w, h = raw_pil_img.size
    coef = w / h
    image_h, image_w = img_size, img_size
    if coef >= 1:
        image_w = int(round(img_size / 8 * coef) * 8)
    else:
        image_h = int(round(img_size / 8 / coef) * 8)

    pil_img = raw_pil_img.resize(
        (image_w, image_h), resample=getattr(Image, 'Resampling', Image).BICUBIC, reducing_gap=None
    )
    img = np.array(pil_img)
    img = img.astype(np.float32) / 127.5 - 1
    img = np.transpose(img, [2, 0, 1])
    img = torch.from_numpy(img).unsqueeze(0)
    return img


================================================
FILE: deepfloyd_if/utils.py
================================================
# -*- coding: utf-8 -*-
from os.path import abspath, dirname, join

import torch
import numpy as np
from PIL import Image, ImageFilter

RESOURCES_ROOT = join(abspath(dirname(__file__)), 'resources')


def drop_shadow(image, offset=(5, 5), background=0xffffff, shadow=0x444444, border=8, iterations=3):
    """
    Drop shadows with PIL.
    Author: Kevin Schluff
    License: Python license
    https://code.activestate.com/recipes/474116/

    Add a gaussian blur drop shadow to an image.

    image       - The image to overlay on top of the shadow.
    offset      - Offset of the shadow from the image as an (x,y) tuple.  Can be
                  positive or negative.
    background  - Background colour behind the image.
    shadow      - Shadow colour (darkness).
    border      - Width of the border around the image.  This must be wide
                  enough to account for the blurring of the shadow.
    iterations  - Number of times to apply the filter.  More iterations
                  produce a more blurred shadow, but increase processing time.
    """

    # Create the backdrop image -- a box in the background colour with a
    # shadow on it.
    total_width = image.size[0] + abs(offset[0]) + 2 * border
    total_height = image.size[1] + abs(offset[1]) + 2 * border
    back = Image.new(image.mode, (total_width, total_height), background)

    # Place the shadow, taking into account the offset from the image
    shadow_left = border + max(offset[0], 0)
    shadow_top = border + max(offset[1], 0)
    back.paste(shadow, [shadow_left, shadow_top, shadow_left + image.size[0], shadow_top + image.size[1]])

    # Apply the filter to blur the edges of the shadow.  Since a small kernel
    # is used, the filter must be applied repeatedly to get a decent blur.
    n = 0
    while n < iterations:
        back = back.filter(ImageFilter.BLUR)
        n += 1

    # Paste the input image onto the shadow backdrop
    image_left = border - min(offset[0], 0)
    image_top = border - min(offset[1], 0)
    back.paste(image, (image_left, image_top))
    return back


def pil_list_to_torch_tensors(pil_images):
    result = []
    for pil_image in pil_images:
        image = np.array(pil_image, dtype=np.uint8)
        image = torch.from_numpy(image)
        image = image.permute(2, 0, 1).unsqueeze(0)
        result.append(image)
    return torch.cat(result, dim=0)


================================================
FILE: requirements-dev.txt
================================================
-r requirements-test.txt
pre-commit


================================================
FILE: requirements-test.txt
================================================
-r requirements.txt
pytest
pytest-cov


================================================
FILE: requirements.txt
================================================
tqdm
numpy
torch<2.0.0
torchvision
omegaconf
matplotlib
Pillow>=9.2.0
huggingface_hub>=0.13.2
transformers~=4.25.1
accelerate~=0.15.0
diffusers~=0.16.0
tokenizers~=0.13.2
sentencepiece~=0.1.97
ftfy~=6.1.1
beautifulsoup4~=4.11.1


================================================
FILE: setup.cfg
================================================
[pep8]
max-line-length = 120
exclude = .tox,*migrations*,.json

[flake8]
max-line-length = 120
exclude = .tox,*migrations*,.json

[autopep8-wrapper]
exclude = .tox,*migrations*,.json

[check-docstring-first]
exclude = .tox,*migrations*,.json


================================================
FILE: setup.py
================================================
# -*- coding: utf-8 -*-
import os
import re
from setuptools import setup


def read(filename):
    with open(os.path.join(os.path.dirname(__file__), filename)) as f:
        file_content = f.read()
    return file_content


def get_requirements():
    requirements = []
    for requirement in read('requirements.txt').splitlines():
        if requirement.startswith('git+') or requirement.startswith('svn+') or requirement.startswith('hg+'):
            parsed_requires = re.findall(r'#egg=([\w\d\.]+)-([\d\.]+)$', requirement)
            if parsed_requires:
                package, version = parsed_requires[0]
                requirements.append(f'{package}=={version}')
            else:
                print('WARNING! For correct matching dependency links need to specify package name and version'
                      'such as <dependency url>#egg=<package_name>-<version>')
        else:
            requirements.append(requirement)
    return requirements


def get_links():
    return [
        requirement for requirement in read('requirements.txt').splitlines()
        if requirement.startswith('git+') or requirement.startswith('svn+') or requirement.startswith('hg+')
    ]


def get_version():
    """ Get version from the package without actually importing it. """
    init = read('deepfloyd_if/__init__.py')
    for line in init.split('\n'):
        if line.startswith('__version__'):
            return eval(line.split('=')[1])


setup(
    name='deepfloyd_if',
    version=get_version(),
    author='DeepFloyd, StabilityAI',
    author_email='shonenkov@gmail.com',
    description='DeepFloyd-IF (Imagen Free)',
    packages=['deepfloyd_if', 'deepfloyd_if/model', 'deepfloyd_if/modules', 'deepfloyd_if/pipelines',
              'deepfloyd_if/resources'],
    package_data={'deepfloyd_if/resources': ['*.png', '*.npy', '*.npz']},
    install_requires=get_requirements(),
    dependency_links=get_links(),
    long_description=read('README.md'),
    long_description_content_type='text/markdown',
)
Download .txt
gitextract_r39ejdyw/

├── .gitattributes
├── .gitignore
├── .pre-commit-config.yaml
├── CHANGELOG.md
├── LICENSE
├── LICENSE-MODEL
├── README.md
├── deepfloyd_if/
│   ├── __init__.py
│   ├── model/
│   │   ├── __init__.py
│   │   ├── gaussian_diffusion.py
│   │   ├── losses.py
│   │   ├── nn.py
│   │   ├── resample.py
│   │   ├── respace.py
│   │   └── unet.py
│   ├── modules/
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── stage_I.py
│   │   ├── stage_II.py
│   │   ├── stage_III.py
│   │   ├── stage_III_sd_x4.py
│   │   ├── t5.py
│   │   └── utils.py
│   ├── pipelines/
│   │   ├── __init__.py
│   │   ├── dream.py
│   │   ├── inpainting.py
│   │   ├── style_transfer.py
│   │   ├── super_resolution.py
│   │   └── utils.py
│   ├── resources/
│   │   ├── p_head_v1.npz
│   │   ├── w_head_v1.npz
│   │   └── zero_t5-v1_1-xxl_vector.npy
│   └── utils.py
├── requirements-dev.txt
├── requirements-test.txt
├── requirements.txt
├── setup.cfg
└── setup.py
Download .txt
SYMBOL INDEX (144 symbols across 20 files)

FILE: deepfloyd_if/model/gaussian_diffusion.py
  function get_named_beta_schedule (line 17) | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
  function betas_for_alpha_bar (line 43) | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.9...
  class ModelMeanType (line 62) | class ModelMeanType(enum.Enum):
  class ModelVarType (line 72) | class ModelVarType(enum.Enum):
  class LossType (line 85) | class LossType(enum.Enum):
    method is_vb (line 93) | def is_vb(self):
  class GaussianDiffusion (line 97) | class GaussianDiffusion:
    method __init__ (line 112) | def __init__(
    method dynamic_thresholding (line 165) | def dynamic_thresholding(self, x, p=0.995, c=1.7):
    method q_mean_variance (line 184) | def q_mean_variance(self, x_start, t):
    method q_sample (line 200) | def q_sample(self, x_start, t, noise=None):
    method q_posterior_mean_variance (line 218) | def q_posterior_mean_variance(self, x_start, x_t, t):
    method p_mean_variance (line 240) | def p_mean_variance(
    method _predict_xstart_from_eps (line 337) | def _predict_xstart_from_eps(self, x_t, t, eps):
    method _predict_xstart_from_xprev (line 344) | def _predict_xstart_from_xprev(self, x_t, t, xprev):
    method _predict_eps_from_xstart (line 354) | def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
    method _scale_timesteps (line 360) | def _scale_timesteps(self, t):
    method p_sample (line 365) | def p_sample(
    method p_sample_loop (line 404) | def p_sample_loop(
    method p_sample_loop_progressive (line 454) | def p_sample_loop_progressive(
    method ddim_sample (line 507) | def ddim_sample(
    method ddim_reverse_sample (line 555) | def ddim_reverse_sample(
    method ddim_sample_loop (line 597) | def ddim_sample_loop(
    method ddim_sample_loop_progressive (line 635) | def ddim_sample_loop_progressive(
    method _vb_terms_bpd (line 686) | def _vb_terms_bpd(
    method training_losses (line 719) | def training_losses(self, model, x_start, t, model_kwargs=None, noise=...
    method _prior_bpd (line 793) | def _prior_bpd(self, x_start):
    method calc_bpd_loop (line 809) | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwar...
  function _extract_into_tensor (line 865) | def _extract_into_tensor(arr, timesteps, broadcast_shape):

FILE: deepfloyd_if/model/losses.py
  function normal_kl (line 12) | def normal_kl(mean1, logvar1, mean2, logvar2):
  function approx_standard_normal_cdf (line 41) | def approx_standard_normal_cdf(x):
  function discretized_gaussian_log_likelihood (line 49) | def discretized_gaussian_log_likelihood(x, *, means, log_scales):

FILE: deepfloyd_if/model/nn.py
  function mean_flat (line 10) | def mean_flat(tensor):
  function gelu (line 17) | def gelu(x):
  function gelu_jit (line 22) | def gelu_jit(x):
  class GELUJit (line 27) | class GELUJit(torch.nn.Module):
    method forward (line 28) | def forward(self, input: Tensor) -> Tensor:
  function get_activation (line 32) | def get_activation(activation):
  class GroupNorm32 (line 45) | class GroupNorm32(nn.GroupNorm):
    method __init__ (line 46) | def __init__(self, num_groups, num_channels, eps=1e-5, dtype=None):
    method forward (line 49) | def forward(self, x):
  class AttentionPooling (line 54) | class AttentionPooling(nn.Module):
    method __init__ (line 56) | def __init__(self, num_heads, embed_dim, dtype=None):
    method forward (line 66) | def forward(self, x):
  function conv_nd (line 105) | def conv_nd(dims, *args, **kwargs):
  function linear (line 118) | def linear(*args, **kwargs):
  function avg_pool_nd (line 125) | def avg_pool_nd(dims, *args, **kwargs):
  function zero_module (line 138) | def zero_module(module):
  function scale_module (line 147) | def scale_module(module, scale):
  function normalization (line 156) | def normalization(channels, dtype=None):
  function timestep_embedding (line 165) | def timestep_embedding(timesteps, dim, max_period=10000, dtype=None):
  function attention (line 187) | def attention(q, k, v, d_k):

FILE: deepfloyd_if/model/resample.py
  class ScheduleSampler (line 8) | class ScheduleSampler(ABC):
    method weights (line 19) | def weights(self):
    method sample (line 25) | def sample(self, batch_size, device):
  class UniformSampler (line 43) | class UniformSampler(ScheduleSampler):
    method __init__ (line 44) | def __init__(self, num_timesteps):
    method weights (line 47) | def weights(self):
  class StaticSampler (line 51) | class StaticSampler(ABC):
    method sample (line 53) | def sample(self, batch_size, device, static_step=100):

FILE: deepfloyd_if/model/respace.py
  function create_gaussian_diffusion (line 8) | def create_gaussian_diffusion(
  function space_timesteps (line 49) | def space_timesteps(num_timesteps, section_counts):
  class SpacedDiffusion (line 108) | class SpacedDiffusion(gd.GaussianDiffusion):
    method __init__ (line 116) | def __init__(self, use_timesteps, **kwargs):
    method p_mean_variance (line 132) | def p_mean_variance(
    method training_losses (line 137) | def training_losses(
    method _wrap_model (line 142) | def _wrap_model(self, model):
    method _scale_timesteps (line 149) | def _scale_timesteps(self, t):
  class _WrappedModel (line 154) | class _WrappedModel:
    method __init__ (line 155) | def __init__(self, model, timestep_map, rescale_timesteps, original_nu...
    method __call__ (line 161) | def __call__(self, x, ts, **kwargs):

FILE: deepfloyd_if/model/unet.py
  class TimestepBlock (line 20) | class TimestepBlock(nn.Module):
    method forward (line 26) | def forward(self, x, emb):
  class TimestepEmbedSequential (line 32) | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    method forward (line 38) | def forward(self, x, emb, encoder_out=None):
  class Upsample (line 49) | class Upsample(nn.Module):
    method __init__ (line 58) | def __init__(self, channels, use_conv, dims=2, out_channels=None, dtyp...
    method forward (line 68) | def forward(self, x):
  class Downsample (line 83) | class Downsample(nn.Module):
    method __init__ (line 92) | def __init__(self, channels, use_conv, dims=2, out_channels=None, dtyp...
    method forward (line 106) | def forward(self, x):
  class ResBlock (line 111) | class ResBlock(TimestepBlock):
    method __init__ (line 126) | def __init__(
    method forward (line 192) | def forward(self, x, emb):
  class AttentionBlock (line 225) | class AttentionBlock(nn.Module):
    method __init__ (line 232) | def __init__(
    method forward (line 265) | def forward(self, x, encoder_out=None):
  class QKVAttention (line 280) | class QKVAttention(nn.Module):
    method __init__ (line 285) | def __init__(self, n_heads, disable_self_attention=False):
    method forward (line 290) | def forward(self, qkv, encoder_kv=None):
  class UNetModel (line 326) | class UNetModel(nn.Module):
    method __init__ (line 353) | def __init__(
    method forward (line 628) | def forward(self, x, timesteps, text_emb, timestep_text_emb=None, aug_...
  class SuperResUNetModel (line 664) | class SuperResUNetModel(UNetModel):
    method __init__ (line 670) | def __init__(self, low_res_diffusion, interpolate_mode='bilinear', *ar...
    method forward (line 681) | def forward(self, x, timesteps, low_res, aug_level=None, **kwargs):

FILE: deepfloyd_if/modules/base.py
  class IFBaseModule (line 23) | class IFBaseModule:
    method __init__ (line 57) | def __init__(self, dir_or_name, device, pil_img_size=256, cache_dir=No...
    method use_diffusers (line 67) | def use_diffusers(self):
    method embeddings_to_image (line 70) | def embeddings_to_image(
    method load_conf (line 231) | def load_conf(self, dir_or_name, filename='config.yml'):
    method load_checkpoint (line 236) | def load_checkpoint(self, model, dir_or_name, filename='pytorch_model....
    method _get_path_or_download_file_from_hf (line 247) | def _get_path_or_download_file_from_hf(self, dir_or_name, filename):
    method get_diffusion (line 256) | def get_diffusion(self, timestep_respacing):
    method seed_everything (line 272) | def seed_everything(seed=None):
    method device_name (line 284) | def device_name(self):
    method to_images (line 291) | def to_images(self, generations, disable_watermark=False):
    method show (line 312) | def show(self, pil_images, nrow=None, size=10):
    method _clear_cache (line 330) | def _clear_cache(self):
    method _get_image_sizes (line 333) | def _get_image_sizes(self, low_res, img_size, aspect_ratio, img_scale):
    method __validate_generations (line 352) | def __validate_generations(self, generations):

FILE: deepfloyd_if/modules/stage_I.py
  class IFStageI (line 8) | class IFStageI(IFBaseModule):
    method __init__ (line 12) | def __init__(self, *args, model_kwargs=None, pil_img_size=64, **kwargs):
    method embeddings_to_image (line 27) | def embeddings_to_image(self, t5_embs, style_t5_embs=None, positive_t5...

FILE: deepfloyd_if/modules/stage_II.py
  class IFStageII (line 8) | class IFStageII(IFBaseModule):
    method __init__ (line 12) | def __init__(self, *args, model_kwargs=None, pil_img_size=256, **kwargs):
    method embeddings_to_image (line 21) | def embeddings_to_image(

FILE: deepfloyd_if/modules/stage_III.py
  class IFStageIII (line 8) | class IFStageIII(IFBaseModule):
    method __init__ (line 12) | def __init__(self, *args, model_kwargs=None, pil_img_size=1024, **kwar...
    method embeddings_to_image (line 21) | def embeddings_to_image(

FILE: deepfloyd_if/modules/stage_III_sd_x4.py
  class StableStageIII (line 11) | class StableStageIII(IFBaseModule):
    method __init__ (line 15) | def __init__(self, *args, model_kwargs=None, pil_img_size=1024, **kwar...
    method use_diffusers (line 41) | def use_diffusers(self):
    method embeddings_to_image (line 48) | def embeddings_to_image(

FILE: deepfloyd_if/modules/t5.py
  class T5Embedder (line 14) | class T5Embedder:
    method __init__ (line 19) | def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, cache_dir=Non...
    method get_text_embeddings (line 87) | def get_text_embeddings(self, texts):
    method text_preprocessing (line 110) | def text_preprocessing(self, text):
    method basic_clean (line 120) | def basic_clean(text):
    method clean_caption (line 125) | def clean_caption(self, caption):

FILE: deepfloyd_if/modules/utils.py
  function predict_proba (line 6) | def predict_proba(X, weights, biases):
  function load_model_weights (line 12) | def load_model_weights(path):
  function clip_process_generations (line 17) | def clip_process_generations(generations):

FILE: deepfloyd_if/pipelines/dream.py
  function dream (line 7) | def dream(

FILE: deepfloyd_if/pipelines/inpainting.py
  function inpainting (line 10) | def inpainting(

FILE: deepfloyd_if/pipelines/style_transfer.py
  function style_transfer (line 11) | def style_transfer(

FILE: deepfloyd_if/pipelines/super_resolution.py
  function super_resolution (line 9) | def super_resolution(

FILE: deepfloyd_if/pipelines/utils.py
  function _prepare_pil_image (line 8) | def _prepare_pil_image(raw_pil_img, img_size):

FILE: deepfloyd_if/utils.py
  function drop_shadow (line 11) | def drop_shadow(image, offset=(5, 5), background=0xffffff, shadow=0x4444...
  function pil_list_to_torch_tensors (line 56) | def pil_list_to_torch_tensors(pil_images):

FILE: setup.py
  function read (line 7) | def read(filename):
  function get_requirements (line 13) | def get_requirements():
  function get_links (line 29) | def get_links():
  function get_version (line 36) | def get_version():
Condensed preview — 38 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (176K chars).
[
  {
    "path": ".gitattributes",
    "chars": 71,
    "preview": "notebooks/pipes-DeepFloyd-IF.ipynb filter=lfs diff=lfs merge=lfs -text\n"
  },
  {
    "path": ".gitignore",
    "chars": 1815,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": ".pre-commit-config.yaml",
    "chars": 582,
    "preview": "repos:\n-   repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v4.2.0\n    hooks:\n    -   id: check-docstring-f"
  },
  {
    "path": "CHANGELOG.md",
    "chars": 527,
    "preview": "v1.0.2rc\n-------\n\n- uses separated tokenizer_path to init tokenizer in T5Embedder\n\nv1.0.1\n------\n\n- renamed main model `"
  },
  {
    "path": "LICENSE",
    "chars": 1504,
    "preview": "Copyright (c) 2023 DeepFloyd, StabilityAI\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\n"
  },
  {
    "path": "LICENSE-MODEL",
    "chars": 11529,
    "preview": "DEEPFLOYD IF LICENSE AGREEMENT\n\nThis License Agreement (as may be amended in accordance with this License Agreement, “Li"
  },
  {
    "path": "README.md",
    "chars": 15597,
    "preview": "[![License](https://img.shields.io/badge/Code_License-Modified_MIT-blue.svg)](LICENSE)\n[![License](https://img.shields.i"
  },
  {
    "path": "deepfloyd_if/__init__.py",
    "chars": 51,
    "preview": "# -*- coding: utf-8 -*-\n\n\n__version__ = '1.0.2rc0'\n"
  },
  {
    "path": "deepfloyd_if/model/__init__.py",
    "chars": 118,
    "preview": "# -*- coding: utf-8 -*-\nfrom .unet import UNetModel, SuperResUNetModel\n\n\n__all__ = ['UNetModel', 'SuperResUNetModel']\n"
  },
  {
    "path": "deepfloyd_if/model/gaussian_diffusion.py",
    "chars": 35215,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"\nThis code started out as a PyTorch port of Ho et al's diffusion model:\nhttps://github.com/ho"
  },
  {
    "path": "deepfloyd_if/model/losses.py",
    "chars": 2587,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"\nHelpers for various likelihood-based losses. These are ported from the original\nHo et al. di"
  },
  {
    "path": "deepfloyd_if/model/nn.py",
    "chars": 5965,
    "preview": "# -*- coding: utf-8 -*-\nimport math\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom torch import"
  },
  {
    "path": "deepfloyd_if/model/resample.py",
    "chars": 2001,
    "preview": "# -*- coding: utf-8 -*-\nfrom abc import ABC, abstractmethod\n\nimport torch\nimport numpy as np\n\n\nclass ScheduleSampler(ABC"
  },
  {
    "path": "deepfloyd_if/model/respace.py",
    "chars": 6379,
    "preview": "# -*- coding: utf-8 -*-\nimport torch\nimport numpy as np\n\nfrom . import gaussian_diffusion as gd\n\n\ndef create_gaussian_di"
  },
  {
    "path": "deepfloyd_if/model/unet.py",
    "chars": 27952,
    "preview": "# -*- coding: utf-8 -*-\nimport os\nimport math\nfrom abc import abstractmethod\n\nimport torch\nimport numpy as np\nimport tor"
  },
  {
    "path": "deepfloyd_if/modules/__init__.py",
    "chars": 321,
    "preview": "# -*- coding: utf-8 -*-\nfrom .stage_I import IFStageI\nfrom .stage_II import IFStageII\nfrom .stage_III import IFStageIII\n"
  },
  {
    "path": "deepfloyd_if/modules/base.py",
    "chars": 15411,
    "preview": "# -*- coding: utf-8 -*-\nimport os\nimport random\nimport platform\nfrom datetime import datetime\n\nimport torch\nimport torch"
  },
  {
    "path": "deepfloyd_if/modules/stage_I.py",
    "chars": 1977,
    "preview": "# -*- coding: utf-8 -*-\nimport accelerate\n\nfrom .base import IFBaseModule\nfrom ..model import UNetModel\n\n\nclass IFStageI"
  },
  {
    "path": "deepfloyd_if/modules/stage_II.py",
    "chars": 1939,
    "preview": "# -*- coding: utf-8 -*-\nimport accelerate\n\nfrom .base import IFBaseModule\nfrom ..model import SuperResUNetModel\n\n\nclass "
  },
  {
    "path": "deepfloyd_if/modules/stage_III.py",
    "chars": 1962,
    "preview": "# -*- coding: utf-8 -*-\nimport accelerate\n\nfrom .base import IFBaseModule\nfrom ..model import SuperResUNetModel\n\n\nclass "
  },
  {
    "path": "deepfloyd_if/modules/stage_III_sd_x4.py",
    "chars": 3126,
    "preview": "# -*- coding: utf-8 -*-\nimport diffusers\nfrom diffusers import DiffusionPipeline, DDPMScheduler\nimport torch\nimport os\n\n"
  },
  {
    "path": "deepfloyd_if/modules/t5.py",
    "chars": 9860,
    "preview": "# -*- coding: utf-8 -*-\nimport os\nimport re\nimport html\nimport urllib.parse as ul\n\nimport ftfy\nimport torch\nfrom bs4 imp"
  },
  {
    "path": "deepfloyd_if/modules/utils.py",
    "chars": 742,
    "preview": "# -*- coding: utf-8 -*-\nimport numpy as np\nimport torchvision.transforms as T\n\n\ndef predict_proba(X, weights, biases):\n "
  },
  {
    "path": "deepfloyd_if/pipelines/__init__.py",
    "chars": 247,
    "preview": "# -*- coding: utf-8 -*-\nfrom .dream import dream\nfrom .style_transfer import style_transfer\nfrom .super_resolution impor"
  },
  {
    "path": "deepfloyd_if/pipelines/dream.py",
    "chars": 5565,
    "preview": "# -*- coding: utf-8 -*-\nfrom datetime import datetime\n\nimport torch\n\n\ndef dream(\n    t5,\n    if_I,\n    if_II=None,\n    i"
  },
  {
    "path": "deepfloyd_if/pipelines/inpainting.py",
    "chars": 4522,
    "preview": "# -*- coding: utf-8 -*-\nfrom datetime import datetime\n\nimport PIL\nimport torch\n\nfrom .utils import _prepare_pil_image\n\n\n"
  },
  {
    "path": "deepfloyd_if/pipelines/style_transfer.py",
    "chars": 4443,
    "preview": "# -*- coding: utf-8 -*-\n\nfrom datetime import datetime\n\nimport PIL\nimport torch\n\nfrom .utils import _prepare_pil_image\n\n"
  },
  {
    "path": "deepfloyd_if/pipelines/super_resolution.py",
    "chars": 1722,
    "preview": "# -*- coding: utf-8 -*-\n\nfrom datetime import datetime\n\nimport PIL\nfrom .utils import _prepare_pil_image\n\n\ndef super_res"
  },
  {
    "path": "deepfloyd_if/pipelines/utils.py",
    "chars": 709,
    "preview": "# -*- coding: utf-8 -*-\n\nimport torch\nimport numpy as np\nfrom PIL import Image\n\n\ndef _prepare_pil_image(raw_pil_img, img"
  },
  {
    "path": "deepfloyd_if/utils.py",
    "chars": 2390,
    "preview": "# -*- coding: utf-8 -*-\nfrom os.path import abspath, dirname, join\n\nimport torch\nimport numpy as np\nfrom PIL import Imag"
  },
  {
    "path": "requirements-dev.txt",
    "chars": 36,
    "preview": "-r requirements-test.txt\npre-commit\n"
  },
  {
    "path": "requirements-test.txt",
    "chars": 38,
    "preview": "-r requirements.txt\npytest\npytest-cov\n"
  },
  {
    "path": "requirements.txt",
    "chars": 228,
    "preview": "tqdm\nnumpy\ntorch<2.0.0\ntorchvision\nomegaconf\nmatplotlib\nPillow>=9.2.0\nhuggingface_hub>=0.13.2\ntransformers~=4.25.1\naccel"
  },
  {
    "path": "setup.cfg",
    "chars": 242,
    "preview": "[pep8]\nmax-line-length = 120\nexclude = .tox,*migrations*,.json\n\n[flake8]\nmax-line-length = 120\nexclude = .tox,*migration"
  },
  {
    "path": "setup.py",
    "chars": 2018,
    "preview": "# -*- coding: utf-8 -*-\nimport os\nimport re\nfrom setuptools import setup\n\n\ndef read(filename):\n    with open(os.path.joi"
  }
]

// ... and 3 more files (download for full content)

About this extraction

This page contains the full source code of the deep-floyd/IF GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 38 files (165.4 KB), approximately 42.9k tokens, and a symbol index with 144 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!