Repository: deep-floyd/IF
Branch: develop
Commit: ffc816389168
Files: 38
Total size: 165.4 KB
Directory structure:
gitextract_r39ejdyw/
├── .gitattributes
├── .gitignore
├── .pre-commit-config.yaml
├── CHANGELOG.md
├── LICENSE
├── LICENSE-MODEL
├── README.md
├── deepfloyd_if/
│ ├── __init__.py
│ ├── model/
│ │ ├── __init__.py
│ │ ├── gaussian_diffusion.py
│ │ ├── losses.py
│ │ ├── nn.py
│ │ ├── resample.py
│ │ ├── respace.py
│ │ └── unet.py
│ ├── modules/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── stage_I.py
│ │ ├── stage_II.py
│ │ ├── stage_III.py
│ │ ├── stage_III_sd_x4.py
│ │ ├── t5.py
│ │ └── utils.py
│ ├── pipelines/
│ │ ├── __init__.py
│ │ ├── dream.py
│ │ ├── inpainting.py
│ │ ├── style_transfer.py
│ │ ├── super_resolution.py
│ │ └── utils.py
│ ├── resources/
│ │ ├── p_head_v1.npz
│ │ ├── w_head_v1.npz
│ │ └── zero_t5-v1_1-xxl_vector.npy
│ └── utils.py
├── requirements-dev.txt
├── requirements-test.txt
├── requirements.txt
├── setup.cfg
└── setup.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitattributes
================================================
notebooks/pipes-DeepFloyd-IF.ipynb filter=lfs diff=lfs merge=lfs -text
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.idea
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, deepfloyd_if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
================================================
FILE: .pre-commit-config.yaml
================================================
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.2.0
hooks:
- id: check-docstring-first
- id: check-merge-conflict
stages:
- push
- id: double-quote-string-fixer
- id: end-of-file-fixer
- id: fix-encoding-pragma
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/pycqa/flake8
rev: "4.0.1"
hooks:
- id: flake8
args: ['--config=setup.cfg']
- repo: https://github.com/pre-commit/mirrors-autopep8
rev: v1.6.0
hooks:
- id: autopep8
================================================
FILE: CHANGELOG.md
================================================
v1.0.2rc
-------
- uses separated tokenizer_path to init tokenizer in T5Embedder
v1.0.1
------
- renamed main model `IF-I-IF` --> `IF-I-XL`
- moved dir `notebooks` to HF storage https://huggingface.co/DeepFloyd/IF-notebooks; lets keep new notebooks there;
- added additional kaggle notebook (more free GPU resources) how to generate pictures 1k: [](https://www.kaggle.com/code/shonenkov/deepfloyd-if-4-3b-generator-of-pictures)
v1.0.0
------
- initial version
================================================
FILE: LICENSE
================================================
Copyright (c) 2023 DeepFloyd, StabilityAI
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
1. The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
2. All persons obtaining a copy or substantial portion of the Software,
a modified version of the Software (or substantial portion thereof), or
a derivative work based upon this Software (or substantial portion thereof)
must not delete, remove, disable, diminish, or circumvent any inference filters or
inference filter mechanisms in the Software, or any portion of the Software that
implements any such filters or filter mechanisms.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: LICENSE-MODEL
================================================
DEEPFLOYD IF LICENSE AGREEMENT
This License Agreement (as may be amended in accordance with this License Agreement, “License”),
between you, or your employer or other entity (if you are entering into this agreement on behalf
of your employer or other entity) (“Licensee” or “you”) and Stability AI Ltd.. (“Stability AI” or “we”)
applies to your use of any computer program, algorithm, source code, object code, or software that is made
available by Stability AI under this License (“Software”) and any specifications, manuals, documentation,
and other written information provided by Stability AI related to the Software (“Documentation”).
By clicking “I Accept” below or by using the Software, you agree to the terms of this License.
If you do not agree to this License, then you do not have any rights to use the Software or
Documentation (collectively, the “Software Products”), and you must immediately cease using
the Software Products. If you are agreeing to be bound by the terms of this License on behalf
of your employer or other entity, you represent and warrant to Stability AI that you have full legal
authority to bind your employer or such entity to this License. If you do not have the requisite authority,
you may not accept the License or access the Software Products on behalf of your employer or other entity.
1. LICENSE GRANT
a. Subject to your compliance with the Documentation and Sections 2, 3, and 5, Stability AI grants
you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited
license under Stability AI’s copyright interests to reproduce, distribute, and create derivative works of
the Software solely for your non-commercial research purposes. The foregoing license is personal to you,
and you may not assign or sublicense this License or any other rights or obligations under this License
without Stability AI’s prior written consent; any such assignment or sublicense will be void and will
automatically and immediately terminate this License.
b. You may make a reasonable number of copies of the Documentation solely for use in connection with
the license to the Software granted above.
c. The grant of rights expressly set forth in this Section 1 (License Grant) are the complete
grant of rights to you in the Software Products, and no other licenses are granted, whether by waiver,
estoppel, implication, equity or otherwise. Stability AI and its licensors reserve all rights
not expressly granted by this License.
2. RESTRICTIONS
You will not, and will not permit, assist or cause any third party to:
a. use, modify, copy, reproduce, create derivative works of, or distribute the Software Products
(or any derivative works thereof, works incorporating the Software Products, or any data produced
by the Software), in whole or in part, for (i) any commercial or production purposes,
(ii) military purposes or in the service of nuclear technology, (iii) purposes of surveillance,
including any research or development relating to surveillance, (iv) biometric processing,
(v) in any manner that infringes, misappropriates, or otherwise violates any third-party rights,
or (vi) in any manner that violates any applicable law and violating any privacy or security laws,
rules, regulations, directives, or governmental requirements (including the General Data Privacy
Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and all laws
governing the processing of biometric information), as well as all amendments and successor laws
to any of the foregoing;
b. alter or remove copyright and other proprietary notices which appear on or in the Software Products;
c. utilize any equipment, device, software, or other means to circumvent or remove any security or
protection used by Stability AI in connection with the Software, or to circumvent or remove any
usage restrictions, or to enable functionality disabled by Stability AI; or
d. offer or impose any terms on the Software Products that alter, restrict, or are inconsistent
with the terms of this License.
e. 1) violate any applicable U.S. and non-U.S. export control and trade sanctions laws
(“Export Laws”); 2) directly or indirectly export, re-export, provide, or otherwise
transfer Software Products: (a) to any individual, entity, or country prohibited by Export Laws; (b)
to anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited
by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications;
3) use or download Software Products if you or they are: (a) located in a comprehensively sanctioned
jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any
purpose prohibited by Export Laws; and (4) will not disguise your location through IP proxying or other methods.
3. ATTRIBUTION
Together with any copies of the Software Products (as well as derivative works thereof or works
incorporating the Software Products) that you distribute, you must provide (i) a copy of this License,
and (ii) the following attribution notice: “DeepFloyd is licensed under the DeepFloyd License,
Copyright (c) Stability AI Ltd. All Rights Reserved.”
4. DISCLAIMERS
THE SOFTWARE PRODUCTS ARE PROVIDED “AS IS” and “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED. STABILITY AIEXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED,
WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE SOFTWARE PRODUCTS,
INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE,
TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. STABILITY AI MAKES NO WARRANTIES OR REPRESENTATIONS
THAT THE SOFTWARE PRODUCTS WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS,
OR PRODUCE ANY PARTICULAR RESULTS.
5. LIMITATION OF LIABILITY
TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL STABILITY AI BE LIABLE TO YOU (A) UNDER
ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY,
OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL,
PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF STABILITY AI HAS BEEN ADVISED OF THE POSSIBILITY
OF SUCH DAMAGES. THE SOFTWARE PRODUCTS, THEIR CONSTITUENT COMPONENTS, AND ANY OUTPUT
(COLLECTIVELY, “SOFTWARE MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR
SITUATION WHERE FAILURE OR FAULT OF THE SOFTWARE MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD
TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S
PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”).
IF YOU ELECT TO USE ANY OF THE SOFTWARE MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK.
YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND
POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY
OF THE SOFTWARE MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL
THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE.
6. INDEMNIFICATION
You will indemnify, defend and hold harmless Stability AI and our subsidiaries and affiliates,
and each of our respective shareholders, directors, officers, employees, agents, successors,
and assigns (collectively, the “Stability AI Parties”) from and against any losses, liabilities,
damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any
Stability AI Party in connection with any claim, demand, allegation, lawsuit, proceeding, or
investigation (collectively, “Claims”) arising out of or related to: (a) your access to or
use of the Software Products (as well as any results or data generated from such access or use),
including any High-Risk Use (defined below); (b) your violation of this License; or (c)
your violation, misappropriation or infringement of any rights of another (including intellectual
property or other proprietary rights and privacy rights). You will promptly notify the Stability AI
Parties of any such Claims, and cooperate with Stability AI Parties in defending such Claims.
You will also grant the Stability AI Parties sole control of the defense or settlement,
at Stability AI’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of,
any other indemnities or remedies set forth in a written agreement between you and
Stability AI or the other Stability AI Parties.
7. TERMINATION; SURVIVAL
a. This License will automatically terminate upon any breach by you of the terms of this License.
b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you.
c. The following sections survive termination of this License: 2 (Restrictions), 3 (Attribution),
4 (Disclaimers), 5 (Limitation on Liability), 6 (Indemnification) 7 (Termination; Survival),
8 (Third Party Materials), 9 (Trademarks), 10 (Applicable Law; Dispute Resolution), and 11 (Miscellaneous).
8. THIRD PARTY MATERIALS
The Software Products may contain third-party software or other components (including free and
open source software) (all of the foregoing, “Third Party Materials”), which are subject to
the license terms of the respective third-party licensors. Your dealings or correspondence
with third parties and your use of or interaction with any Third Party Materials are solely
between you and the third party. Stability AI does not control or endorse, and makes
no representations or warranties regarding, any Third Party Materials, and your access
to and use of such Third Party Materials are at your own risk.
9. TRADEMARKS
Licensee has not been granted any trademark license as part of this License and may not use any name
or mark associated with Stability AI without the prior written permission of Stability AI, except to
the extent necessary to make the reference required by the “ATTRIBUTION” section of this Agreement.
10. APPLICABLE LAW; DISPUTE RESOLUTION
This License will be governed and construed under the laws of the State of California without regard
to conflicts of law provisions. Any suit or proceeding arising out of or relating to this License
will be brought in the federal or state courts, as applicable, in San Mateo County, California,
and each party irrevocably submits to the jurisdiction and venue of such courts.
11. MISCELLANEOUS
If any provision or part of a provision of this License is unlawful, void or unenforceable,
that provision or part of the provision is deemed severed from this License, and will not affect
the validity and enforceability of any remaining provisions. The failure of Stability AI to exercise
or enforce any right or provision of this License will not operate as a waiver of such right or provision.
This License does not confer any third-party beneficiary rights upon any other person or entity.
This License, together with the Documentation, contains the entire understanding between you and
Stability AI regarding the subject matter of this License, and supersedes all other written or
oral agreements and understandings between you and Stability AI regarding such subject matter.
No change or addition to any provision of this License will be binding unless it is in writing and
signed by an authorized representative of both you and Stability AI.
================================================
FILE: README.md
================================================
[](LICENSE)
[](LICENSE-MODEL)
[](https://pepy.tech/project/deepfloyd_if)
[](https://discord.gg/umz62Mgr)
[](https://twitter.com/deepfloydai)
[](http://linktr.ee/deepfloyd)
# IF by [DeepFloyd Lab](https://deepfloyd.ai) at [StabilityAI](https://stability.ai/)
We introduce DeepFloyd IF, a novel state-of-the-art open-source text-to-image model with a high degree of photorealism and language understanding. DeepFloyd IF is a modular composed of a frozen text encoder and three cascaded pixel diffusion modules: a base model that generates 64x64 px image based on text prompt and two super-resolution models, each designed to generate images of increasing resolution: 256x256 px and 1024x1024 px. All stages of the model utilize a frozen text encoder based on the T5 transformer to extract text embeddings, which are then fed into a UNet architecture enhanced with cross-attention and attention pooling. The result is a highly efficient model that outperforms current state-of-the-art models, achieving a zero-shot FID score of 6.66 on the COCO dataset. Our work underscores the potential of larger UNet architectures in the first stage of cascaded diffusion models and depicts a promising future for text-to-image synthesis.
*Inspired by* [*Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding*](https://arxiv.org/pdf/2205.11487.pdf)
## Minimum requirements to use all IF models:
- 16GB vRAM for IF-I-XL (4.3B text to 64x64 base module) & IF-II-L (1.2B to 256x256 upscaler module)
- 24GB vRAM for IF-I-XL (4.3B text to 64x64 base module) & IF-II-L (1.2B to 256x256 upscaler module) & Stable x4 (to 1024x1024 upscaler)
- `xformers` and set env variable `FORCE_MEM_EFFICIENT_ATTN=1`
## Quick Start
[](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/deepfloyd_if_free_tier_google_colab.ipynb)
[](https://huggingface.co/spaces/DeepFloyd/IF)
```shell
pip install deepfloyd_if==1.0.2rc0
pip install xformers==0.0.16
pip install git+https://github.com/openai/CLIP.git --no-deps
```
## Local notebooks
[](https://huggingface.co/DeepFloyd/IF-notebooks/blob/main/pipes-DeepFloyd-IF-v1.0.ipynb)
[](https://www.kaggle.com/code/shonenkov/deepfloyd-if-4-3b-generator-of-pictures)
The Dream, Style Transfer, Super Resolution or Inpainting modes are avaliable in a Jupyter Notebook [here](https://huggingface.co/DeepFloyd/IF-notebooks/blob/main/pipes-DeepFloyd-IF-v1.0.ipynb).
## Integration with 🤗 Diffusers
IF is also integrated with the 🤗 Hugging Face [Diffusers library](https://github.com/huggingface/diffusers/).
Diffusers runs each stage individually allowing the user to customize the image generation process as well as allowing to inspect intermediate results easily.
### Example
Before you can use IF, you need to accept its usage conditions. To do so:
1. Make sure to have a [Hugging Face account](https://huggingface.co/join) and be loggin in
2. Accept the license on the model card of [DeepFloyd/IF-I-XL-v1.0](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0)
3. Make sure to login locally. Install `huggingface_hub`
```sh
pip install huggingface_hub --upgrade
```
run the login function in a Python shell
```py
from huggingface_hub import login
login()
```
and enter your [Hugging Face Hub access token](https://huggingface.co/docs/hub/security-tokens#what-are-user-access-tokens).
Next we install `diffusers` and dependencies:
```sh
pip install diffusers accelerate transformers safetensors
```
And we can now run the model locally.
By default `diffusers` makes use of [model cpu offloading](https://huggingface.co/docs/diffusers/optimization/fp16#model-offloading-for-fast-inference-and-memory-savings) to run the whole IF pipeline with as little as 14 GB of VRAM.
If you are using `torch>=2.0.0`, make sure to **delete all** `enable_xformers_memory_efficient_attention()`
functions.
```py
from diffusers import DiffusionPipeline
from diffusers.utils import pt_to_pil
import torch
# stage 1
stage_1 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
stage_1.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0
stage_1.enable_model_cpu_offload()
# stage 2
stage_2 = DiffusionPipeline.from_pretrained(
"DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
)
stage_2.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0
stage_2.enable_model_cpu_offload()
# stage 3
safety_modules = {"feature_extractor": stage_1.feature_extractor, "safety_checker": stage_1.safety_checker, "watermarker": stage_1.watermarker}
stage_3 = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16)
stage_3.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0
stage_3.enable_model_cpu_offload()
prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"'
# text embeds
prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)
generator = torch.manual_seed(0)
# stage 1
image = stage_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt").images
pt_to_pil(image)[0].save("./if_stage_I.png")
# stage 2
image = stage_2(
image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt"
).images
pt_to_pil(image)[0].save("./if_stage_II.png")
# stage 3
image = stage_3(prompt=prompt, image=image, generator=generator, noise_level=100).images
image[0].save("./if_stage_III.png")
```
There are multiple ways to speed up the inference time and lower the memory consumption even more with `diffusers`. To do so, please have a look at the Diffusers docs:
- 🚀 [Optimizing for inference time](https://huggingface.co/docs/diffusers/api/pipelines/if#optimizing-for-speed)
- ⚙️ [Optimizing for low memory during inference](https://huggingface.co/docs/diffusers/api/pipelines/if#optimizing-for-memory)
For more in-detail information about how to use IF, please have a look at [the IF blog post](https://huggingface.co/blog/if) and [the documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/if) 📖.
Diffusers dreambooth scripts also supports fine-tuning 🎨 [IF](https://huggingface.co/docs/diffusers/main/en/training/dreambooth#if).
With parameter efficient finetuning, you can add new concepts to IF with a single GPU and ~28 GB VRAM.
## Run the code locally
### Loading the models into VRAM
```python
from deepfloyd_if.modules import IFStageI, IFStageII, StableStageIII
from deepfloyd_if.modules.t5 import T5Embedder
device = 'cuda:0'
if_I = IFStageI('IF-I-XL-v1.0', device=device)
if_II = IFStageII('IF-II-L-v1.0', device=device)
if_III = StableStageIII('stable-diffusion-x4-upscaler', device=device)
t5 = T5Embedder(device="cpu")
```
### I. Dream
Dream is the text-to-image mode of the IF model
```python
from deepfloyd_if.pipelines import dream
prompt = 'ultra close-up color photo portrait of rainbow owl with deer horns in the woods'
count = 4
result = dream(
t5=t5, if_I=if_I, if_II=if_II, if_III=if_III,
prompt=[prompt]*count,
seed=42,
if_I_kwargs={
"guidance_scale": 7.0,
"sample_timestep_respacing": "smart100",
},
if_II_kwargs={
"guidance_scale": 4.0,
"sample_timestep_respacing": "smart50",
},
if_III_kwargs={
"guidance_scale": 9.0,
"noise_level": 20,
"sample_timestep_respacing": "75",
},
)
if_III.show(result['III'], size=14)
```

## II. Zero-shot Image-to-Image Translation

In Style Transfer mode, the output of your prompt comes out at the style of the `support_pil_img`
```python
from deepfloyd_if.pipelines import style_transfer
result = style_transfer(
t5=t5, if_I=if_I, if_II=if_II,
support_pil_img=raw_pil_image,
style_prompt=[
'in style of professional origami',
'in style of oil art, Tate modern',
'in style of plastic building bricks',
'in style of classic anime from 1990',
],
seed=42,
if_I_kwargs={
"guidance_scale": 10.0,
"sample_timestep_respacing": "10,10,10,10,10,10,10,10,0,0",
'support_noise_less_qsample_steps': 5,
},
if_II_kwargs={
"guidance_scale": 4.0,
"sample_timestep_respacing": 'smart50',
"support_noise_less_qsample_steps": 5,
},
)
if_I.show(result['II'], 1, 20)
```

## III. Super Resolution
For super-resolution, users can run `IF-II` and `IF-III` or 'Stable x4' on an image that was not necessarely generated by IF (two cascades):
```python
from deepfloyd_if.pipelines import super_resolution
middle_res = super_resolution(
t5,
if_III=if_II,
prompt=['woman with a blue headscarf and a blue sweaterp, detailed picture, 4k dslr, best quality'],
support_pil_img=raw_pil_image,
img_scale=4.,
img_size=64,
if_III_kwargs={
'sample_timestep_respacing': 'smart100',
'aug_level': 0.5,
'guidance_scale': 6.0,
},
)
high_res = super_resolution(
t5,
if_III=if_III,
prompt=[''],
support_pil_img=middle_res['III'][0],
img_scale=4.,
img_size=256,
if_III_kwargs={
"guidance_scale": 9.0,
"noise_level": 20,
"sample_timestep_respacing": "75",
},
)
show_superres(raw_pil_image, high_res['III'][0])
```

### IV. Zero-shot Inpainting
```python
from deepfloyd_if.pipelines import inpainting
result = inpainting(
t5=t5, if_I=if_I,
if_II=if_II,
if_III=if_III,
support_pil_img=raw_pil_image,
inpainting_mask=inpainting_mask,
prompt=[
'oil art, a man in a hat',
],
seed=42,
if_I_kwargs={
"guidance_scale": 7.0,
"sample_timestep_respacing": "10,10,10,10,10,0,0,0,0,0",
'support_noise_less_qsample_steps': 0,
},
if_II_kwargs={
"guidance_scale": 4.0,
'aug_level': 0.0,
"sample_timestep_respacing": '100',
},
if_III_kwargs={
"guidance_scale": 9.0,
"noise_level": 20,
"sample_timestep_respacing": "75",
},
)
if_I.show(result['I'], 2, 3)
if_I.show(result['II'], 2, 6)
if_I.show(result['III'], 2, 14)
```

### 🤗 Model Zoo 🤗
The link to download the weights as well as the model cards will be available soon on each model of the model zoo
#### Original
| Name | Cascade | Params | FID | Batch size | Steps |
|:----------------------------------------------------------|:-------:|:------:|:----:|:----------:|:-----:|
| [IF-I-M](https://huggingface.co/DeepFloyd/IF-I-M-v1.0) | I | 400M | 8.86 | 3072 | 2.5M |
| [IF-I-L](https://huggingface.co/DeepFloyd/IF-I-L-v1.0) | I | 900M | 8.06 | 3200 | 3.0M |
| [IF-I-XL](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0)* | I | 4.3B | 6.66 | 3072 | 2.42M |
| [IF-II-M](https://huggingface.co/DeepFloyd/IF-II-M-v1.0) | II | 450M | - | 1536 | 2.5M |
| [IF-II-L](https://huggingface.co/DeepFloyd/IF-II-L-v1.0)* | II | 1.2B | - | 1536 | 2.5M |
| IF-III-L* _(soon)_ | III | 700M | - | 3072 | 1.25M |
*best modules
### Quantitative Evaluation
`FID = 6.66`

## License
The code in this repository is released under the bespoke license (see added [point two](https://github.com/deep-floyd/IF/blob/main/LICENSE#L13)).
The weights will be available soon via [the DeepFloyd organization at Hugging Face](https://huggingface.co/DeepFloyd) and have their own LICENSE.
**Disclaimer:** *The initial release of the IF model is under a restricted research-purposes-only license temporarily to gather feedback, and after that we intend to release a fully open-source model in line with other Stability AI models.*
## Limitations and Biases
The models available in this codebase have known limitations and biases. Please refer to [the model card](https://huggingface.co/DeepFloyd/IF-I-L-v1.0) for more information.
## 🎓 DeepFloyd IF creators:
- Alex Shonenkov [GitHub](https://github.com/shonenkov) | [Linktr](https://linktr.ee/shonenkovAI)
- Misha Konstantinov [GitHub](https://github.com/zeroshot-ai) | [Twitter](https://twitter.com/_bra_ket)
- Daria Bakshandaeva [GitHub](https://github.com/Gugutse) | [Twitter](https://twitter.com/_gugutse_)
- Christoph Schuhmann [GitHub](https://github.com/christophschuhmann) | [Twitter](https://twitter.com/laion_ai)
- Ksenia Ivanova [GitHub](https://github.com/ivksu) | [Twitter](https://twitter.com/susiaiv)
- Nadiia Klokova [GitHub](https://github.com/vauimpuls) | [Twitter](https://twitter.com/vauimpuls)
## 📄 Research Paper (Soon)
## Acknowledgements
Special thanks to [StabilityAI](http://stability.ai) and its CEO [Emad Mostaque](https://twitter.com/emostaque) for invaluable support, providing GPU compute and infrastructure to train the models (our gratitude goes to [Richard Vencu](https://github.com/rvencu)); thanks to [LAION](https://laion.ai) and [Christoph Schuhmann](https://github.com/christophschuhmann) in particular for contribution to the project and well-prepared datasets; thanks to [Huggingface](https://huggingface.co) teams for optimizing models' speed and memory consumption during inference, creating demos and giving cool advice!
## 🚀 External Contributors 🚀
- The Biggest Thanks [@Apolinário](https://github.com/apolinario), for ideas, consultations, help and support on all stages to make IF available in open-source; for writing a lot of documentation and instructions; for creating a friendly atmosphere in difficult moments 🦉;
- Thanks, [@patrickvonplaten](https://github.com/patrickvonplaten), for improving loading time of unet models by 80%;
for integration Stable-Diffusion-x4 as native pipeline 💪;
- Thanks, [@williamberman](https://github.com/williamberman) and [@patrickvonplaten](https://github.com/patrickvonplaten) for diffusers integration 🙌;
- Thanks, [@hysts](https://github.com/hysts) and [@Apolinário](https://github.com/apolinario) for creating [the best gradio demo with IF](https://huggingface.co/spaces/DeepFloyd/IF) 🚀;
- Thanks, [@Dango233](https://github.com/Dango233), for adapting IF with xformers memory efficient attention 💪;
================================================
FILE: deepfloyd_if/__init__.py
================================================
# -*- coding: utf-8 -*-
__version__ = '1.0.2rc0'
================================================
FILE: deepfloyd_if/model/__init__.py
================================================
# -*- coding: utf-8 -*-
from .unet import UNetModel, SuperResUNetModel
__all__ = ['UNetModel', 'SuperResUNetModel']
================================================
FILE: deepfloyd_if/model/gaussian_diffusion.py
================================================
# -*- coding: utf-8 -*-
"""
This code started out as a PyTorch port of Ho et al's diffusion model:
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
"""
import enum
import math
import numpy as np
import torch
from .nn import mean_flat
from .losses import normal_kl, discretized_gaussian_log_likelihood
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
"""
Get a pre-defined beta schedule for the given name.
The beta schedule library consists of beta schedules which remain similar
in the limit of num_diffusion_timesteps.
Beta schedules may be added, but should not be removed or changed once
they are committed to maintain backwards compatibility.
"""
if schedule_name == 'linear':
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / num_diffusion_timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return np.linspace(
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
)
elif schedule_name == 'cosine':
return betas_for_alpha_bar(
num_diffusion_timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f'unknown beta schedule: {schedule_name}')
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
class ModelMeanType(enum.Enum):
"""
Which type of output the model predicts.
"""
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
START_X = enum.auto() # the model predicts x_0
EPSILON = enum.auto() # the model predicts epsilon
class ModelVarType(enum.Enum):
"""
What is used as the model's output variance.
The LEARNED_RANGE option has been added to allow the model to predict
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
"""
LEARNED = enum.auto()
FIXED_SMALL = enum.auto()
FIXED_LARGE = enum.auto()
LEARNED_RANGE = enum.auto()
class LossType(enum.Enum):
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
RESCALED_MSE = (
enum.auto()
) # use raw MSE loss (with RESCALED_KL when learning variances)
KL = enum.auto() # use the variational lower-bound
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
def is_vb(self):
return self == LossType.KL or self == LossType.RESCALED_KL
class GaussianDiffusion:
"""
Utilities for training and sampling diffusion model.
Ported directly from here, and then adapted over time to further experimentation.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
:param betas: a 1-D numpy array of betas for each diffusion timestep,
starting at T and going to 1.
:param model_mean_type: a ModelMeanType determining what the model outputs.
:param model_var_type: a ModelVarType determining how variance is output.
:param loss_type: a LossType determining the loss function to use.
:param rescale_timesteps: if True, pass floating point timesteps into the
model so that they are always scaled like in the
original paper (0 to 1000).
"""
def __init__(
self,
*,
betas,
model_mean_type,
model_var_type,
loss_type,
rescale_timesteps=False,
):
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
self.loss_type = loss_type
self.rescale_timesteps = rescale_timesteps
# Use float64 for accuracy.
betas = np.array(betas, dtype=np.float64)
self.betas = betas
assert len(betas.shape) == 1, 'betas must be 1-D'
assert (betas > 0).all() and (betas <= 1).all()
self.num_timesteps = int(betas.shape[0])
alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = (
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
# log calculation clipped because the posterior variance is 0 at the
# beginning of the diffusion chain.
self.posterior_log_variance_clipped = np.log(
np.append(self.posterior_variance[1], self.posterior_variance[1:])
)
self.posterior_mean_coef1 = (
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
self.posterior_mean_coef2 = (
(1.0 - self.alphas_cumprod_prev)
* np.sqrt(alphas)
/ (1.0 - self.alphas_cumprod)
)
def dynamic_thresholding(self, x, p=0.995, c=1.7):
"""
Dynamic thresholding, a diffusion sampling technique from Imagen (https://arxiv.org/abs/2205.11487)
to leverage high guidance weights and generating more photorealistic and detailed images
than previously was possible based on x.clamp(-1, 1) vanilla clipping or static thresholding
p — percentile determine relative value for clipping threshold for dynamic compression,
helps prevent oversaturation recommend values [0.96 — 0.99]
c — absolute hard clipping of value for clipping threshold for dynamic compression,
helps prevent undersaturation and low contrast issues; recommend values [1.5 — 2.]
"""
x_shapes = x.shape
s = torch.quantile(x.abs().reshape(x_shapes[0], -1), p, dim=-1)
s = torch.clamp(s, min=1, max=c)
x_compressed = torch.clip(x.reshape(x_shapes[0], -1).T, -s, s) / s
x_compressed = x_compressed.T.reshape(x_shapes)
return x_compressed
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
)
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = _extract_into_tensor(
self.log_one_minus_alphas_cumprod, t, x_start.shape
)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data for a given number of diffusion steps.
In other words, sample from q(x_t | x_0).
:param x_start: the initial data batch.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:param noise: if specified, the split-out normal noise.
:return: A noisy version of x_start.
"""
if noise is None:
noise = torch.randn_like(x_start)
assert noise.shape == x_start.shape
return (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior:
q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor(
self.posterior_log_variance_clipped, t, x_t.shape
)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
== posterior_log_variance_clipped.shape[0]
== x_start.shape[0]
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(
self, model, x, t, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7,
denoised_fn=None, model_kwargs=None
):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
:param model: the model, which takes a signal and a batch of timesteps
as input.
:param x: the [N x C x ...] tensor at time t.
:param t: a 1-D Tensor of timesteps.
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample. Applies before
clip_denoised.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict with the following keys:
- 'mean': the model mean output.
- 'variance': the model variance output.
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
"""
if model_kwargs is None:
model_kwargs = {}
B, C = x.shape[:2]
assert t.shape == (B,)
model_output = model(x, self._scale_timesteps(t), **model_kwargs)
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
assert model_output.shape == (B, C * 2, *x.shape[2:])
model_output, model_var_values = torch.split(model_output, C, dim=1)
if self.model_var_type == ModelVarType.LEARNED:
model_log_variance = model_var_values
model_variance = torch.exp(model_log_variance)
else:
min_log = _extract_into_tensor(
self.posterior_log_variance_clipped, t, x.shape
)
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
model_variance = torch.exp(model_log_variance)
else:
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so
# to get a better decoder log likelihood.
ModelVarType.FIXED_LARGE: (
np.append(self.posterior_variance[1], self.betas[1:]),
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
),
ModelVarType.FIXED_SMALL: (
self.posterior_variance,
self.posterior_log_variance_clipped,
),
}[self.model_var_type]
model_variance = _extract_into_tensor(model_variance, t, x.shape)
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
def process_xstart(x):
if denoised_fn is not None:
x = denoised_fn(x)
if clip_denoised:
x = self.dynamic_thresholding(x, p=dynamic_thresholding_p, c=dynamic_thresholding_c)
return x # x.clamp(-1, 1)
return x
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
pred_xstart = process_xstart(
self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
)
model_mean = model_output
elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
if self.model_mean_type == ModelMeanType.START_X:
pred_xstart = process_xstart(model_output)
else:
pred_xstart = process_xstart(
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
)
model_mean, _, _ = self.q_posterior_mean_variance(
x_start=pred_xstart, x_t=x, t=t
)
else:
raise NotImplementedError(self.model_mean_type)
assert (
model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
)
return {
'mean': model_mean,
'variance': model_variance,
'log_variance': model_log_variance,
'pred_xstart': pred_xstart,
}
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
)
def _predict_xstart_from_xprev(self, x_t, t, xprev):
assert x_t.shape == xprev.shape
return ( # (xprev - coef2*x_t) / coef1
_extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
- _extract_into_tensor(
self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
)
* x_t
)
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- pred_xstart
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def _scale_timesteps(self, t):
if self.rescale_timesteps:
return t.float() * (1000.0 / self.num_timesteps)
return t
def p_sample(
self, model, x, t, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7,
denoised_fn=None, model_kwargs=None, inpainting_mask=None,
):
"""
Sample x_{t-1} from the model at the given timestep.
:param model: the model to sample from.
:param x: the current tensor at x_{t-1}.
:param t: the value of t, starting at 0 for the first diffusion step.
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
- 'sample': a random sample from the model.
- 'pred_xstart': a prediction of x_0.
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
dynamic_thresholding_p=dynamic_thresholding_p,
dynamic_thresholding_c=dynamic_thresholding_c,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
noise = torch.randn_like(x)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
if inpainting_mask is None:
inpainting_mask = torch.ones_like(x, device=x.device)
sample = out['mean'] + nonzero_mask * torch.exp(0.5 * out['log_variance']) * noise
sample = (1 - inpainting_mask)*x + inpainting_mask*sample
return {'sample': sample, 'pred_xstart': out['pred_xstart']}
def p_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
dynamic_thresholding_p=0.99,
dynamic_thresholding_c=1.7,
inpainting_mask=None,
denoised_fn=None,
model_kwargs=None,
device=None,
progress=False,
sample_fn=None,
):
"""
Generate samples from the model.
:param model: the model module.
:param shape: the shape of the samples, (N, C, H, W).
:param noise: if specified, the noise from the encoder to sample.
Should be of the same shape as `shape`.
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param device: if specified, the device to create the samples on.
If not specified, use a model parameter's device.
:param progress: if True, show a tqdm progress bar.
:return: a non-differentiable batch of samples.
"""
final = None
for step_idx, sample in enumerate(self.p_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
dynamic_thresholding_p=dynamic_thresholding_p,
dynamic_thresholding_c=dynamic_thresholding_c,
denoised_fn=denoised_fn,
inpainting_mask=inpainting_mask,
model_kwargs=model_kwargs,
device=device,
progress=progress,
)):
if sample_fn is not None:
sample = sample_fn(step_idx, sample)
final = sample
return final['sample']
def p_sample_loop_progressive(
self,
model,
shape,
inpainting_mask=None,
noise=None,
clip_denoised=True,
dynamic_thresholding_p=0.99,
dynamic_thresholding_c=1.7,
denoised_fn=None,
model_kwargs=None,
device=None,
progress=False,
):
"""
Generate samples from the model and yield intermediate samples from
each timestep of diffusion.
Arguments are the same as p_sample_loop().
Returns a generator over dicts, where each dict is the return value of
p_sample().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
else:
img = torch.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
t = torch.tensor([i] * shape[0], device=device)
with torch.no_grad():
out = self.p_sample(
model,
img,
t,
clip_denoised=clip_denoised,
dynamic_thresholding_p=dynamic_thresholding_p,
dynamic_thresholding_c=dynamic_thresholding_c,
denoised_fn=denoised_fn,
inpainting_mask=inpainting_mask,
model_kwargs=model_kwargs,
)
yield out
img = out['sample']
def ddim_sample(
self,
model,
x,
t,
clip_denoised=True,
dynamic_thresholding_p=0.99,
dynamic_thresholding_c=1.7,
denoised_fn=None,
model_kwargs=None,
eta=0.0,
):
"""
Sample x_{t-1} from the model using DDIM.
Same usage as p_sample().
"""
out = self.p_mean_variance(
model,
x,
t,
dynamic_thresholding_p=dynamic_thresholding_p,
dynamic_thresholding_c=dynamic_thresholding_c,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = self._predict_eps_from_xstart(x, t, out['pred_xstart'])
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
sigma = (
eta
* torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
* torch.sqrt(1 - alpha_bar / alpha_bar_prev)
)
# Equation 12.
noise = torch.randn_like(x)
mean_pred = (
out['pred_xstart'] * torch.sqrt(alpha_bar_prev)
+ torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
sample = mean_pred + nonzero_mask * sigma * noise
return {'sample': sample, 'pred_xstart': out['pred_xstart']}
def ddim_reverse_sample(
self,
model,
x,
t,
clip_denoised=True,
dynamic_thresholding_p=0.99,
dynamic_thresholding_c=1.7,
denoised_fn=None,
model_kwargs=None,
eta=0.0,
):
"""
Sample x_{t+1} from the model using DDIM reverse ODE.
"""
assert eta == 0.0, 'Reverse ODE only for deterministic path'
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
dynamic_thresholding_p=dynamic_thresholding_p,
dynamic_thresholding_c=dynamic_thresholding_c,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
- out['pred_xstart']
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
# Equation 12. reversed
mean_pred = (
out['pred_xstart'] * torch.sqrt(alpha_bar_next)
+ torch.sqrt(1 - alpha_bar_next) * eps
)
return {'sample': mean_pred, 'pred_xstart': out['pred_xstart']}
def ddim_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
dynamic_thresholding_p=0.99,
dynamic_thresholding_c=1.7,
denoised_fn=None,
model_kwargs=None,
device=None,
progress=False,
eta=0.0,
sample_fn=None,
):
"""
Generate samples from the model using DDIM.
Same usage as p_sample_loop().
"""
final = None
for step_idx, sample in enumerate(self.ddim_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
dynamic_thresholding_p=dynamic_thresholding_p,
dynamic_thresholding_c=dynamic_thresholding_c,
model_kwargs=model_kwargs,
device=device,
progress=progress,
eta=eta,
)):
if sample_fn is not None:
sample = sample_fn(step_idx, sample)
final = sample
return final['sample']
def ddim_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=True,
dynamic_thresholding_p=0.99,
dynamic_thresholding_c=1.7,
denoised_fn=None,
model_kwargs=None,
device=None,
progress=False,
eta=0.0,
):
"""
Use DDIM to sample from the model and yield intermediate samples from
each timestep of DDIM.
Same usage as p_sample_loop_progressive().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
else:
img = torch.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
t = torch.tensor([i] * shape[0], device=device)
with torch.no_grad():
out = self.ddim_sample(
model,
img,
t,
clip_denoised=clip_denoised,
dynamic_thresholding_p=dynamic_thresholding_p,
dynamic_thresholding_c=dynamic_thresholding_c,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
eta=eta,
)
yield out
img = out['sample']
def _vb_terms_bpd(
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
):
"""
Get a term for the variational lower-bound.
The resulting units are bits (rather than nats, as one might expect).
This allows for comparison to other papers.
:return: a dict with the following keys:
- 'output': a shape [N] tensor of NLLs or KLs.
- 'pred_xstart': the x_0 predictions.
"""
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t
)
out = self.p_mean_variance(
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
)
kl = normal_kl(
true_mean, true_log_variance_clipped, out['mean'], out['log_variance']
)
kl = mean_flat(kl) / np.log(2.0)
decoder_nll = -discretized_gaussian_log_likelihood(
x_start, means=out['mean'], log_scales=0.5 * out['log_variance']
)
assert decoder_nll.shape == x_start.shape
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
# At the first timestep return the decoder NLL,
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
output = torch.where((t == 0), decoder_nll, kl)
return {'output': output, 'pred_xstart': out['pred_xstart']}
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
"""
Compute training losses for a single timestep.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param t: a batch of timestep indices.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param noise: if specified, the specific Gaussian noise to try to remove.
:return: a dict with the key "loss" containing a tensor of shape [N].
Some mean or variance settings may also have other keys.
"""
if model_kwargs is None:
model_kwargs = {}
if noise is None:
noise = torch.randn_like(x_start)
x_t = self.q_sample(x_start, t, noise=noise)
terms = {}
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
terms['loss'] = self._vb_terms_bpd(
model=model,
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,
model_kwargs=model_kwargs,
)['output']
if self.loss_type == LossType.RESCALED_KL:
terms['loss'] *= self.num_timesteps
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
if self.model_var_type in [
ModelVarType.LEARNED,
ModelVarType.LEARNED_RANGE,
]:
B, C = x_t.shape[:2]
assert model_output.shape == (B, C * 2, *x_t.shape[2:])
model_output, model_var_values = torch.split(model_output, C, dim=1)
# Learn the variance using the variational bound, but don't let
# it affect our mean prediction.
frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1)
terms['vb'] = self._vb_terms_bpd(
model=lambda *args, r=frozen_out: r,
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,
)['output']
if self.loss_type == LossType.RESCALED_MSE:
# Divide by 1000 for equivalence with initial implementation.
# Without a factor of 1/1000, the VB term hurts the MSE term.
terms['vb'] *= self.num_timesteps / 1000.0
target = {
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t
)[0],
ModelMeanType.START_X: x_start,
ModelMeanType.EPSILON: noise,
}[self.model_mean_type]
assert model_output.shape == target.shape == x_start.shape
terms['mse'] = mean_flat((target - model_output) ** 2)
if 'vb' in terms:
terms['loss'] = terms['mse'] + terms['vb']
else:
terms['loss'] = terms['mse']
else:
raise NotImplementedError(self.loss_type)
return terms
def _prior_bpd(self, x_start):
"""
Get the prior KL term for the variational lower-bound, measured in
bits-per-dim.
This term can't be optimized, as it only depends on the encoder.
:param x_start: the [N x C x ...] tensor of inputs.
:return: a batch of [N] KL values (in bits), one per batch element.
"""
batch_size = x_start.shape[0]
t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
kl_prior = normal_kl(
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
)
return mean_flat(kl_prior) / np.log(2.0)
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
"""
Compute the entire variational lower-bound, measured in bits-per-dim,
as well as other related quantities.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param clip_denoised: if True, clip denoised samples.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
- total_bpd: the total variational lower-bound, per batch element.
- prior_bpd: the prior term in the lower-bound.
- vb: an [N x T] tensor of terms in the lower-bound.
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
"""
device = x_start.device
batch_size = x_start.shape[0]
vb = []
xstart_mse = []
mse = []
for t in list(range(self.num_timesteps))[::-1]:
t_batch = torch.tensor([t] * batch_size, device=device)
noise = torch.randn_like(x_start)
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
# Calculate VLB term at the current timestep
with torch.no_grad():
out = self._vb_terms_bpd(
model,
x_start=x_start,
x_t=x_t,
t=t_batch,
clip_denoised=clip_denoised,
model_kwargs=model_kwargs,
)
vb.append(out['output'])
xstart_mse.append(mean_flat((out['pred_xstart'] - x_start) ** 2))
eps = self._predict_eps_from_xstart(x_t, t_batch, out['pred_xstart'])
mse.append(mean_flat((eps - noise) ** 2))
vb = torch.stack(vb, dim=1)
xstart_mse = torch.stack(xstart_mse, dim=1)
mse = torch.stack(mse, dim=1)
prior_bpd = self._prior_bpd(x_start)
total_bpd = vb.sum(dim=1) + prior_bpd
return {
'total_bpd': total_bpd,
'prior_bpd': prior_bpd,
'vb': vb,
'xstart_mse': xstart_mse,
'mse': mse,
}
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)
================================================
FILE: deepfloyd_if/model/losses.py
================================================
# -*- coding: utf-8 -*-
"""
Helpers for various likelihood-based losses. These are ported from the original
Ho et al. diffusion model codebase:
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
"""
import torch
import numpy as np
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, 'at least one argument must be a Tensor'
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for th.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)
def approx_standard_normal_cdf(x):
"""
A fast approximation of the cumulative distribution function of the
standard normal.
"""
return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3))))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a Gaussian distribution discretizing to a
given image.
:param x: the target images. It is assumed that this was uint8 values,
rescaled to the range [-1, 1].
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
assert x.shape == means.shape == log_scales.shape
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < -0.999,
log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))),
)
assert log_probs.shape == x.shape
return log_probs
================================================
FILE: deepfloyd_if/model/nn.py
================================================
# -*- coding: utf-8 -*-
import math
import torch
import torch.nn.functional as F
from torch import nn
from torch import Tensor
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def gelu(x):
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
@torch.jit.script
def gelu_jit(x):
"""OpenAI's gelu implementation."""
return gelu(x)
class GELUJit(torch.nn.Module):
def forward(self, input: Tensor) -> Tensor:
return gelu_jit(input)
def get_activation(activation):
if activation == 'silu':
return torch.nn.SiLU()
elif activation == 'gelu_jit':
return GELUJit()
elif activation == 'gelu':
return torch.nn.GELU()
elif activation == 'none':
return torch.nn.Identity()
else:
raise ValueError(f'unknown activation type {activation}')
class GroupNorm32(nn.GroupNorm):
def __init__(self, num_groups, num_channels, eps=1e-5, dtype=None):
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, dtype=dtype)
def forward(self, x):
y = super().forward(x).to(x.dtype)
return y
class AttentionPooling(nn.Module):
def __init__(self, num_heads, embed_dim, dtype=None):
super().__init__()
self.dtype = dtype
self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
self.num_heads = num_heads
self.dim_per_head = embed_dim // self.num_heads
def forward(self, x):
bs, length, width = x.size()
def shape(x):
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, -1, self.num_heads, self.dim_per_head)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs*self.num_heads, -1, self.dim_per_head)
# (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
x = x.transpose(1, 2)
return x
class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
# (bs*n_heads, class_token_length, dim_per_head)
q = shape(self.q_proj(class_token))
# (bs*n_heads, length+class_token_length, dim_per_head)
k = shape(self.k_proj(x))
v = shape(self.v_proj(x))
# (bs*n_heads, class_token_length, length+class_token_length):
scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
weight = torch.einsum(
'bct,bcs->bts', q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
# (bs*n_heads, dim_per_head, class_token_length)
a = torch.einsum('bts,bcs->bct', weight, v)
# (bs, length+1, width)
a = a.reshape(bs, -1, 1).transpose(1, 2)
return a[:, 0, :] # cls_token
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f'unsupported dimensions: {dims}')
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f'unsupported dimensions: {dims}')
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def normalization(channels, dtype=None):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(num_channels=channels, num_groups=32, dtype=dtype)
def timestep_embedding(timesteps, dim, max_period=10000, dtype=None):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if dtype is None:
dtype = torch.float32
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device, dtype=dtype)
args = timesteps[:, None].type(dtype) * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def attention(q, k, v, d_k):
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
scores = F.softmax(scores, dim=-1)
output = torch.matmul(scores, v)
return output
================================================
FILE: deepfloyd_if/model/resample.py
================================================
# -*- coding: utf-8 -*-
from abc import ABC, abstractmethod
import torch
import numpy as np
class ScheduleSampler(ABC):
"""
A distribution over timesteps in the diffusion process, intended to reduce
variance of the objective.
By default, samplers perform unbiased importance sampling, in which the
objective's mean is unchanged.
However, subclasses may override sample() to change how the resampled
terms are reweighted, allowing for actual changes in the objective.
"""
@abstractmethod
def weights(self):
"""
Get a numpy array of weights, one per diffusion step.
The weights needn't be normalized, but must be positive.
"""
def sample(self, batch_size, device):
"""
Importance-sample timesteps for a batch.
:param batch_size: the number of timesteps.
:param device: the torch device to save to.
:return: a tuple (timesteps, weights):
- timesteps: a tensor of timestep indices.
- weights: a tensor of weights to scale the resulting losses.
"""
w = self.weights()
p = w / np.sum(w)
indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
indices = torch.from_numpy(indices_np).long().to(device)
weights_np = 1 / (len(p) * p[indices_np])
weights = torch.from_numpy(weights_np).float().to(device)
return indices, weights
class UniformSampler(ScheduleSampler):
def __init__(self, num_timesteps):
self._weights = np.ones([num_timesteps])
def weights(self):
return self._weights
class StaticSampler(ABC):
def sample(self, batch_size, device, static_step=100):
indices_np = np.ones(batch_size, dtype=np.int) * static_step
weights_np = np.ones(batch_size, dtype=np.int)
indices = torch.from_numpy(indices_np).long().to(device)
weights = torch.from_numpy(weights_np).float().to(device)
return indices, weights
================================================
FILE: deepfloyd_if/model/respace.py
================================================
# -*- coding: utf-8 -*-
import torch
import numpy as np
from . import gaussian_diffusion as gd
def create_gaussian_diffusion(
*,
steps=1000,
learn_sigma=False,
sigma_small=False,
noise_schedule='linear',
use_kl=False,
predict_xstart=False,
rescale_timesteps=False,
rescale_learned_sigmas=False,
timestep_respacing='',
):
betas = gd.get_named_beta_schedule(noise_schedule, steps)
if use_kl:
loss_type = gd.LossType.RESCALED_KL
elif rescale_learned_sigmas:
loss_type = gd.LossType.RESCALED_MSE
else:
loss_type = gd.LossType.MSE
if not timestep_respacing:
timestep_respacing = [steps]
return SpacedDiffusion(
use_timesteps=space_timesteps(steps, timestep_respacing),
betas=betas,
model_mean_type=(
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
),
model_var_type=(
(
gd.ModelVarType.FIXED_LARGE
if not sigma_small
else gd.ModelVarType.FIXED_SMALL
)
if not learn_sigma
else gd.ModelVarType.LEARNED_RANGE
),
loss_type=loss_type,
rescale_timesteps=rescale_timesteps,
)
def space_timesteps(num_timesteps, section_counts):
"""
Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
of the original process.
For example, if there's 300 timesteps and the section counts are [10,15,20]
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
If the stride is a string starting with "ddim", then the fixed striding
from the DDIM paper is used, and only one section is allowed.
:param num_timesteps: the number of diffusion steps in the original
process to divide up.
:param section_counts: either a list of numbers, or a string containing
comma-separated numbers, indicating the step count
per section. As a special case, use "ddimN" where N
is a number of steps to use the striding from the
DDIM paper.
:return: a set of diffusion steps from the original process to use.
"""
if isinstance(section_counts, str):
if section_counts.startswith('ddim'):
desired_count = int(section_counts[len('ddim'):])
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i))
raise ValueError(
f'cannot create exactly {num_timesteps} steps with an integer stride'
)
elif section_counts == 'fast27':
steps = space_timesteps(num_timesteps, '10,10,3,2,2')
# Help reduce DDIM artifacts from noisiest timesteps.
steps.remove(num_timesteps - 1)
steps.add(num_timesteps - 3)
return steps
section_counts = [int(x) for x in section_counts.split(',')]
size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts)
start_idx = 0
all_steps = []
for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0)
if size < section_count:
raise ValueError(
f'cannot divide section of {size} steps into {section_count}'
)
if section_count <= 1:
frac_stride = 1
else:
frac_stride = (size - 1) / (section_count - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(section_count):
taken_steps.append(start_idx + round(cur_idx))
cur_idx += frac_stride
all_steps += taken_steps
start_idx += size
return set(all_steps)
class SpacedDiffusion(gd.GaussianDiffusion):
"""
A diffusion process which can skip steps in a base diffusion process.
:param use_timesteps: a collection (sequence or set) of timesteps from the
original diffusion process to retain.
:param kwargs: the kwargs to create the base diffusion process.
"""
def __init__(self, use_timesteps, **kwargs):
self.use_timesteps = set(use_timesteps)
self.timestep_map = []
self.original_num_steps = len(kwargs['betas'])
base_diffusion = gd.GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
last_alpha_cumprod = 1.0
new_betas = []
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
if i in self.use_timesteps:
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
self.timestep_map.append(i)
kwargs['betas'] = np.array(new_betas)
super().__init__(**kwargs)
def p_mean_variance(
self, model, *args, **kwargs
): # pylint: disable=signature-differs
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
def training_losses(
self, model, *args, **kwargs
): # pylint: disable=signature-differs
return super().training_losses(self._wrap_model(model), *args, **kwargs)
def _wrap_model(self, model):
if isinstance(model, _WrappedModel):
return model
return _WrappedModel(
model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
)
def _scale_timesteps(self, t):
# Scaling is done by the wrapped model.
return t
class _WrappedModel:
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
self.model = model
self.timestep_map = timestep_map
self.rescale_timesteps = rescale_timesteps
self.original_num_steps = original_num_steps
def __call__(self, x, ts, **kwargs):
map_tensor = torch.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
new_ts = map_tensor[ts]
if self.rescale_timesteps:
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
return self.model(x, new_ts, **kwargs)
================================================
FILE: deepfloyd_if/model/unet.py
================================================
# -*- coding: utf-8 -*-
import os
import math
from abc import abstractmethod
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from .nn import avg_pool_nd, conv_nd, linear, normalization, timestep_embedding, zero_module, get_activation, \
AttentionPooling
_FORCE_MEM_EFFICIENT_ATTN = int(os.environ.get('FORCE_MEM_EFFICIENT_ATTN', 0))
print('FORCE_MEM_EFFICIENT_ATTN=', _FORCE_MEM_EFFICIENT_ATTN, '@UNET:QKVATTENTION')
if _FORCE_MEM_EFFICIENT_ATTN:
from xformers.ops import memory_efficient_attention # noqa
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb, encoder_out=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, AttentionBlock):
x = layer(x, encoder_out)
else:
x = layer(x)
return x
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining a convolution is applied.
:param dims: determines the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, dtype=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
self.dtype = dtype
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1, dtype=self.dtype)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest')
else:
if self.dtype == torch.bfloat16:
x = x.type(torch.float32 if x.device.type == 'cpu' else torch.float16)
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.dtype == torch.bfloat16:
x = x.type(torch.bfloat16)
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining a convolution is applied.
:param dims: determines the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, dtype=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
self.dtype = dtype
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1, dtype=self.dtype)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: specified, the number of out channels.
:param use_conv: True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines the signal is 1D, 2D, or 3D.
:param up: True, use this block for upsampling.
:param down: True, use this block for downsampling.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
activation,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
up=False,
down=False,
dtype=None,
efficient_activation=False,
scale_skip_connection=False,
):
super().__init__()
self.dtype = dtype
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_scale_shift_norm = use_scale_shift_norm
self.efficient_activation = efficient_activation
self.scale_skip_connection = scale_skip_connection
self.in_layers = nn.Sequential(
normalization(channels, dtype=self.dtype),
get_activation(activation),
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=self.dtype),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims, dtype=self.dtype)
self.x_upd = Upsample(channels, False, dims, dtype=self.dtype)
elif down:
self.h_upd = Downsample(channels, False, dims, dtype=self.dtype)
self.x_upd = Downsample(channels, False, dims, dtype=self.dtype)
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.Identity() if self.efficient_activation else get_activation(activation),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
dtype=self.dtype
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels, dtype=self.dtype),
get_activation(activation),
nn.Dropout(p=dropout),
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=self.dtype)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=self.dtype)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=self.dtype)
def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
res = self.skip_connection(x) + h
if self.scale_skip_connection:
res *= 0.7071 # 1 / sqrt(2), https://arxiv.org/pdf/2104.07636.pdf
return res
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
disable_self_attention=False,
encoder_channels=None,
dtype=None,
):
super().__init__()
self.dtype = dtype
self.channels = channels
self.disable_self_attention = disable_self_attention
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}'
self.num_heads = channels // num_head_channels
self.norm = normalization(channels, dtype=self.dtype)
self.qkv = conv_nd(1, channels, channels * 3, 1, dtype=self.dtype)
if self.disable_self_attention:
self.qkv = conv_nd(1, channels, channels, 1, dtype=self.dtype)
else:
self.qkv = conv_nd(1, channels, channels * 3, 1, dtype=self.dtype)
self.attention = QKVAttention(self.num_heads, disable_self_attention=disable_self_attention)
if encoder_channels is not None:
self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1, dtype=self.dtype)
self.norm_encoder = normalization(encoder_channels, dtype=self.dtype)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1, dtype=self.dtype))
def forward(self, x, encoder_out=None):
b, c, *spatial = x.shape
qkv = self.qkv(self.norm(x).view(b, c, -1))
if encoder_out is not None:
# from imagen article: https://arxiv.org/pdf/2205.11487.abs
encoder_out = self.norm_encoder(encoder_out)
# # #
encoder_out = self.encoder_kv(encoder_out)
h = self.attention(qkv, encoder_out)
else:
h = self.attention(qkv)
h = self.proj_out(h)
return x + h.reshape(b, c, *spatial)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads, disable_self_attention=False):
super().__init__()
self.n_heads = n_heads
self.disable_self_attention = disable_self_attention
def forward(self, qkv, encoder_kv=None):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
if self.disable_self_attention:
ch = width // (1 * self.n_heads)
q, = qkv.reshape(bs * self.n_heads, ch * 1, length).split(ch, dim=1)
else:
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
if encoder_kv is not None:
assert encoder_kv.shape[1] == self.n_heads * ch * 2
if self.disable_self_attention:
k, v = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
else:
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
k = torch.cat([ek, k], dim=-1)
v = torch.cat([ev, v], dim=-1)
scale = 1 / math.sqrt(math.sqrt(ch))
if _FORCE_MEM_EFFICIENT_ATTN:
q, k, v = map(lambda t: t.permute(0, 2, 1).contiguous(), (q, k, v))
a = memory_efficient_attention(q, k, v)
a = a.permute(0, 2, 1)
else:
weight = torch.einsum(
'bct,bcs->bts', q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum('bts,bcs->bct', weight, v)
return a.reshape(bs, -1, length)
class UNetModel(nn.Module):
"""
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: True, use learned convolutions for upsampling and
downsampling.
:param dims: determines the signal is 1D, 2D, or 3D.
:param num_classes: specified (as an int), then this model will be
class-conditional with `num_classes` classes.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
"""
def __init__(
self,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
activation,
encoder_dim,
att_pool_heads,
encoder_channels,
image_size,
disable_self_attentions=None,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
precision='32',
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
efficient_activation=False,
scale_skip_connection=False,
):
super().__init__()
if num_heads_upsample == -1:
num_heads_upsample = num_heads
self.encoder_channels = encoder_channels
self.encoder_dim = encoder_dim
self.efficient_activation = efficient_activation
self.scale_skip_connection = scale_skip_connection
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.dropout = dropout
# adapt attention resolutions
if isinstance(attention_resolutions, str):
self.attention_resolutions = []
for res in attention_resolutions.split(','):
self.attention_resolutions.append(image_size // int(res))
else:
self.attention_resolutions = attention_resolutions
self.attention_resolutions = tuple(self.attention_resolutions)
#
# adapt disable self attention resolutions
if not disable_self_attentions:
self.disable_self_attentions = []
elif disable_self_attentions is True:
self.disable_self_attentions = attention_resolutions
elif isinstance(disable_self_attentions, str):
self.disable_self_attentions = []
for res in disable_self_attentions.split(','):
self.disable_self_attentions.append(image_size // int(res))
else:
self.disable_self_attentions = disable_self_attentions
self.disable_self_attentions = tuple(self.disable_self_attentions)
#
# adapt channel mult
if isinstance(channel_mult, str):
self.channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(','))
else:
self.channel_mult = tuple(channel_mult)
#
self.conv_resample = conv_resample
self.num_classes = num_classes
self.dtype = torch.float32
self.precision = str(precision)
self.use_fp16 = precision == '16'
if self.precision == '16':
self.dtype = torch.float16
elif self.precision == 'bf16':
self.dtype = torch.bfloat16
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.time_embed_dim = model_channels * max(self.channel_mult)
self.time_embed = nn.Sequential(
linear(model_channels, self.time_embed_dim, dtype=self.dtype),
get_activation(activation),
linear(self.time_embed_dim, self.time_embed_dim, dtype=self.dtype),
)
if self.num_classes is not None:
self.label_emb = nn.Embedding(num_classes, self.time_embed_dim)
ch = input_ch = int(self.channel_mult[0] * model_channels)
self.input_blocks = nn.ModuleList(
[TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1, dtype=self.dtype))]
)
self._feature_size = ch
input_block_chans = [ch]
ds = 1
if isinstance(num_res_blocks, int):
num_res_blocks = [num_res_blocks]*len(self.channel_mult)
self.num_res_blocks = num_res_blocks
for level, mult in enumerate(self.channel_mult):
for _ in range(num_res_blocks[level]):
layers = [
ResBlock(
ch,
self.time_embed_dim,
dropout,
out_channels=int(mult * model_channels),
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
activation=activation,
efficient_activation=self.efficient_activation,
scale_skip_connection=self.scale_skip_connection,
)
]
ch = int(mult * model_channels)
if ds in self.attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
encoder_channels=encoder_channels,
dtype=self.dtype,
disable_self_attention=ds in self.disable_self_attentions,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(self.channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
self.time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
dtype=self.dtype,
activation=activation,
efficient_activation=self.efficient_activation,
scale_skip_connection=self.scale_skip_connection,
)
if resblock_updown
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
self.time_embed_dim,
dropout,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
activation=activation,
efficient_activation=self.efficient_activation,
scale_skip_connection=self.scale_skip_connection,
),
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
encoder_channels=encoder_channels,
dtype=self.dtype,
disable_self_attention=ds in self.disable_self_attentions,
),
ResBlock(
ch,
self.time_embed_dim,
dropout,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
activation=activation,
efficient_activation=self.efficient_activation,
scale_skip_connection=self.scale_skip_connection,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(self.channel_mult))[::-1]:
for i in range(num_res_blocks[level] + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
ch + ich,
self.time_embed_dim,
dropout,
out_channels=int(model_channels * mult),
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
activation=activation,
efficient_activation=self.efficient_activation,
scale_skip_connection=self.scale_skip_connection,
)
]
ch = int(model_channels * mult)
if ds in self.attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads_upsample,
num_head_channels=num_head_channels,
encoder_channels=encoder_channels,
dtype=self.dtype,
disable_self_attention=ds in self.disable_self_attentions,
)
)
if level and i == num_res_blocks[level]:
out_ch = ch
layers.append(
ResBlock(
ch,
self.time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
dtype=self.dtype,
activation=activation,
efficient_activation=self.efficient_activation,
scale_skip_connection=self.scale_skip_connection,
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
normalization(ch, dtype=self.dtype),
get_activation(activation),
zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1, dtype=self.dtype)),
)
self.activation_layer = get_activation(activation) if self.efficient_activation else nn.Identity()
self.encoder_pooling = nn.Sequential(
nn.LayerNorm(encoder_dim, dtype=self.dtype),
AttentionPooling(att_pool_heads, encoder_dim, dtype=self.dtype),
nn.Linear(encoder_dim, self.time_embed_dim, dtype=self.dtype),
nn.LayerNorm(self.time_embed_dim, dtype=self.dtype)
)
if encoder_dim != encoder_channels:
self.encoder_proj = nn.Linear(encoder_dim, encoder_channels, dtype=self.dtype)
else:
self.encoder_proj = nn.Identity()
self.cache = None
def forward(self, x, timesteps, text_emb, timestep_text_emb=None, aug_emb=None, use_cache=False, **kwargs):
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels, dtype=self.dtype))
if use_cache and self.cache is not None:
encoder_out, encoder_pool = self.cache
else:
text_emb = text_emb.type(self.dtype)
encoder_out = self.encoder_proj(text_emb)
encoder_out = encoder_out.permute(0, 2, 1) # NLC -> NCL
if timestep_text_emb is None:
timestep_text_emb = text_emb
encoder_pool = self.encoder_pooling(timestep_text_emb)
if use_cache:
self.cache = (encoder_out, encoder_pool)
emb = emb + encoder_pool.to(emb)
if aug_emb is not None:
emb = emb + aug_emb.to(emb)
emb = self.activation_layer(emb)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, encoder_out)
hs.append(h)
h = self.middle_block(h, emb, encoder_out)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb, encoder_out)
h = h.type(self.dtype)
h = self.out(h)
return h
class SuperResUNetModel(UNetModel):
"""
A text2im model that performs super-resolution.
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def __init__(self, low_res_diffusion, interpolate_mode='bilinear', *args, **kwargs):
self.low_res_diffusion = low_res_diffusion
self.interpolate_mode = interpolate_mode
super().__init__(*args, **kwargs)
self.aug_proj = nn.Sequential(
linear(self.model_channels, self.time_embed_dim, dtype=self.dtype),
get_activation(kwargs['activation']),
linear(self.time_embed_dim, self.time_embed_dim, dtype=self.dtype),
)
def forward(self, x, timesteps, low_res, aug_level=None, **kwargs):
bs, _, new_height, new_width = x.shape
align_corners = True
if self.interpolate_mode == 'nearest':
align_corners = None
upsampled = F.interpolate(
low_res, (new_height, new_width), mode=self.interpolate_mode, align_corners=align_corners
)
if aug_level is None:
aug_steps = (np.random.random(bs)*1000).astype(np.int64) # uniform [0, 1)
aug_steps = torch.from_numpy(aug_steps).to(x.device, dtype=torch.long)
else:
aug_steps = torch.tensor([int(aug_level * 1000)]).repeat(bs).to(x.device, dtype=torch.long)
upsampled = self.low_res_diffusion.q_sample(upsampled, aug_steps)
x = torch.cat([x, upsampled], dim=1)
aug_emb = self.aug_proj(
timestep_embedding(aug_steps, self.model_channels, dtype=self.dtype)
)
return super().forward(x, timesteps, aug_emb=aug_emb, **kwargs)
================================================
FILE: deepfloyd_if/modules/__init__.py
================================================
# -*- coding: utf-8 -*-
from .stage_I import IFStageI
from .stage_II import IFStageII
from .stage_III import IFStageIII
from .stage_III_sd_x4 import StableStageIII
from .t5 import T5Embedder
from .base import IFBaseModule
__all__ = ['IFBaseModule', 'IFStageI', 'IFStageII', 'IFStageIII', 'StableStageIII', 'T5Embedder']
================================================
FILE: deepfloyd_if/modules/base.py
================================================
# -*- coding: utf-8 -*-
import os
import random
import platform
from datetime import datetime
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as T
from PIL import Image
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download
from accelerate.utils import set_module_tensor_to_device
from .. import utils
from ..model.respace import create_gaussian_diffusion
from .utils import load_model_weights, predict_proba, clip_process_generations
class IFBaseModule:
stage = '-'
available_models = []
cpu_zero_emb = np.load(os.path.join(utils.RESOURCES_ROOT, 'zero_t5-v1_1-xxl_vector.npy'))
cpu_zero_emb = torch.from_numpy(cpu_zero_emb)
respacing_modes = {
'fast27': '10,10,3,2,2',
'smart27': '7,4,2,1,2,4,7',
'smart50': '10,6,4,3,2,2,3,4,6,10',
'smart100': '1,1,1,1,2,2,2,2,2,2,3,3,4,4,5,5,6,7,7,8,9,10,13',
'smart185': '1,1,2,2,2,3,3,3,4,5,6,7,8,9,10,11,12,13,14,15,16,18,20',
'super27': '1,1,1,1,1,1,1,2,5,13', # for III super-res
'super40': '2,2,2,2,2,2,3,4,6,15', # for III super-res
'super100': '4,4,6,6,8,8,10,10,14,30', # for III super-res
}
wm_pil_img = Image.open(os.path.join(utils.RESOURCES_ROOT, 'wm.png'))
try:
import clip # noqa
except ModuleNotFoundError:
print('Warning! You should install CLIP: "pip install git+https://github.com/openai/CLIP.git --no-deps"')
raise
clip_model, clip_preprocess = clip.load('ViT-L/14', device='cpu')
clip_model.eval()
cpu_w_weights, cpu_w_biases = load_model_weights(os.path.join(utils.RESOURCES_ROOT, 'w_head_v1.npz'))
cpu_p_weights, cpu_p_biases = load_model_weights(os.path.join(utils.RESOURCES_ROOT, 'p_head_v1.npz'))
w_threshold, p_threshold = 0.5, 0.5
def __init__(self, dir_or_name, device, pil_img_size=256, cache_dir=None, hf_token=None):
self.hf_token = hf_token
self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_')
self.dir_or_name = dir_or_name
self.conf = self.load_conf(dir_or_name) if not self.use_diffusers else None
self.device = torch.device(device)
self.zero_emb = self.cpu_zero_emb.clone().to(self.device)
self.pil_img_size = pil_img_size
@property
def use_diffusers(self):
return False
def embeddings_to_image(
self, t5_embs, low_res=None, *,
style_t5_embs=None,
positive_t5_embs=None,
negative_t5_embs=None,
batch_repeat=1,
dynamic_thresholding_p=0.95,
sample_loop='ddpm',
sample_timestep_respacing='smart185',
dynamic_thresholding_c=1.5,
guidance_scale=7.0,
aug_level=0.25,
positive_mixer=0.15,
blur_sigma=None,
img_size=None,
img_scale=4.0,
aspect_ratio='1:1',
progress=True,
seed=None,
sample_fn=None,
support_noise=None,
support_noise_less_qsample_steps=0,
inpainting_mask=None,
**kwargs,
):
self._clear_cache()
image_w, image_h = self._get_image_sizes(low_res, img_size, aspect_ratio, img_scale)
diffusion = self.get_diffusion(sample_timestep_respacing)
bs_scale = 2 if positive_t5_embs is None else 3
def model_fn(x_t, ts, **kwargs):
half = x_t[: len(x_t) // bs_scale]
combined = torch.cat([half]*bs_scale, dim=0)
model_out = self.model(combined, ts, **kwargs)
eps, rest = model_out[:, :3], model_out[:, 3:]
if bs_scale == 3:
cond_eps, pos_cond_eps, uncond_eps = torch.split(eps, len(eps) // bs_scale, dim=0)
half_eps = uncond_eps + guidance_scale * (
cond_eps * (1 - positive_mixer) + pos_cond_eps * positive_mixer - uncond_eps)
pos_half_eps = uncond_eps + guidance_scale * (pos_cond_eps - uncond_eps)
eps = torch.cat([half_eps, pos_half_eps, half_eps], dim=0)
else:
cond_eps, uncond_eps = torch.split(eps, len(eps) // bs_scale, dim=0)
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)
seed = self.seed_everything(seed)
text_emb = t5_embs.to(self.device, dtype=self.model.dtype).repeat(batch_repeat, 1, 1)
batch_size = text_emb.shape[0] * batch_repeat
if positive_t5_embs is not None:
positive_t5_embs = positive_t5_embs.to(self.device, dtype=self.model.dtype).repeat(batch_repeat, 1, 1)
if negative_t5_embs is not None:
negative_t5_embs = negative_t5_embs.to(self.device, dtype=self.model.dtype).repeat(batch_repeat, 1, 1)
timestep_text_emb = None
if style_t5_embs is not None:
list_timestep_text_emb = [
style_t5_embs.to(self.device, dtype=self.model.dtype).repeat(batch_repeat, 1, 1),
]
if positive_t5_embs is not None:
list_timestep_text_emb.append(positive_t5_embs)
if negative_t5_embs is not None:
list_timestep_text_emb.append(negative_t5_embs)
else:
list_timestep_text_emb.append(
self.zero_emb.unsqueeze(0).repeat(batch_size, 1, 1).to(self.device, dtype=self.model.dtype))
timestep_text_emb = torch.cat(list_timestep_text_emb, dim=0).to(self.device, dtype=self.model.dtype)
metadata = {
'seed': seed,
'guidance_scale': guidance_scale,
'dynamic_thresholding_p': dynamic_thresholding_p,
'dynamic_thresholding_c': dynamic_thresholding_c,
'batch_size': batch_size,
'device_name': self.device_name,
'img_size': [image_w, image_h],
'sample_loop': sample_loop,
'sample_timestep_respacing': sample_timestep_respacing,
'stage': self.stage,
}
list_text_emb = [t5_embs.to(self.device)]
if positive_t5_embs is not None:
list_text_emb.append(positive_t5_embs.to(self.device))
if negative_t5_embs is not None:
list_text_emb.append(negative_t5_embs.to(self.device))
else:
list_text_emb.append(
self.zero_emb.unsqueeze(0).repeat(batch_size, 1, 1).to(self.device, dtype=self.model.dtype))
model_kwargs = dict(
text_emb=torch.cat(list_text_emb, dim=0).to(self.device, dtype=self.model.dtype),
timestep_text_emb=timestep_text_emb,
use_cache=True,
)
if low_res is not None:
if blur_sigma is not None:
low_res = T.GaussianBlur(3, sigma=(blur_sigma, blur_sigma))(low_res)
model_kwargs['low_res'] = torch.cat([low_res]*bs_scale, dim=0).to(self.device)
model_kwargs['aug_level'] = aug_level
if support_noise is None:
noise = torch.randn(
(batch_size * bs_scale, 3, image_h, image_w), device=self.device, dtype=self.model.dtype)
else:
assert support_noise_less_qsample_steps < len(diffusion.timestep_map) - 1
assert support_noise.shape == (1, 3, image_h, image_w)
q_sample_steps = torch.tensor([int(len(diffusion.timestep_map) - 1 - support_noise_less_qsample_steps)])
support_noise = support_noise.cpu()
noise = support_noise.clone()
noise[inpainting_mask.cpu().bool() if inpainting_mask is not None else ...] = diffusion.q_sample(
support_noise[inpainting_mask.cpu().bool() if inpainting_mask is not None else ...],
q_sample_steps,
)
noise = noise.repeat(batch_size*bs_scale, 1, 1, 1).to(device=self.device, dtype=self.model.dtype)
if inpainting_mask is not None:
inpainting_mask = inpainting_mask.to(device=self.device, dtype=torch.long)
if sample_loop == 'ddpm':
with torch.no_grad():
sample = diffusion.p_sample_loop(
model_fn,
(batch_size * bs_scale, 3, image_h, image_w),
noise=noise,
clip_denoised=True,
model_kwargs=model_kwargs,
dynamic_thresholding_p=dynamic_thresholding_p,
dynamic_thresholding_c=dynamic_thresholding_c,
inpainting_mask=inpainting_mask,
device=self.device,
progress=progress,
sample_fn=sample_fn,
)[:batch_size]
elif sample_loop == 'ddim':
with torch.no_grad():
sample = diffusion.ddim_sample_loop(
model_fn,
(batch_size * bs_scale, 3, image_h, image_w),
noise=noise,
clip_denoised=True,
model_kwargs=model_kwargs,
dynamic_thresholding_p=dynamic_thresholding_p,
dynamic_thresholding_c=dynamic_thresholding_c,
device=self.device,
progress=progress,
sample_fn=sample_fn,
)[:batch_size]
else:
raise ValueError(f'Sample loop "{sample_loop}" doesnt support')
sample = self.__validate_generations(sample)
self._clear_cache()
return sample, metadata
def load_conf(self, dir_or_name, filename='config.yml'):
path = self._get_path_or_download_file_from_hf(dir_or_name, filename)
conf = OmegaConf.load(path)
return conf
def load_checkpoint(self, model, dir_or_name, filename='pytorch_model.bin'):
path = self._get_path_or_download_file_from_hf(dir_or_name, filename)
if os.path.exists(path):
checkpoint = torch.load(path, map_location='cpu')
param_device = 'cpu'
for param_name, param in checkpoint.items():
set_module_tensor_to_device(model, param_name, param_device, value=param)
else:
print(f'Warning! In directory "{dir_or_name}" filename "pytorch_model.bin" is not found.')
return model
def _get_path_or_download_file_from_hf(self, dir_or_name, filename):
if dir_or_name in self.available_models:
cache_dir = os.path.join(self.cache_dir, dir_or_name)
hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir,
force_filename=filename, token=self.hf_token)
return os.path.join(cache_dir, filename)
else:
return os.path.join(dir_or_name, filename)
def get_diffusion(self, timestep_respacing):
timestep_respacing = self.respacing_modes.get(timestep_respacing, timestep_respacing)
diffusion = create_gaussian_diffusion(
steps=1000,
learn_sigma=True,
sigma_small=False,
noise_schedule='cosine',
use_kl=False,
predict_xstart=False,
rescale_timesteps=True,
rescale_learned_sigmas=True,
timestep_respacing=timestep_respacing,
)
return diffusion
@staticmethod
def seed_everything(seed=None):
if seed is None:
seed = int((datetime.utcnow().timestamp() * 10 ** 6) % (2 ** 32 - 1))
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
return seed
def device_name(self):
if self.device.type == 'cpu':
return 'cpu_' + str(platform.processor())
if self.device.type == 'cuda':
return torch.cuda.get_device_name(self.device)
return '-'
def to_images(self, generations, disable_watermark=False):
bs, c, h, w = generations.shape
coef = min(h / self.pil_img_size, w / self.pil_img_size)
img_h, img_w = (int(h / coef), int(w / coef)) if coef < 1 else (h, w)
S1, S2 = 1024 ** 2, img_w * img_h
K = (S2 / S1) ** 0.5
wm_size, wm_x, wm_y = int(K * 62), img_w - int(14 * K), img_h - int(14 * K)
wm_img = self.wm_pil_img.resize(
(wm_size, wm_size), getattr(Image, 'Resampling', Image).BICUBIC, reducing_gap=None)
pil_images = []
for image in ((generations + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu():
pil_img = torchvision.transforms.functional.to_pil_image(image).convert('RGB')
pil_img = pil_img.resize((img_w, img_h), getattr(Image, 'Resampling', Image).NEAREST)
if not disable_watermark:
pil_img.paste(wm_img, box=(wm_x - wm_size, wm_y - wm_size, wm_x, wm_y), mask=wm_img.split()[-1])
pil_images.append(pil_img)
return pil_images
def show(self, pil_images, nrow=None, size=10):
if nrow is None:
nrow = round(len(pil_images)**0.5)
imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow)
if not isinstance(imgs, list):
imgs = [imgs.cpu()]
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(size, size))
for i, img in enumerate(imgs):
img = img.detach()
img = torchvision.transforms.functional.to_pil_image(img)
axs[0, i].imshow(np.asarray(img))
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
fix.show()
plt.show()
def _clear_cache(self):
self.model.cache = None
def _get_image_sizes(self, low_res, img_size, aspect_ratio, img_scale):
if low_res is not None:
bs, c, h, w = low_res.shape
image_h, image_w = int((h*img_scale)//32)*32, int((w*img_scale//32))*32
else:
scale_w, scale_h = aspect_ratio.split(':')
scale_w, scale_h = int(scale_w), int(scale_h)
coef = scale_w / scale_h
image_h, image_w = img_size, img_size
if coef >= 1:
image_w = int(round(img_size/8 * coef) * 8)
else:
image_h = int(round(img_size/8 / coef) * 8)
assert image_h % 8 == 0
assert image_w % 8 == 0
return image_w, image_h
def __validate_generations(self, generations):
with torch.no_grad():
imgs = clip_process_generations(generations)
image_features = self.clip_model.encode_image(imgs.to('cpu'))
image_features = image_features.detach().cpu().numpy().astype(np.float16)
p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases)
w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases)
query = p_pred > self.p_threshold
if query.sum() > 0:
generations[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(generations[query])
query = w_pred > self.w_threshold
if query.sum() > 0:
generations[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(generations[query])
return generations
================================================
FILE: deepfloyd_if/modules/stage_I.py
================================================
# -*- coding: utf-8 -*-
import accelerate
from .base import IFBaseModule
from ..model import UNetModel
class IFStageI(IFBaseModule):
stage = 'I'
available_models = ['IF-I-M-v1.0', 'IF-I-L-v1.0', 'IF-I-XL-v1.0']
def __init__(self, *args, model_kwargs=None, pil_img_size=64, **kwargs):
"""
:param conf_or_path:
:param device:
:param cache_dir:
:param use_auth_token:
"""
super().__init__(*args, pil_img_size=pil_img_size, **kwargs)
model_params = dict(self.conf.params)
model_params.update(model_kwargs or {})
with accelerate.init_empty_weights():
self.model = UNetModel(**model_params)
self.model = self.load_checkpoint(self.model, self.dir_or_name)
self.model.eval().to(self.device)
def embeddings_to_image(self, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None,
batch_repeat=1, dynamic_thresholding_p=0.95, sample_loop='ddpm', positive_mixer=0.25,
sample_timestep_respacing='150', dynamic_thresholding_c=1.5, guidance_scale=7.0,
aspect_ratio='1:1', progress=True, seed=None, sample_fn=None, **kwargs):
return super().embeddings_to_image(
t5_embs=t5_embs,
style_t5_embs=style_t5_embs,
positive_t5_embs=positive_t5_embs,
negative_t5_embs=negative_t5_embs,
batch_repeat=batch_repeat,
dynamic_thresholding_p=dynamic_thresholding_p,
dynamic_thresholding_c=dynamic_thresholding_c,
sample_loop=sample_loop,
sample_timestep_respacing=sample_timestep_respacing,
guidance_scale=guidance_scale,
img_size=64,
aspect_ratio=aspect_ratio,
progress=progress,
seed=seed,
sample_fn=sample_fn,
positive_mixer=positive_mixer,
**kwargs
)
================================================
FILE: deepfloyd_if/modules/stage_II.py
================================================
# -*- coding: utf-8 -*-
import accelerate
from .base import IFBaseModule
from ..model import SuperResUNetModel
class IFStageII(IFBaseModule):
stage = 'II'
available_models = ['IF-II-M-v1.0', 'IF-II-L-v1.0']
def __init__(self, *args, model_kwargs=None, pil_img_size=256, **kwargs):
super().__init__(*args, pil_img_size=pil_img_size, **kwargs)
model_params = dict(self.conf.params)
model_params.update(model_kwargs or {})
with accelerate.init_empty_weights():
self.model = SuperResUNetModel(low_res_diffusion=self.get_diffusion('1000'), **model_params)
self.model = self.load_checkpoint(self.model, self.dir_or_name)
self.model.eval().to(self.device)
def embeddings_to_image(
self, low_res, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1,
aug_level=0.25, dynamic_thresholding_p=0.95, dynamic_thresholding_c=1.0, sample_loop='ddpm',
sample_timestep_respacing='smart50', guidance_scale=4.0, img_scale=4.0, positive_mixer=0.5,
progress=True, seed=None, sample_fn=None, **kwargs):
return super().embeddings_to_image(
t5_embs=t5_embs,
low_res=low_res,
style_t5_embs=style_t5_embs,
positive_t5_embs=positive_t5_embs,
negative_t5_embs=negative_t5_embs,
batch_repeat=batch_repeat,
aug_level=aug_level,
dynamic_thresholding_p=dynamic_thresholding_p,
dynamic_thresholding_c=dynamic_thresholding_c,
sample_loop=sample_loop,
sample_timestep_respacing=sample_timestep_respacing,
guidance_scale=guidance_scale,
positive_mixer=positive_mixer,
img_size=256,
img_scale=img_scale,
progress=progress,
seed=seed,
sample_fn=sample_fn,
**kwargs
)
================================================
FILE: deepfloyd_if/modules/stage_III.py
================================================
# -*- coding: utf-8 -*-
import accelerate
from .base import IFBaseModule
from ..model import SuperResUNetModel
class IFStageIII(IFBaseModule):
available_models = ['IF-III-L-v1.0']
def __init__(self, *args, model_kwargs=None, pil_img_size=1024, **kwargs):
super().__init__(*args, pil_img_size=pil_img_size, **kwargs)
model_params = dict(self.conf.params)
model_params.update(model_kwargs or {})
with accelerate.init_empty_weights():
self.model = SuperResUNetModel(low_res_diffusion=self.get_diffusion('1000'), **model_params)
self.model = self.load_checkpoint(self.model, self.dir_or_name)
self.model.eval().to(self.device)
def embeddings_to_image(
self, low_res, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1,
aug_level=0.0, blur_sigma=None, dynamic_thresholding_p=0.95, dynamic_thresholding_c=1.0, positive_mixer=0.5,
sample_loop='ddpm', sample_timestep_respacing='super40', guidance_scale=4.0, img_scale=4.0,
progress=True, seed=None, sample_fn=None, **kwargs):
return super().embeddings_to_image(
t5_embs=t5_embs,
low_res=low_res,
style_t5_embs=style_t5_embs,
positive_t5_embs=positive_t5_embs,
negative_t5_embs=negative_t5_embs,
batch_repeat=batch_repeat,
aug_level=aug_level,
blur_sigma=blur_sigma,
dynamic_thresholding_p=dynamic_thresholding_p,
dynamic_thresholding_c=dynamic_thresholding_c,
sample_loop=sample_loop,
sample_timestep_respacing=sample_timestep_respacing,
guidance_scale=guidance_scale,
positive_mixer=positive_mixer,
img_size=1024,
img_scale=img_scale,
progress=progress,
seed=seed,
sample_fn=sample_fn,
**kwargs
)
================================================
FILE: deepfloyd_if/modules/stage_III_sd_x4.py
================================================
# -*- coding: utf-8 -*-
import diffusers
from diffusers import DiffusionPipeline, DDPMScheduler
import torch
import os
from .base import IFBaseModule
import packaging.version as pv
class StableStageIII(IFBaseModule):
available_models = ['stable-diffusion-x4-upscaler']
def __init__(self, *args, model_kwargs=None, pil_img_size=1024, **kwargs):
super().__init__(*args, pil_img_size=pil_img_size, **kwargs)
if pv.parse(diffusers.__version__) <= pv.parse('0.15.1'):
raise ValueError(
'Make sure to have `diffusers >= 0.16.0` installed.'
' Please run `pip install diffusers --upgrade`'
)
model_id = os.path.join('stabilityai', self.dir_or_name)
model_kwargs = model_kwargs or {}
precision = str(model_kwargs.get('precision', '16'))
if precision == '16':
torch_dtype = torch.float16
elif precision == 'bf16':
torch_dtype = torch.bfloat16
else:
torch_dtype = torch.float32
self.model = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype, token=self.hf_token)
self.model.to(self.device)
if bool(os.environ.get('FORCE_MEM_EFFICIENT_ATTN')):
self.model.enable_xformers_memory_efficient_attention()
@property
def use_diffusers(self):
if self.dir_or_name == self.available_models[-1]:
return True
elif os.path.isdir(self.dir_or_name) and os.path.isfile(os.path.join(self.dir_or_name, 'model_index.json')):
return True
return False
def embeddings_to_image(
self, low_res, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1,
aug_level=0.0, blur_sigma=None, dynamic_thresholding_p=0.95, dynamic_thresholding_c=1.0, positive_mixer=0.5,
sample_loop='ddpm', sample_timestep_respacing='75', guidance_scale=4.0, img_scale=4.0,
progress=True, seed=None, sample_fn=None, **kwargs):
prompt = kwargs.pop('prompt')
noise_level = kwargs.pop('noise_level', 20)
if sample_loop == 'ddpm':
self.model.scheduler = DDPMScheduler.from_config(self.model.scheduler.config)
else:
raise ValueError(f"For now only the 'ddpm' sample loop type is supported, but you passed {sample_loop}")
num_inference_steps = int(sample_timestep_respacing)
self.model.set_progress_bar_config(disable=not progress)
generator = torch.manual_seed(seed)
prompt = sum([batch_repeat * [p] for p in prompt], [])
low_res = low_res.repeat(batch_repeat, 1, 1, 1)
metadata = {
'image': low_res,
'prompt': prompt,
'noise_level': noise_level,
'generator': generator,
'guidance_scale': guidance_scale,
'num_inference_steps': num_inference_steps,
'output_type': 'pt',
}
images = self.model(**metadata).images
sample = self._IFBaseModule__validate_generations(images)
return sample, metadata
================================================
FILE: deepfloyd_if/modules/t5.py
================================================
# -*- coding: utf-8 -*-
import os
import re
import html
import urllib.parse as ul
import ftfy
import torch
from bs4 import BeautifulSoup
from transformers import T5EncoderModel, AutoTokenizer
from huggingface_hub import hf_hub_download
class T5Embedder:
available_models = ['t5-v1_1-xxl']
bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa
def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, cache_dir=None, hf_token=None, use_text_preprocessing=True,
t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None):
self.device = torch.device(device)
self.torch_dtype = torch_dtype or torch.bfloat16
if t5_model_kwargs is None:
t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype}
if use_offload_folder is not None:
t5_model_kwargs['offload_folder'] = use_offload_folder
t5_model_kwargs['device_map'] = {
'shared': self.device,
'encoder.embed_tokens': self.device,
'encoder.block.0': self.device,
'encoder.block.1': self.device,
'encoder.block.2': self.device,
'encoder.block.3': self.device,
'encoder.block.4': self.device,
'encoder.block.5': self.device,
'encoder.block.6': self.device,
'encoder.block.7': self.device,
'encoder.block.8': self.device,
'encoder.block.9': self.device,
'encoder.block.10': self.device,
'encoder.block.11': self.device,
'encoder.block.12': 'disk',
'encoder.block.13': 'disk',
'encoder.block.14': 'disk',
'encoder.block.15': 'disk',
'encoder.block.16': 'disk',
'encoder.block.17': 'disk',
'encoder.block.18': 'disk',
'encoder.block.19': 'disk',
'encoder.block.20': 'disk',
'encoder.block.21': 'disk',
'encoder.block.22': 'disk',
'encoder.block.23': 'disk',
'encoder.final_layer_norm': 'disk',
'encoder.dropout': 'disk',
}
else:
t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device}
self.use_text_preprocessing = use_text_preprocessing
self.hf_token = hf_token
self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_')
self.dir_or_name = dir_or_name
tokenizer_path, path = dir_or_name, dir_or_name
if dir_or_name in self.available_models:
cache_dir = os.path.join(self.cache_dir, dir_or_name)
for filename in [
'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin'
]:
hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir,
force_filename=filename, token=self.hf_token)
tokenizer_path, path = cache_dir, cache_dir
else:
cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl')
for filename in [
'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
]:
hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir,
force_filename=filename, token=self.hf_token)
tokenizer_path = cache_dir
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
def get_text_embeddings(self, texts):
texts = [self.text_preprocessing(text) for text in texts]
text_tokens_and_mask = self.tokenizer(
texts,
max_length=77,
padding='max_length',
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors='pt'
)
text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids']
text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask']
with torch.no_grad():
text_encoder_embs = self.model(
input_ids=text_tokens_and_mask['input_ids'].to(self.device),
attention_mask=text_tokens_and_mask['attention_mask'].to(self.device),
)['last_hidden_state'].detach()
return text_encoder_embs
def text_preprocessing(self, text):
if self.use_text_preprocessing:
# The exact text cleaning as was in the training stage:
text = self.clean_caption(text)
text = self.clean_caption(text)
return text
else:
return text.lower().strip()
@staticmethod
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def clean_caption(self, caption):
caption = str(caption)
caption = ul.unquote_plus(caption)
caption = caption.strip().lower()
caption = re.sub('', 'person', caption)
# urls:
caption = re.sub(
r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
'', caption) # regex for urls
caption = re.sub(
r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
'', caption) # regex for urls
# html:
caption = BeautifulSoup(caption, features='html.parser').text
# @
caption = re.sub(r'@[\w\d]+\b', '', caption)
# 31C0—31EF CJK Strokes
# 31F0—31FF Katakana Phonetic Extensions
# 3200—32FF Enclosed CJK Letters and Months
# 3300—33FF CJK Compatibility
# 3400—4DBF CJK Unified Ideographs Extension A
# 4DC0—4DFF Yijing Hexagram Symbols
# 4E00—9FFF CJK Unified Ideographs
caption = re.sub(r'[\u31c0-\u31ef]+', '', caption)
caption = re.sub(r'[\u31f0-\u31ff]+', '', caption)
caption = re.sub(r'[\u3200-\u32ff]+', '', caption)
caption = re.sub(r'[\u3300-\u33ff]+', '', caption)
caption = re.sub(r'[\u3400-\u4dbf]+', '', caption)
caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption)
caption = re.sub(r'[\u4e00-\u9fff]+', '', caption)
#######################################################
# все виды тире / all types of dash --> "-"
caption = re.sub(
r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa
'-', caption)
# кавычки к одному стандарту
caption = re.sub(r'[`´«»“”¨]', '"', caption)
caption = re.sub(r'[‘’]', "'", caption)
# "
caption = re.sub(r'"?', '', caption)
# &
caption = re.sub(r'&', '', caption)
# ip adresses:
caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption)
# article ids:
caption = re.sub(r'\d:\d\d\s+$', '', caption)
# \n
caption = re.sub(r'\\n', ' ', caption)
# "#123"
caption = re.sub(r'#\d{1,3}\b', '', caption)
# "#12345.."
caption = re.sub(r'#\d{5,}\b', '', caption)
# "123456.."
caption = re.sub(r'\b\d{6,}\b', '', caption)
# filenames:
caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption)
#
caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT"""
caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT"""
caption = re.sub(self.bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
caption = re.sub(r'\s+\.\s+', r' ', caption) # " . "
# this-is-my-cute-cat / this_is_my_cute_cat
regex2 = re.compile(r'(?:\-|\_)')
if len(re.findall(regex2, caption)) > 3:
caption = re.sub(regex2, ' ', caption)
caption = self.basic_clean(caption)
caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640
caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc
caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231
caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption)
caption = re.sub(r'(free\s)?download(\sfree)?', '', caption)
caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption)
caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption)
caption = re.sub(r'\bpage\s+\d+\b', '', caption)
caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a...
caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption)
caption = re.sub(r'\b\s+\:\s+', r': ', caption)
caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption)
caption = re.sub(r'\s+', ' ', caption)
caption.strip()
caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption)
caption = re.sub(r'^[\'\_,\-\:;]', r'', caption)
caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption)
caption = re.sub(r'^\.\S+$', '', caption)
return caption.strip()
================================================
FILE: deepfloyd_if/modules/utils.py
================================================
# -*- coding: utf-8 -*-
import numpy as np
import torchvision.transforms as T
def predict_proba(X, weights, biases):
logits = X @ weights.T + biases
proba = np.where(logits >= 0, 1 / (1 + np.exp(-logits)), np.exp(logits) / (1 + np.exp(logits)))
return proba.T
def load_model_weights(path):
model_weights = np.load(path)
return model_weights['weights'], model_weights['biases']
def clip_process_generations(generations):
min_size = min(generations.shape[-2:])
return T.Compose([
T.CenterCrop(min_size),
T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True),
T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])(generations)
================================================
FILE: deepfloyd_if/pipelines/__init__.py
================================================
# -*- coding: utf-8 -*-
from .dream import dream
from .style_transfer import style_transfer
from .super_resolution import super_resolution
from .inpainting import inpainting
__all__ = ['dream', 'style_transfer', 'super_resolution', 'inpainting']
================================================
FILE: deepfloyd_if/pipelines/dream.py
================================================
# -*- coding: utf-8 -*-
from datetime import datetime
import torch
def dream(
t5,
if_I,
if_II=None,
if_III=None,
*,
prompt,
style_prompt=None,
negative_prompt=None,
seed=None,
aspect_ratio='1:1',
if_I_kwargs=None,
if_II_kwargs=None,
if_III_kwargs=None,
progress=True,
return_tensors=False,
disable_watermark=False,
):
"""
Generate pictures using text description!
:param optional dict if_I_kwargs:
"dynamic_thresholding_p": 0.95, [0.5, 1.0] it controls color saturation on high cfg values
"dynamic_thresholding_c": 1.5, [1.0, 15.0] clips the limiter to avoid greyish images on high limiter values
"guidance_scale": 7.0, [1.0, 20.0] control the level of text understanding
"positive_mixer": 0.25, [0.0, 1.0] contribution of the second positive prompt, 0.0 - minimum, 1.0 - maximum
"sample_timestep_respacing": "150", see available modes IFBaseModule.respacing_modes or use custom
:param optional dict if_II_kwargs:
"dynamic_thresholding_p": 0.95, [0.5, 1.0] it controls color saturation on high cfg values
"dynamic_thresholding_c": 1.0, [1.0, 15.0] clips the limiter to avoid greyish images on high limiter values
"guidance_scale": 4.0, [1.0, 20.0] control the amount of texture and details in the final image
"aug_level": 0.25, [0.0, 1.0] adds additional augmentation to generate more realistic images
"positive_mixer": 0.5, [0.0, 1.0] contribution of the second positive prompt, 0.0 - minimum, 1.0 - maximum
"sample_timestep_respacing": "smart50", see available modes IFBaseModule.respacing_modes or use custom
:param deepfloyd_if.modules.IFStageI if_I: obj
:param deepfloyd_if.modules.IFStageII if_II: obj
:param deepfloyd_if.modules.IFStageIII if_III: obj
:param deepfloyd_if.modules.T5Embedder t5: obj
:param int seed: int, in case None will use random value
:param aspect_ratio:
:param str prompt: text hint/description
:param str style_prompt: text hint/description for style
:param str negative_prompt: text hint/description for negative prompt, will use it as unconditional emb
:param progress:
:return:
"""
if seed is None:
seed = int((datetime.utcnow().timestamp() * 10 ** 6) % (2 ** 32 - 1))
if_I.seed_everything(seed)
if isinstance(prompt, str):
prompt = [prompt]
t5_embs = t5.get_text_embeddings(prompt)
if_I_kwargs = if_I_kwargs or {}
if_I_kwargs['seed'] = seed
if_I_kwargs['t5_embs'] = t5_embs
if_I_kwargs['aspect_ratio'] = aspect_ratio
if_I_kwargs['progress'] = progress
if style_prompt is not None:
if isinstance(style_prompt, str):
style_prompt = [style_prompt]
style_t5_embs = t5.get_text_embeddings(style_prompt)
if_I_kwargs['style_t5_embs'] = style_t5_embs
if_I_kwargs['positive_t5_embs'] = style_t5_embs
if negative_prompt is not None:
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
negative_t5_embs = t5.get_text_embeddings(negative_prompt)
if_I_kwargs['negative_t5_embs'] = negative_t5_embs
stageI_generations, _ = if_I.embeddings_to_image(**if_I_kwargs)
pil_images_I = if_I.to_images(stageI_generations, disable_watermark=disable_watermark)
result = {'I': pil_images_I}
if if_II is not None:
if_II_kwargs = if_II_kwargs or {}
if_II_kwargs['low_res'] = stageI_generations
if_II_kwargs['seed'] = seed
if_II_kwargs['t5_embs'] = t5_embs
if_II_kwargs['progress'] = progress
if_II_kwargs['style_t5_embs'] = if_I_kwargs.get('style_t5_embs')
if_II_kwargs['positive_t5_embs'] = if_I_kwargs.get('positive_t5_embs')
stageII_generations, _meta = if_II.embeddings_to_image(**if_II_kwargs)
pil_images_II = if_II.to_images(stageII_generations, disable_watermark=disable_watermark)
result['II'] = pil_images_II
else:
stageII_generations = None
if if_II is not None and if_III is not None:
if_III_kwargs = if_III_kwargs or {}
stageIII_generations = []
for idx in range(len(stageII_generations)):
if if_III.use_diffusers:
if_III_kwargs['prompt'] = prompt[idx: idx+1]
if_III_kwargs['low_res'] = stageII_generations[idx:idx+1]
if_III_kwargs['seed'] = seed
if_III_kwargs['t5_embs'] = t5_embs[idx:idx+1]
if_III_kwargs['progress'] = progress
style_t5_embs = if_I_kwargs.get('style_t5_embs')
if style_t5_embs is not None:
style_t5_embs = style_t5_embs[idx:idx+1]
positive_t5_embs = if_I_kwargs.get('positive_t5_embs')
if positive_t5_embs is not None:
positive_t5_embs = positive_t5_embs[idx:idx+1]
if_III_kwargs['style_t5_embs'] = style_t5_embs
if_III_kwargs['positive_t5_embs'] = positive_t5_embs
_stageIII_generations, _meta = if_III.embeddings_to_image(**if_III_kwargs)
stageIII_generations.append(_stageIII_generations)
stageIII_generations = torch.cat(stageIII_generations, 0)
pil_images_III = if_III.to_images(stageIII_generations, disable_watermark=disable_watermark)
result['III'] = pil_images_III
else:
stageIII_generations = None
if return_tensors:
return result, (stageI_generations, stageII_generations, stageIII_generations)
else:
return result
================================================
FILE: deepfloyd_if/pipelines/inpainting.py
================================================
# -*- coding: utf-8 -*-
from datetime import datetime
import PIL
import torch
from .utils import _prepare_pil_image
def inpainting(
t5,
if_I,
if_II=None,
if_III=None,
*,
support_pil_img,
prompt,
inpainting_mask,
negative_prompt=None,
seed=None,
if_I_kwargs=None,
if_II_kwargs=None,
if_III_kwargs=None,
progress=True,
return_tensors=False,
disable_watermark=False,
):
from skimage.transform import resize # noqa
from skimage import img_as_bool # noqa
assert isinstance(support_pil_img, PIL.Image.Image)
if seed is None:
seed = int((datetime.utcnow().timestamp() * 10 ** 6) % (2 ** 32 - 1))
t5_embs = t5.get_text_embeddings(prompt)
if negative_prompt is not None:
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
negative_t5_embs = t5.get_text_embeddings(negative_prompt)
else:
negative_t5_embs = None
low_res = _prepare_pil_image(support_pil_img, 64)
mid_res = _prepare_pil_image(support_pil_img, 256)
high_res = _prepare_pil_image(support_pil_img, 1024)
result = {}
_, _, image_h, image_w = low_res.shape
if_I_kwargs = if_I_kwargs or {}
if_I_kwargs['seed'] = seed
if_I_kwargs['progress'] = progress
if_I_kwargs['aspect_ratio'] = f'{image_w}:{image_h}'
if_I_kwargs['t5_embs'] = t5_embs
if_I_kwargs['negative_t5_embs'] = negative_t5_embs
if_I_kwargs['support_noise'] = low_res
inpainting_mask_I = img_as_bool(resize(inpainting_mask[0].cpu(), (3, image_h, image_w)))
inpainting_mask_I = torch.from_numpy(inpainting_mask_I).unsqueeze(0).to(if_I.device)
if_I_kwargs['inpainting_mask'] = inpainting_mask_I
stageI_generations, _ = if_I.embeddings_to_image(**if_I_kwargs)
pil_images_I = if_I.to_images(stageI_generations, disable_watermark=disable_watermark)
result['I'] = pil_images_I
if if_II is not None:
_, _, image_h, image_w = mid_res.shape
if_II_kwargs = if_II_kwargs or {}
if_II_kwargs['low_res'] = stageI_generations
if_II_kwargs['seed'] = seed
if_II_kwargs['t5_embs'] = t5_embs
if_II_kwargs['negative_t5_embs'] = negative_t5_embs
if_II_kwargs['progress'] = progress
if_II_kwargs['support_noise'] = mid_res
if 'inpainting_mask' not in if_II_kwargs:
inpainting_mask_II = img_as_bool(resize(inpainting_mask[0].cpu(), (3, image_h, image_w)))
inpainting_mask_II = torch.from_numpy(inpainting_mask_II).unsqueeze(0).to(if_II.device)
if_II_kwargs['inpainting_mask'] = inpainting_mask_II
stageII_generations, _meta = if_II.embeddings_to_image(**if_II_kwargs)
pil_images_II = if_II.to_images(stageII_generations, disable_watermark=disable_watermark)
result['II'] = pil_images_II
else:
stageII_generations = None
if if_II is not None and if_III is not None:
_, _, image_h, image_w = high_res.shape
if_III_kwargs = if_III_kwargs or {}
stageIII_generations = []
for idx in range(len(stageII_generations)):
if if_III.use_diffusers:
if_III_kwargs['prompt'] = prompt[idx: idx+1]
if_III_kwargs['low_res'] = stageII_generations[idx:idx+1]
if_III_kwargs['seed'] = seed
if_III_kwargs['t5_embs'] = t5_embs[idx:idx+1]
if negative_t5_embs is not None:
if_III_kwargs['negative_t5_embs'] = negative_t5_embs[idx:idx+1]
if_III_kwargs['progress'] = progress
if_III_kwargs['support_noise'] = high_res
if 'inpainting_mask' not in if_III_kwargs:
inpainting_mask_III = img_as_bool(resize(inpainting_mask[0].cpu(), (3, image_h, image_w)))
inpainting_mask_III = torch.from_numpy(inpainting_mask_III).unsqueeze(0).to(if_III.device)
if_III_kwargs['inpainting_mask'] = inpainting_mask_III
_stageIII_generations, _meta = if_III.embeddings_to_image(**if_III_kwargs)
stageIII_generations.append(_stageIII_generations)
stageIII_generations = torch.cat(stageIII_generations, 0)
pil_images_III = if_III.to_images(stageIII_generations, disable_watermark=disable_watermark)
result['III'] = pil_images_III
else:
stageIII_generations = None
if return_tensors:
return result, (stageI_generations, stageII_generations, stageIII_generations)
else:
return result
================================================
FILE: deepfloyd_if/pipelines/style_transfer.py
================================================
# -*- coding: utf-8 -*-
from datetime import datetime
import PIL
import torch
from .utils import _prepare_pil_image
def style_transfer(
t5,
if_I,
if_II,
if_III=None,
*,
support_pil_img,
style_prompt,
prompt=None,
negative_prompt=None,
seed=None,
if_I_kwargs=None,
if_II_kwargs=None,
if_III_kwargs=None,
progress=True,
return_tensors=False,
disable_watermark=False,
):
assert isinstance(support_pil_img, PIL.Image.Image)
bs = len(style_prompt)
if seed is None:
seed = int((datetime.utcnow().timestamp() * 10 ** 6) % (2 ** 32 - 1))
if prompt is not None:
t5_embs = t5.get_text_embeddings(prompt)
else:
t5_embs = t5.get_text_embeddings(style_prompt)
style_t5_embs = t5.get_text_embeddings(style_prompt)
if negative_prompt is not None:
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
negative_t5_embs = t5.get_text_embeddings(negative_prompt)
else:
negative_t5_embs = None
low_res = _prepare_pil_image(support_pil_img, 64)
mid_res = _prepare_pil_image(support_pil_img, 256)
# high_res = _prepare_pil_image(support_pil_img, 1024)
result = {}
if if_I is not None:
_, _, image_h, image_w = low_res.shape
if_I_kwargs = if_I_kwargs or {'sample_timestep_respacing': '20,20,20,20,10,0,0,0,0,0'}
if_I_kwargs['seed'] = seed
if_I_kwargs['progress'] = progress
if_I_kwargs['aspect_ratio'] = f'{image_w}:{image_h}'
if_I_kwargs['t5_embs'] = t5_embs
if_I_kwargs['style_t5_embs'] = style_t5_embs
if_I_kwargs['positive_t5_embs'] = style_t5_embs
if_I_kwargs['negative_t5_embs'] = negative_t5_embs
if_I_kwargs['support_noise'] = low_res
stageI_generations, _ = if_I.embeddings_to_image(**if_I_kwargs)
pil_images_I = if_I.to_images(stageI_generations, disable_watermark=disable_watermark)
result['I'] = pil_images_I
else:
stageI_generations = None
if if_II is not None:
if stageI_generations is None:
stageI_generations = low_res.repeat(bs, 1, 1, 1)
if_II_kwargs = if_II_kwargs or {}
if_II_kwargs['low_res'] = stageI_generations
if_II_kwargs['seed'] = seed
if_II_kwargs['t5_embs'] = t5_embs
if_II_kwargs['style_t5_embs'] = style_t5_embs
if_II_kwargs['positive_t5_embs'] = style_t5_embs
if_II_kwargs['negative_t5_embs'] = negative_t5_embs
if_II_kwargs['progress'] = progress
if_II_kwargs['support_noise'] = mid_res
stageII_generations, _meta = if_II.embeddings_to_image(**if_II_kwargs)
pil_images_II = if_II.to_images(stageII_generations, disable_watermark=disable_watermark)
result['II'] = pil_images_II
else:
stageII_generations = None
if if_II is not None and if_III is not None:
if_III_kwargs = if_III_kwargs or {}
stageIII_generations = []
for idx in range(len(stageII_generations)):
if if_III.use_diffusers:
if_III_kwargs['prompt'] = prompt[idx: idx+1] if prompt is not None else style_prompt[idx: idx+1]
if_III_kwargs['low_res'] = stageII_generations[idx:idx+1]
if_III_kwargs['seed'] = seed
if_III_kwargs['t5_embs'] = t5_embs[idx:idx+1]
if_III_kwargs['progress'] = progress
style_t5_embs = if_II_kwargs.get('style_t5_embs')
if style_t5_embs is not None:
style_t5_embs = style_t5_embs[idx:idx+1]
positive_t5_embs = if_II_kwargs.get('positive_t5_embs')
if positive_t5_embs is not None:
positive_t5_embs = positive_t5_embs[idx:idx+1]
if_III_kwargs['style_t5_embs'] = style_t5_embs
if_III_kwargs['positive_t5_embs'] = positive_t5_embs
_stageIII_generations, _meta = if_III.embeddings_to_image(**if_III_kwargs)
stageIII_generations.append(_stageIII_generations)
stageIII_generations = torch.cat(stageIII_generations, 0)
pil_images_III = if_III.to_images(stageIII_generations, disable_watermark=disable_watermark)
result['III'] = pil_images_III
else:
stageIII_generations = None
if return_tensors:
return result, (stageI_generations, stageII_generations, stageIII_generations)
else:
return result
================================================
FILE: deepfloyd_if/pipelines/super_resolution.py
================================================
# -*- coding: utf-8 -*-
from datetime import datetime
import PIL
from .utils import _prepare_pil_image
def super_resolution(
t5,
if_III=None,
*,
support_pil_img,
prompt=None,
negative_prompt=None,
seed=None,
if_III_kwargs=None,
progress=True,
img_size=256,
img_scale=4.0,
return_tensors=False,
disable_watermark=False,
):
assert isinstance(support_pil_img, PIL.Image.Image)
assert img_size % 8 == 0
if seed is None:
seed = int((datetime.utcnow().timestamp() * 10 ** 6) % (2 ** 32 - 1))
if prompt is not None:
t5_embs = t5.get_text_embeddings(prompt)
else:
t5_embs = t5.get_text_embeddings('')
if negative_prompt is not None:
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
negative_t5_embs = t5.get_text_embeddings(negative_prompt)
else:
negative_t5_embs = None
low_res = _prepare_pil_image(support_pil_img, img_size)
result = {}
bs = 1
if_III_kwargs = if_III_kwargs or {}
if if_III.use_diffusers:
if_III_kwargs['prompt'] = prompt
if_III_kwargs['low_res'] = low_res.repeat(bs, 1, 1, 1)
if_III_kwargs['seed'] = seed
if_III_kwargs['t5_embs'] = t5_embs
if_III_kwargs['negative_t5_embs'] = negative_t5_embs
if_III_kwargs['progress'] = progress
if_III_kwargs['img_scale'] = img_scale
stageIII_generations, _meta = if_III.embeddings_to_image(**if_III_kwargs)
pil_images_III = if_III.to_images(stageIII_generations, disable_watermark=disable_watermark)
result['III'] = pil_images_III
if return_tensors:
return result, (stageIII_generations,)
else:
return result
================================================
FILE: deepfloyd_if/pipelines/utils.py
================================================
# -*- coding: utf-8 -*-
import torch
import numpy as np
from PIL import Image
def _prepare_pil_image(raw_pil_img, img_size):
raw_pil_img = raw_pil_img.convert('RGB')
w, h = raw_pil_img.size
coef = w / h
image_h, image_w = img_size, img_size
if coef >= 1:
image_w = int(round(img_size / 8 * coef) * 8)
else:
image_h = int(round(img_size / 8 / coef) * 8)
pil_img = raw_pil_img.resize(
(image_w, image_h), resample=getattr(Image, 'Resampling', Image).BICUBIC, reducing_gap=None
)
img = np.array(pil_img)
img = img.astype(np.float32) / 127.5 - 1
img = np.transpose(img, [2, 0, 1])
img = torch.from_numpy(img).unsqueeze(0)
return img
================================================
FILE: deepfloyd_if/utils.py
================================================
# -*- coding: utf-8 -*-
from os.path import abspath, dirname, join
import torch
import numpy as np
from PIL import Image, ImageFilter
RESOURCES_ROOT = join(abspath(dirname(__file__)), 'resources')
def drop_shadow(image, offset=(5, 5), background=0xffffff, shadow=0x444444, border=8, iterations=3):
"""
Drop shadows with PIL.
Author: Kevin Schluff
License: Python license
https://code.activestate.com/recipes/474116/
Add a gaussian blur drop shadow to an image.
image - The image to overlay on top of the shadow.
offset - Offset of the shadow from the image as an (x,y) tuple. Can be
positive or negative.
background - Background colour behind the image.
shadow - Shadow colour (darkness).
border - Width of the border around the image. This must be wide
enough to account for the blurring of the shadow.
iterations - Number of times to apply the filter. More iterations
produce a more blurred shadow, but increase processing time.
"""
# Create the backdrop image -- a box in the background colour with a
# shadow on it.
total_width = image.size[0] + abs(offset[0]) + 2 * border
total_height = image.size[1] + abs(offset[1]) + 2 * border
back = Image.new(image.mode, (total_width, total_height), background)
# Place the shadow, taking into account the offset from the image
shadow_left = border + max(offset[0], 0)
shadow_top = border + max(offset[1], 0)
back.paste(shadow, [shadow_left, shadow_top, shadow_left + image.size[0], shadow_top + image.size[1]])
# Apply the filter to blur the edges of the shadow. Since a small kernel
# is used, the filter must be applied repeatedly to get a decent blur.
n = 0
while n < iterations:
back = back.filter(ImageFilter.BLUR)
n += 1
# Paste the input image onto the shadow backdrop
image_left = border - min(offset[0], 0)
image_top = border - min(offset[1], 0)
back.paste(image, (image_left, image_top))
return back
def pil_list_to_torch_tensors(pil_images):
result = []
for pil_image in pil_images:
image = np.array(pil_image, dtype=np.uint8)
image = torch.from_numpy(image)
image = image.permute(2, 0, 1).unsqueeze(0)
result.append(image)
return torch.cat(result, dim=0)
================================================
FILE: requirements-dev.txt
================================================
-r requirements-test.txt
pre-commit
================================================
FILE: requirements-test.txt
================================================
-r requirements.txt
pytest
pytest-cov
================================================
FILE: requirements.txt
================================================
tqdm
numpy
torch<2.0.0
torchvision
omegaconf
matplotlib
Pillow>=9.2.0
huggingface_hub>=0.13.2
transformers~=4.25.1
accelerate~=0.15.0
diffusers~=0.16.0
tokenizers~=0.13.2
sentencepiece~=0.1.97
ftfy~=6.1.1
beautifulsoup4~=4.11.1
================================================
FILE: setup.cfg
================================================
[pep8]
max-line-length = 120
exclude = .tox,*migrations*,.json
[flake8]
max-line-length = 120
exclude = .tox,*migrations*,.json
[autopep8-wrapper]
exclude = .tox,*migrations*,.json
[check-docstring-first]
exclude = .tox,*migrations*,.json
================================================
FILE: setup.py
================================================
# -*- coding: utf-8 -*-
import os
import re
from setuptools import setup
def read(filename):
with open(os.path.join(os.path.dirname(__file__), filename)) as f:
file_content = f.read()
return file_content
def get_requirements():
requirements = []
for requirement in read('requirements.txt').splitlines():
if requirement.startswith('git+') or requirement.startswith('svn+') or requirement.startswith('hg+'):
parsed_requires = re.findall(r'#egg=([\w\d\.]+)-([\d\.]+)$', requirement)
if parsed_requires:
package, version = parsed_requires[0]
requirements.append(f'{package}=={version}')
else:
print('WARNING! For correct matching dependency links need to specify package name and version'
'such as #egg=-')
else:
requirements.append(requirement)
return requirements
def get_links():
return [
requirement for requirement in read('requirements.txt').splitlines()
if requirement.startswith('git+') or requirement.startswith('svn+') or requirement.startswith('hg+')
]
def get_version():
""" Get version from the package without actually importing it. """
init = read('deepfloyd_if/__init__.py')
for line in init.split('\n'):
if line.startswith('__version__'):
return eval(line.split('=')[1])
setup(
name='deepfloyd_if',
version=get_version(),
author='DeepFloyd, StabilityAI',
author_email='shonenkov@gmail.com',
description='DeepFloyd-IF (Imagen Free)',
packages=['deepfloyd_if', 'deepfloyd_if/model', 'deepfloyd_if/modules', 'deepfloyd_if/pipelines',
'deepfloyd_if/resources'],
package_data={'deepfloyd_if/resources': ['*.png', '*.npy', '*.npz']},
install_requires=get_requirements(),
dependency_links=get_links(),
long_description=read('README.md'),
long_description_content_type='text/markdown',
)