Repository: deep-floyd/IF Branch: develop Commit: ffc816389168 Files: 38 Total size: 165.4 KB Directory structure: gitextract_r39ejdyw/ ├── .gitattributes ├── .gitignore ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── LICENSE ├── LICENSE-MODEL ├── README.md ├── deepfloyd_if/ │ ├── __init__.py │ ├── model/ │ │ ├── __init__.py │ │ ├── gaussian_diffusion.py │ │ ├── losses.py │ │ ├── nn.py │ │ ├── resample.py │ │ ├── respace.py │ │ └── unet.py │ ├── modules/ │ │ ├── __init__.py │ │ ├── base.py │ │ ├── stage_I.py │ │ ├── stage_II.py │ │ ├── stage_III.py │ │ ├── stage_III_sd_x4.py │ │ ├── t5.py │ │ └── utils.py │ ├── pipelines/ │ │ ├── __init__.py │ │ ├── dream.py │ │ ├── inpainting.py │ │ ├── style_transfer.py │ │ ├── super_resolution.py │ │ └── utils.py │ ├── resources/ │ │ ├── p_head_v1.npz │ │ ├── w_head_v1.npz │ │ └── zero_t5-v1_1-xxl_vector.npy │ └── utils.py ├── requirements-dev.txt ├── requirements-test.txt ├── requirements.txt ├── setup.cfg └── setup.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitattributes ================================================ notebooks/pipes-DeepFloyd-IF.ipynb filter=lfs diff=lfs merge=lfs -text ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .idea .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, deepfloyd_if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ ================================================ FILE: .pre-commit-config.yaml ================================================ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.2.0 hooks: - id: check-docstring-first - id: check-merge-conflict stages: - push - id: double-quote-string-fixer - id: end-of-file-fixer - id: fix-encoding-pragma - id: mixed-line-ending - id: trailing-whitespace - repo: https://github.com/pycqa/flake8 rev: "4.0.1" hooks: - id: flake8 args: ['--config=setup.cfg'] - repo: https://github.com/pre-commit/mirrors-autopep8 rev: v1.6.0 hooks: - id: autopep8 ================================================ FILE: CHANGELOG.md ================================================ v1.0.2rc ------- - uses separated tokenizer_path to init tokenizer in T5Embedder v1.0.1 ------ - renamed main model `IF-I-IF` --> `IF-I-XL` - moved dir `notebooks` to HF storage https://huggingface.co/DeepFloyd/IF-notebooks; lets keep new notebooks there; - added additional kaggle notebook (more free GPU resources) how to generate pictures 1k: [![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/code/shonenkov/deepfloyd-if-4-3b-generator-of-pictures) v1.0.0 ------ - initial version ================================================ FILE: LICENSE ================================================ Copyright (c) 2023 DeepFloyd, StabilityAI Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 1. The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 2. All persons obtaining a copy or substantial portion of the Software, a modified version of the Software (or substantial portion thereof), or a derivative work based upon this Software (or substantial portion thereof) must not delete, remove, disable, diminish, or circumvent any inference filters or inference filter mechanisms in the Software, or any portion of the Software that implements any such filters or filter mechanisms. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: LICENSE-MODEL ================================================ DEEPFLOYD IF LICENSE AGREEMENT This License Agreement (as may be amended in accordance with this License Agreement, “License”), between you, or your employer or other entity (if you are entering into this agreement on behalf of your employer or other entity) (“Licensee” or “you”) and Stability AI Ltd.. (“Stability AI” or “we”) applies to your use of any computer program, algorithm, source code, object code, or software that is made available by Stability AI under this License (“Software”) and any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software (“Documentation”). By clicking “I Accept” below or by using the Software, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to use the Software or Documentation (collectively, the “Software Products”), and you must immediately cease using the Software Products. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to Stability AI that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the Software Products on behalf of your employer or other entity. 1. LICENSE GRANT a. Subject to your compliance with the Documentation and Sections 2, 3, and 5, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s copyright interests to reproduce, distribute, and create derivative works of the Software solely for your non-commercial research purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Stability AI’s prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License. b. You may make a reasonable number of copies of the Documentation solely for use in connection with the license to the Software granted above. c. The grant of rights expressly set forth in this Section 1 (License Grant) are the complete grant of rights to you in the Software Products, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Stability AI and its licensors reserve all rights not expressly granted by this License. 2. RESTRICTIONS You will not, and will not permit, assist or cause any third party to: a. use, modify, copy, reproduce, create derivative works of, or distribute the Software Products (or any derivative works thereof, works incorporating the Software Products, or any data produced by the Software), in whole or in part, for (i) any commercial or production purposes, (ii) military purposes or in the service of nuclear technology, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates any third-party rights, or (vi) in any manner that violates any applicable law and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and all laws governing the processing of biometric information), as well as all amendments and successor laws to any of the foregoing; b. alter or remove copyright and other proprietary notices which appear on or in the Software Products; c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Stability AI in connection with the Software, or to circumvent or remove any usage restrictions, or to enable functionality disabled by Stability AI; or d. offer or impose any terms on the Software Products that alter, restrict, or are inconsistent with the terms of this License. e. 1) violate any applicable U.S. and non-U.S. export control and trade sanctions laws (“Export Laws”); 2) directly or indirectly export, re-export, provide, or otherwise transfer Software Products: (a) to any individual, entity, or country prohibited by Export Laws; (b) to anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications; 3) use or download Software Products if you or they are: (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any purpose prohibited by Export Laws; and (4) will not disguise your location through IP proxying or other methods. 3. ATTRIBUTION Together with any copies of the Software Products (as well as derivative works thereof or works incorporating the Software Products) that you distribute, you must provide (i) a copy of this License, and (ii) the following attribution notice: “DeepFloyd is licensed under the DeepFloyd License, Copyright (c) Stability AI Ltd. All Rights Reserved.” 4. DISCLAIMERS THE SOFTWARE PRODUCTS ARE PROVIDED “AS IS” and “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. STABILITY AIEXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE SOFTWARE PRODUCTS, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. STABILITY AI MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE SOFTWARE PRODUCTS WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS. 5. LIMITATION OF LIABILITY TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL STABILITY AI BE LIABLE TO YOU (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF STABILITY AI HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE SOFTWARE PRODUCTS, THEIR CONSTITUENT COMPONENTS, AND ANY OUTPUT (COLLECTIVELY, “SOFTWARE MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE SOFTWARE MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE SOFTWARE MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE SOFTWARE MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE. 6. INDEMNIFICATION You will indemnify, defend and hold harmless Stability AI and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “Stability AI Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any Stability AI Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to: (a) your access to or use of the Software Products (as well as any results or data generated from such access or use), including any High-Risk Use (defined below); (b) your violation of this License; or (c) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Stability AI Parties of any such Claims, and cooperate with Stability AI Parties in defending such Claims. You will also grant the Stability AI Parties sole control of the defense or settlement, at Stability AI’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Stability AI or the other Stability AI Parties. 7. TERMINATION; SURVIVAL a. This License will automatically terminate upon any breach by you of the terms of this License. b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you. c. The following sections survive termination of this License: 2 (Restrictions), 3 (Attribution), 4 (Disclaimers), 5 (Limitation on Liability), 6 (Indemnification) 7 (Termination; Survival), 8 (Third Party Materials), 9 (Trademarks), 10 (Applicable Law; Dispute Resolution), and 11 (Miscellaneous). 8. THIRD PARTY MATERIALS The Software Products may contain third-party software or other components (including free and open source software) (all of the foregoing, “Third Party Materials”), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Stability AI does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk. 9. TRADEMARKS Licensee has not been granted any trademark license as part of this License and may not use any name or mark associated with Stability AI without the prior written permission of Stability AI, except to the extent necessary to make the reference required by the “ATTRIBUTION” section of this Agreement. 10. APPLICABLE LAW; DISPUTE RESOLUTION This License will be governed and construed under the laws of the State of California without regard to conflicts of law provisions. Any suit or proceeding arising out of or relating to this License will be brought in the federal or state courts, as applicable, in San Mateo County, California, and each party irrevocably submits to the jurisdiction and venue of such courts. 11. MISCELLANEOUS If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Stability AI to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the Documentation, contains the entire understanding between you and Stability AI regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Stability AI regarding such subject matter. No change or addition to any provision of this License will be binding unless it is in writing and signed by an authorized representative of both you and Stability AI. ================================================ FILE: README.md ================================================ [![License](https://img.shields.io/badge/Code_License-Modified_MIT-blue.svg)](LICENSE) [![License](https://img.shields.io/badge/Weights_License-DeepFloyd_IF-orange.svg)](LICENSE-MODEL) [![Downloads](https://pepy.tech/badge/deepfloyd_if)](https://pepy.tech/project/deepfloyd_if) [![Discord](https://img.shields.io/badge/Discord-%237289DA.svg?logo=discord&logoColor=white)](https://discord.gg/umz62Mgr) [![Twitter](https://img.shields.io/badge/Twitter-%231DA1F2.svg?logo=twitter&logoColor=white)](https://twitter.com/deepfloydai) [![Linktree](https://img.shields.io/badge/Linktree-%2339E09B.svg?logo=linktree&logoColor=white)](http://linktr.ee/deepfloyd) # IF by [DeepFloyd Lab](https://deepfloyd.ai) at [StabilityAI](https://stability.ai/)

We introduce DeepFloyd IF, a novel state-of-the-art open-source text-to-image model with a high degree of photorealism and language understanding. DeepFloyd IF is a modular composed of a frozen text encoder and three cascaded pixel diffusion modules: a base model that generates 64x64 px image based on text prompt and two super-resolution models, each designed to generate images of increasing resolution: 256x256 px and 1024x1024 px. All stages of the model utilize a frozen text encoder based on the T5 transformer to extract text embeddings, which are then fed into a UNet architecture enhanced with cross-attention and attention pooling. The result is a highly efficient model that outperforms current state-of-the-art models, achieving a zero-shot FID score of 6.66 on the COCO dataset. Our work underscores the potential of larger UNet architectures in the first stage of cascaded diffusion models and depicts a promising future for text-to-image synthesis.

*Inspired by* [*Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding*](https://arxiv.org/pdf/2205.11487.pdf) ## Minimum requirements to use all IF models: - 16GB vRAM for IF-I-XL (4.3B text to 64x64 base module) & IF-II-L (1.2B to 256x256 upscaler module) - 24GB vRAM for IF-I-XL (4.3B text to 64x64 base module) & IF-II-L (1.2B to 256x256 upscaler module) & Stable x4 (to 1024x1024 upscaler) - `xformers` and set env variable `FORCE_MEM_EFFICIENT_ATTN=1` ## Quick Start [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/deepfloyd_if_free_tier_google_colab.ipynb) [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/DeepFloyd/IF) ```shell pip install deepfloyd_if==1.0.2rc0 pip install xformers==0.0.16 pip install git+https://github.com/openai/CLIP.git --no-deps ``` ## Local notebooks [![Jupyter Notebook](https://img.shields.io/badge/jupyter_notebook-%23FF7A01.svg?logo=jupyter&logoColor=white)](https://huggingface.co/DeepFloyd/IF-notebooks/blob/main/pipes-DeepFloyd-IF-v1.0.ipynb) [![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/code/shonenkov/deepfloyd-if-4-3b-generator-of-pictures) The Dream, Style Transfer, Super Resolution or Inpainting modes are avaliable in a Jupyter Notebook [here](https://huggingface.co/DeepFloyd/IF-notebooks/blob/main/pipes-DeepFloyd-IF-v1.0.ipynb). ## Integration with 🤗 Diffusers IF is also integrated with the 🤗 Hugging Face [Diffusers library](https://github.com/huggingface/diffusers/). Diffusers runs each stage individually allowing the user to customize the image generation process as well as allowing to inspect intermediate results easily. ### Example Before you can use IF, you need to accept its usage conditions. To do so: 1. Make sure to have a [Hugging Face account](https://huggingface.co/join) and be loggin in 2. Accept the license on the model card of [DeepFloyd/IF-I-XL-v1.0](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0) 3. Make sure to login locally. Install `huggingface_hub` ```sh pip install huggingface_hub --upgrade ``` run the login function in a Python shell ```py from huggingface_hub import login login() ``` and enter your [Hugging Face Hub access token](https://huggingface.co/docs/hub/security-tokens#what-are-user-access-tokens). Next we install `diffusers` and dependencies: ```sh pip install diffusers accelerate transformers safetensors ``` And we can now run the model locally. By default `diffusers` makes use of [model cpu offloading](https://huggingface.co/docs/diffusers/optimization/fp16#model-offloading-for-fast-inference-and-memory-savings) to run the whole IF pipeline with as little as 14 GB of VRAM. If you are using `torch>=2.0.0`, make sure to **delete all** `enable_xformers_memory_efficient_attention()` functions. ```py from diffusers import DiffusionPipeline from diffusers.utils import pt_to_pil import torch # stage 1 stage_1 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16) stage_1.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0 stage_1.enable_model_cpu_offload() # stage 2 stage_2 = DiffusionPipeline.from_pretrained( "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 ) stage_2.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0 stage_2.enable_model_cpu_offload() # stage 3 safety_modules = {"feature_extractor": stage_1.feature_extractor, "safety_checker": stage_1.safety_checker, "watermarker": stage_1.watermarker} stage_3 = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16) stage_3.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0 stage_3.enable_model_cpu_offload() prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"' # text embeds prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt) generator = torch.manual_seed(0) # stage 1 image = stage_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt").images pt_to_pil(image)[0].save("./if_stage_I.png") # stage 2 image = stage_2( image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt" ).images pt_to_pil(image)[0].save("./if_stage_II.png") # stage 3 image = stage_3(prompt=prompt, image=image, generator=generator, noise_level=100).images image[0].save("./if_stage_III.png") ``` There are multiple ways to speed up the inference time and lower the memory consumption even more with `diffusers`. To do so, please have a look at the Diffusers docs: - 🚀 [Optimizing for inference time](https://huggingface.co/docs/diffusers/api/pipelines/if#optimizing-for-speed) - ⚙️ [Optimizing for low memory during inference](https://huggingface.co/docs/diffusers/api/pipelines/if#optimizing-for-memory) For more in-detail information about how to use IF, please have a look at [the IF blog post](https://huggingface.co/blog/if) and [the documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/if) 📖. Diffusers dreambooth scripts also supports fine-tuning 🎨 [IF](https://huggingface.co/docs/diffusers/main/en/training/dreambooth#if). With parameter efficient finetuning, you can add new concepts to IF with a single GPU and ~28 GB VRAM. ## Run the code locally ### Loading the models into VRAM ```python from deepfloyd_if.modules import IFStageI, IFStageII, StableStageIII from deepfloyd_if.modules.t5 import T5Embedder device = 'cuda:0' if_I = IFStageI('IF-I-XL-v1.0', device=device) if_II = IFStageII('IF-II-L-v1.0', device=device) if_III = StableStageIII('stable-diffusion-x4-upscaler', device=device) t5 = T5Embedder(device="cpu") ``` ### I. Dream Dream is the text-to-image mode of the IF model ```python from deepfloyd_if.pipelines import dream prompt = 'ultra close-up color photo portrait of rainbow owl with deer horns in the woods' count = 4 result = dream( t5=t5, if_I=if_I, if_II=if_II, if_III=if_III, prompt=[prompt]*count, seed=42, if_I_kwargs={ "guidance_scale": 7.0, "sample_timestep_respacing": "smart100", }, if_II_kwargs={ "guidance_scale": 4.0, "sample_timestep_respacing": "smart50", }, if_III_kwargs={ "guidance_scale": 9.0, "noise_level": 20, "sample_timestep_respacing": "75", }, ) if_III.show(result['III'], size=14) ``` ![](./pics/dream-III.jpg) ## II. Zero-shot Image-to-Image Translation ![](./pics/img_to_img_scheme.jpeg) In Style Transfer mode, the output of your prompt comes out at the style of the `support_pil_img` ```python from deepfloyd_if.pipelines import style_transfer result = style_transfer( t5=t5, if_I=if_I, if_II=if_II, support_pil_img=raw_pil_image, style_prompt=[ 'in style of professional origami', 'in style of oil art, Tate modern', 'in style of plastic building bricks', 'in style of classic anime from 1990', ], seed=42, if_I_kwargs={ "guidance_scale": 10.0, "sample_timestep_respacing": "10,10,10,10,10,10,10,10,0,0", 'support_noise_less_qsample_steps': 5, }, if_II_kwargs={ "guidance_scale": 4.0, "sample_timestep_respacing": 'smart50', "support_noise_less_qsample_steps": 5, }, ) if_I.show(result['II'], 1, 20) ``` ![Alternative Text](./pics/deep_floyd_if_image_2_image.gif) ## III. Super Resolution For super-resolution, users can run `IF-II` and `IF-III` or 'Stable x4' on an image that was not necessarely generated by IF (two cascades): ```python from deepfloyd_if.pipelines import super_resolution middle_res = super_resolution( t5, if_III=if_II, prompt=['woman with a blue headscarf and a blue sweaterp, detailed picture, 4k dslr, best quality'], support_pil_img=raw_pil_image, img_scale=4., img_size=64, if_III_kwargs={ 'sample_timestep_respacing': 'smart100', 'aug_level': 0.5, 'guidance_scale': 6.0, }, ) high_res = super_resolution( t5, if_III=if_III, prompt=[''], support_pil_img=middle_res['III'][0], img_scale=4., img_size=256, if_III_kwargs={ "guidance_scale": 9.0, "noise_level": 20, "sample_timestep_respacing": "75", }, ) show_superres(raw_pil_image, high_res['III'][0]) ``` ![](./pics/if_as_upscaler.jpg) ### IV. Zero-shot Inpainting ```python from deepfloyd_if.pipelines import inpainting result = inpainting( t5=t5, if_I=if_I, if_II=if_II, if_III=if_III, support_pil_img=raw_pil_image, inpainting_mask=inpainting_mask, prompt=[ 'oil art, a man in a hat', ], seed=42, if_I_kwargs={ "guidance_scale": 7.0, "sample_timestep_respacing": "10,10,10,10,10,0,0,0,0,0", 'support_noise_less_qsample_steps': 0, }, if_II_kwargs={ "guidance_scale": 4.0, 'aug_level': 0.0, "sample_timestep_respacing": '100', }, if_III_kwargs={ "guidance_scale": 9.0, "noise_level": 20, "sample_timestep_respacing": "75", }, ) if_I.show(result['I'], 2, 3) if_I.show(result['II'], 2, 6) if_I.show(result['III'], 2, 14) ``` ![](./pics/deep_floyd_if_inpainting.gif) ### 🤗 Model Zoo 🤗 The link to download the weights as well as the model cards will be available soon on each model of the model zoo #### Original | Name | Cascade | Params | FID | Batch size | Steps | |:----------------------------------------------------------|:-------:|:------:|:----:|:----------:|:-----:| | [IF-I-M](https://huggingface.co/DeepFloyd/IF-I-M-v1.0) | I | 400M | 8.86 | 3072 | 2.5M | | [IF-I-L](https://huggingface.co/DeepFloyd/IF-I-L-v1.0) | I | 900M | 8.06 | 3200 | 3.0M | | [IF-I-XL](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0)* | I | 4.3B | 6.66 | 3072 | 2.42M | | [IF-II-M](https://huggingface.co/DeepFloyd/IF-II-M-v1.0) | II | 450M | - | 1536 | 2.5M | | [IF-II-L](https://huggingface.co/DeepFloyd/IF-II-L-v1.0)* | II | 1.2B | - | 1536 | 2.5M | | IF-III-L* _(soon)_ | III | 700M | - | 3072 | 1.25M | *best modules ### Quantitative Evaluation `FID = 6.66` ![](./pics/fid30k_if.jpg) ## License The code in this repository is released under the bespoke license (see added [point two](https://github.com/deep-floyd/IF/blob/main/LICENSE#L13)). The weights will be available soon via [the DeepFloyd organization at Hugging Face](https://huggingface.co/DeepFloyd) and have their own LICENSE. **Disclaimer:** *The initial release of the IF model is under a restricted research-purposes-only license temporarily to gather feedback, and after that we intend to release a fully open-source model in line with other Stability AI models.* ## Limitations and Biases The models available in this codebase have known limitations and biases. Please refer to [the model card](https://huggingface.co/DeepFloyd/IF-I-L-v1.0) for more information. ## 🎓 DeepFloyd IF creators: - Alex Shonenkov [GitHub](https://github.com/shonenkov) | [Linktr](https://linktr.ee/shonenkovAI) - Misha Konstantinov [GitHub](https://github.com/zeroshot-ai) | [Twitter](https://twitter.com/_bra_ket) - Daria Bakshandaeva [GitHub](https://github.com/Gugutse) | [Twitter](https://twitter.com/_gugutse_) - Christoph Schuhmann [GitHub](https://github.com/christophschuhmann) | [Twitter](https://twitter.com/laion_ai) - Ksenia Ivanova [GitHub](https://github.com/ivksu) | [Twitter](https://twitter.com/susiaiv) - Nadiia Klokova [GitHub](https://github.com/vauimpuls) | [Twitter](https://twitter.com/vauimpuls) ## 📄 Research Paper (Soon) ## Acknowledgements Special thanks to [StabilityAI](http://stability.ai) and its CEO [Emad Mostaque](https://twitter.com/emostaque) for invaluable support, providing GPU compute and infrastructure to train the models (our gratitude goes to [Richard Vencu](https://github.com/rvencu)); thanks to [LAION](https://laion.ai) and [Christoph Schuhmann](https://github.com/christophschuhmann) in particular for contribution to the project and well-prepared datasets; thanks to [Huggingface](https://huggingface.co) teams for optimizing models' speed and memory consumption during inference, creating demos and giving cool advice! ## 🚀 External Contributors 🚀 - The Biggest Thanks [@Apolinário](https://github.com/apolinario), for ideas, consultations, help and support on all stages to make IF available in open-source; for writing a lot of documentation and instructions; for creating a friendly atmosphere in difficult moments 🦉; - Thanks, [@patrickvonplaten](https://github.com/patrickvonplaten), for improving loading time of unet models by 80%; for integration Stable-Diffusion-x4 as native pipeline 💪; - Thanks, [@williamberman](https://github.com/williamberman) and [@patrickvonplaten](https://github.com/patrickvonplaten) for diffusers integration 🙌; - Thanks, [@hysts](https://github.com/hysts) and [@Apolinário](https://github.com/apolinario) for creating [the best gradio demo with IF](https://huggingface.co/spaces/DeepFloyd/IF) 🚀; - Thanks, [@Dango233](https://github.com/Dango233), for adapting IF with xformers memory efficient attention 💪; ================================================ FILE: deepfloyd_if/__init__.py ================================================ # -*- coding: utf-8 -*- __version__ = '1.0.2rc0' ================================================ FILE: deepfloyd_if/model/__init__.py ================================================ # -*- coding: utf-8 -*- from .unet import UNetModel, SuperResUNetModel __all__ = ['UNetModel', 'SuperResUNetModel'] ================================================ FILE: deepfloyd_if/model/gaussian_diffusion.py ================================================ # -*- coding: utf-8 -*- """ This code started out as a PyTorch port of Ho et al's diffusion model: https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. """ import enum import math import numpy as np import torch from .nn import mean_flat from .losses import normal_kl, discretized_gaussian_log_likelihood def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): """ Get a pre-defined beta schedule for the given name. The beta schedule library consists of beta schedules which remain similar in the limit of num_diffusion_timesteps. Beta schedules may be added, but should not be removed or changed once they are committed to maintain backwards compatibility. """ if schedule_name == 'linear': # Linear schedule from Ho et al, extended to work for any number of # diffusion steps. scale = 1000 / num_diffusion_timesteps beta_start = scale * 0.0001 beta_end = scale * 0.02 return np.linspace( beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 ) elif schedule_name == 'cosine': return betas_for_alpha_bar( num_diffusion_timesteps, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, ) else: raise NotImplementedError(f'unknown beta schedule: {schedule_name}') def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t from 0 to 1 and produces the cumulative product of (1-beta) up to that part of the diffusion process. :param max_beta: the maximum beta to use; use values lower than 1 to prevent singularities. """ betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) return np.array(betas) class ModelMeanType(enum.Enum): """ Which type of output the model predicts. """ PREVIOUS_X = enum.auto() # the model predicts x_{t-1} START_X = enum.auto() # the model predicts x_0 EPSILON = enum.auto() # the model predicts epsilon class ModelVarType(enum.Enum): """ What is used as the model's output variance. The LEARNED_RANGE option has been added to allow the model to predict values between FIXED_SMALL and FIXED_LARGE, making its job easier. """ LEARNED = enum.auto() FIXED_SMALL = enum.auto() FIXED_LARGE = enum.auto() LEARNED_RANGE = enum.auto() class LossType(enum.Enum): MSE = enum.auto() # use raw MSE loss (and KL when learning variances) RESCALED_MSE = ( enum.auto() ) # use raw MSE loss (with RESCALED_KL when learning variances) KL = enum.auto() # use the variational lower-bound RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB def is_vb(self): return self == LossType.KL or self == LossType.RESCALED_KL class GaussianDiffusion: """ Utilities for training and sampling diffusion model. Ported directly from here, and then adapted over time to further experimentation. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 :param betas: a 1-D numpy array of betas for each diffusion timestep, starting at T and going to 1. :param model_mean_type: a ModelMeanType determining what the model outputs. :param model_var_type: a ModelVarType determining how variance is output. :param loss_type: a LossType determining the loss function to use. :param rescale_timesteps: if True, pass floating point timesteps into the model so that they are always scaled like in the original paper (0 to 1000). """ def __init__( self, *, betas, model_mean_type, model_var_type, loss_type, rescale_timesteps=False, ): self.model_mean_type = model_mean_type self.model_var_type = model_var_type self.loss_type = loss_type self.rescale_timesteps = rescale_timesteps # Use float64 for accuracy. betas = np.array(betas, dtype=np.float64) self.betas = betas assert len(betas.shape) == 1, 'betas must be 1-D' assert (betas > 0).all() and (betas <= 1).all() self.num_timesteps = int(betas.shape[0]) alphas = 1.0 - betas self.alphas_cumprod = np.cumprod(alphas, axis=0) self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) # calculations for posterior q(x_{t-1} | x_t, x_0) self.posterior_variance = ( betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) # log calculation clipped because the posterior variance is 0 at the # beginning of the diffusion chain. self.posterior_log_variance_clipped = np.log( np.append(self.posterior_variance[1], self.posterior_variance[1:]) ) self.posterior_mean_coef1 = ( betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) self.posterior_mean_coef2 = ( (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) ) def dynamic_thresholding(self, x, p=0.995, c=1.7): """ Dynamic thresholding, a diffusion sampling technique from Imagen (https://arxiv.org/abs/2205.11487) to leverage high guidance weights and generating more photorealistic and detailed images than previously was possible based on x.clamp(-1, 1) vanilla clipping or static thresholding p — percentile determine relative value for clipping threshold for dynamic compression, helps prevent oversaturation recommend values [0.96 — 0.99] c — absolute hard clipping of value for clipping threshold for dynamic compression, helps prevent undersaturation and low contrast issues; recommend values [1.5 — 2.] """ x_shapes = x.shape s = torch.quantile(x.abs().reshape(x_shapes[0], -1), p, dim=-1) s = torch.clamp(s, min=1, max=c) x_compressed = torch.clip(x.reshape(x_shapes[0], -1).T, -s, s) / s x_compressed = x_compressed.T.reshape(x_shapes) return x_compressed def q_mean_variance(self, x_start, t): """ Get the distribution q(x_t | x_0). :param x_start: the [N x C x ...] tensor of noiseless inputs. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ mean = ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start ) variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = _extract_into_tensor( self.log_one_minus_alphas_cumprod, t, x_start.shape ) return mean, variance, log_variance def q_sample(self, x_start, t, noise=None): """ Diffuse the data for a given number of diffusion steps. In other words, sample from q(x_t | x_0). :param x_start: the initial data batch. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :param noise: if specified, the split-out normal noise. :return: A noisy version of x_start. """ if noise is None: noise = torch.randn_like(x_start) assert noise.shape == x_start.shape return ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) def q_posterior_mean_variance(self, x_start, x_t, t): """ Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0) """ assert x_start.shape == x_t.shape posterior_mean = ( _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = _extract_into_tensor( self.posterior_log_variance_clipped, t, x_t.shape ) assert ( posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == x_start.shape[0] ) return posterior_mean, posterior_variance, posterior_log_variance_clipped def p_mean_variance( self, model, x, t, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7, denoised_fn=None, model_kwargs=None ): """ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0. :param model: the model, which takes a signal and a batch of timesteps as input. :param x: the [N x C x ...] tensor at time t. :param t: a 1-D Tensor of timesteps. :param clip_denoised: if True, clip the denoised signal into [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. Applies before clip_denoised. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict with the following keys: - 'mean': the model mean output. - 'variance': the model variance output. - 'log_variance': the log of 'variance'. - 'pred_xstart': the prediction for x_0. """ if model_kwargs is None: model_kwargs = {} B, C = x.shape[:2] assert t.shape == (B,) model_output = model(x, self._scale_timesteps(t), **model_kwargs) if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: assert model_output.shape == (B, C * 2, *x.shape[2:]) model_output, model_var_values = torch.split(model_output, C, dim=1) if self.model_var_type == ModelVarType.LEARNED: model_log_variance = model_var_values model_variance = torch.exp(model_log_variance) else: min_log = _extract_into_tensor( self.posterior_log_variance_clipped, t, x.shape ) max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) # The model_var_values is [-1, 1] for [min_var, max_var]. frac = (model_var_values + 1) / 2 model_log_variance = frac * max_log + (1 - frac) * min_log model_variance = torch.exp(model_log_variance) else: model_variance, model_log_variance = { # for fixedlarge, we set the initial (log-)variance like so # to get a better decoder log likelihood. ModelVarType.FIXED_LARGE: ( np.append(self.posterior_variance[1], self.betas[1:]), np.log(np.append(self.posterior_variance[1], self.betas[1:])), ), ModelVarType.FIXED_SMALL: ( self.posterior_variance, self.posterior_log_variance_clipped, ), }[self.model_var_type] model_variance = _extract_into_tensor(model_variance, t, x.shape) model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) def process_xstart(x): if denoised_fn is not None: x = denoised_fn(x) if clip_denoised: x = self.dynamic_thresholding(x, p=dynamic_thresholding_p, c=dynamic_thresholding_c) return x # x.clamp(-1, 1) return x if self.model_mean_type == ModelMeanType.PREVIOUS_X: pred_xstart = process_xstart( self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) ) model_mean = model_output elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: if self.model_mean_type == ModelMeanType.START_X: pred_xstart = process_xstart(model_output) else: pred_xstart = process_xstart( self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) ) model_mean, _, _ = self.q_posterior_mean_variance( x_start=pred_xstart, x_t=x, t=t ) else: raise NotImplementedError(self.model_mean_type) assert ( model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape ) return { 'mean': model_mean, 'variance': model_variance, 'log_variance': model_log_variance, 'pred_xstart': pred_xstart, } def _predict_xstart_from_eps(self, x_t, t, eps): assert x_t.shape == eps.shape return ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) def _predict_xstart_from_xprev(self, x_t, t, xprev): assert x_t.shape == xprev.shape return ( # (xprev - coef2*x_t) / coef1 _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev - _extract_into_tensor( self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape ) * x_t ) def _predict_eps_from_xstart(self, x_t, t, pred_xstart): return ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def _scale_timesteps(self, t): if self.rescale_timesteps: return t.float() * (1000.0 / self.num_timesteps) return t def p_sample( self, model, x, t, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7, denoised_fn=None, model_kwargs=None, inpainting_mask=None, ): """ Sample x_{t-1} from the model at the given timestep. :param model: the model to sample from. :param x: the current tensor at x_{t-1}. :param t: the value of t, starting at 0 for the first diffusion step. :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - 'sample': a random sample from the model. - 'pred_xstart': a prediction of x_0. """ out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, dynamic_thresholding_p=dynamic_thresholding_p, dynamic_thresholding_c=dynamic_thresholding_c, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) noise = torch.randn_like(x) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 if inpainting_mask is None: inpainting_mask = torch.ones_like(x, device=x.device) sample = out['mean'] + nonzero_mask * torch.exp(0.5 * out['log_variance']) * noise sample = (1 - inpainting_mask)*x + inpainting_mask*sample return {'sample': sample, 'pred_xstart': out['pred_xstart']} def p_sample_loop( self, model, shape, noise=None, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7, inpainting_mask=None, denoised_fn=None, model_kwargs=None, device=None, progress=False, sample_fn=None, ): """ Generate samples from the model. :param model: the model module. :param shape: the shape of the samples, (N, C, H, W). :param noise: if specified, the noise from the encoder to sample. Should be of the same shape as `shape`. :param clip_denoised: if True, clip x_start predictions to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param device: if specified, the device to create the samples on. If not specified, use a model parameter's device. :param progress: if True, show a tqdm progress bar. :return: a non-differentiable batch of samples. """ final = None for step_idx, sample in enumerate(self.p_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, dynamic_thresholding_p=dynamic_thresholding_p, dynamic_thresholding_c=dynamic_thresholding_c, denoised_fn=denoised_fn, inpainting_mask=inpainting_mask, model_kwargs=model_kwargs, device=device, progress=progress, )): if sample_fn is not None: sample = sample_fn(step_idx, sample) final = sample return final['sample'] def p_sample_loop_progressive( self, model, shape, inpainting_mask=None, noise=None, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7, denoised_fn=None, model_kwargs=None, device=None, progress=False, ): """ Generate samples from the model and yield intermediate samples from each timestep of diffusion. Arguments are the same as p_sample_loop(). Returns a generator over dicts, where each dict is the return value of p_sample(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = torch.randn(*shape, device=device) indices = list(range(self.num_timesteps))[::-1] if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) for i in indices: t = torch.tensor([i] * shape[0], device=device) with torch.no_grad(): out = self.p_sample( model, img, t, clip_denoised=clip_denoised, dynamic_thresholding_p=dynamic_thresholding_p, dynamic_thresholding_c=dynamic_thresholding_c, denoised_fn=denoised_fn, inpainting_mask=inpainting_mask, model_kwargs=model_kwargs, ) yield out img = out['sample'] def ddim_sample( self, model, x, t, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7, denoised_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. Same usage as p_sample(). """ out = self.p_mean_variance( model, x, t, dynamic_thresholding_p=dynamic_thresholding_p, dynamic_thresholding_c=dynamic_thresholding_c, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out['pred_xstart']) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( eta * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * torch.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. noise = torch.randn_like(x) mean_pred = ( out['pred_xstart'] * torch.sqrt(alpha_bar_prev) + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {'sample': sample, 'pred_xstart': out['pred_xstart']} def ddim_reverse_sample( self, model, x, t, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7, denoised_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t+1} from the model using DDIM reverse ODE. """ assert eta == 0.0, 'Reverse ODE only for deterministic path' out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, dynamic_thresholding_p=dynamic_thresholding_p, dynamic_thresholding_c=dynamic_thresholding_c, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out['pred_xstart'] ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) # Equation 12. reversed mean_pred = ( out['pred_xstart'] * torch.sqrt(alpha_bar_next) + torch.sqrt(1 - alpha_bar_next) * eps ) return {'sample': mean_pred, 'pred_xstart': out['pred_xstart']} def ddim_sample_loop( self, model, shape, noise=None, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7, denoised_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, sample_fn=None, ): """ Generate samples from the model using DDIM. Same usage as p_sample_loop(). """ final = None for step_idx, sample in enumerate(self.ddim_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, dynamic_thresholding_p=dynamic_thresholding_p, dynamic_thresholding_c=dynamic_thresholding_c, model_kwargs=model_kwargs, device=device, progress=progress, eta=eta, )): if sample_fn is not None: sample = sample_fn(step_idx, sample) final = sample return final['sample'] def ddim_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7, denoised_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, ): """ Use DDIM to sample from the model and yield intermediate samples from each timestep of DDIM. Same usage as p_sample_loop_progressive(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = torch.randn(*shape, device=device) indices = list(range(self.num_timesteps))[::-1] if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) for i in indices: t = torch.tensor([i] * shape[0], device=device) with torch.no_grad(): out = self.ddim_sample( model, img, t, clip_denoised=clip_denoised, dynamic_thresholding_p=dynamic_thresholding_p, dynamic_thresholding_c=dynamic_thresholding_c, denoised_fn=denoised_fn, model_kwargs=model_kwargs, eta=eta, ) yield out img = out['sample'] def _vb_terms_bpd( self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None ): """ Get a term for the variational lower-bound. The resulting units are bits (rather than nats, as one might expect). This allows for comparison to other papers. :return: a dict with the following keys: - 'output': a shape [N] tensor of NLLs or KLs. - 'pred_xstart': the x_0 predictions. """ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t ) out = self.p_mean_variance( model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs ) kl = normal_kl( true_mean, true_log_variance_clipped, out['mean'], out['log_variance'] ) kl = mean_flat(kl) / np.log(2.0) decoder_nll = -discretized_gaussian_log_likelihood( x_start, means=out['mean'], log_scales=0.5 * out['log_variance'] ) assert decoder_nll.shape == x_start.shape decoder_nll = mean_flat(decoder_nll) / np.log(2.0) # At the first timestep return the decoder NLL, # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) output = torch.where((t == 0), decoder_nll, kl) return {'output': output, 'pred_xstart': out['pred_xstart']} def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): """ Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ if model_kwargs is None: model_kwargs = {} if noise is None: noise = torch.randn_like(x_start) x_t = self.q_sample(x_start, t, noise=noise) terms = {} if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: terms['loss'] = self._vb_terms_bpd( model=model, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, model_kwargs=model_kwargs, )['output'] if self.loss_type == LossType.RESCALED_KL: terms['loss'] *= self.num_timesteps elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) if self.model_var_type in [ ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE, ]: B, C = x_t.shape[:2] assert model_output.shape == (B, C * 2, *x_t.shape[2:]) model_output, model_var_values = torch.split(model_output, C, dim=1) # Learn the variance using the variational bound, but don't let # it affect our mean prediction. frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1) terms['vb'] = self._vb_terms_bpd( model=lambda *args, r=frozen_out: r, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, )['output'] if self.loss_type == LossType.RESCALED_MSE: # Divide by 1000 for equivalence with initial implementation. # Without a factor of 1/1000, the VB term hurts the MSE term. terms['vb'] *= self.num_timesteps / 1000.0 target = { ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t )[0], ModelMeanType.START_X: x_start, ModelMeanType.EPSILON: noise, }[self.model_mean_type] assert model_output.shape == target.shape == x_start.shape terms['mse'] = mean_flat((target - model_output) ** 2) if 'vb' in terms: terms['loss'] = terms['mse'] + terms['vb'] else: terms['loss'] = terms['mse'] else: raise NotImplementedError(self.loss_type) return terms def _prior_bpd(self, x_start): """ Get the prior KL term for the variational lower-bound, measured in bits-per-dim. This term can't be optimized, as it only depends on the encoder. :param x_start: the [N x C x ...] tensor of inputs. :return: a batch of [N] KL values (in bits), one per batch element. """ batch_size = x_start.shape[0] t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) kl_prior = normal_kl( mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 ) return mean_flat(kl_prior) / np.log(2.0) def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): """ Compute the entire variational lower-bound, measured in bits-per-dim, as well as other related quantities. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param clip_denoised: if True, clip denoised samples. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - total_bpd: the total variational lower-bound, per batch element. - prior_bpd: the prior term in the lower-bound. - vb: an [N x T] tensor of terms in the lower-bound. - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. - mse: an [N x T] tensor of epsilon MSEs for each timestep. """ device = x_start.device batch_size = x_start.shape[0] vb = [] xstart_mse = [] mse = [] for t in list(range(self.num_timesteps))[::-1]: t_batch = torch.tensor([t] * batch_size, device=device) noise = torch.randn_like(x_start) x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) # Calculate VLB term at the current timestep with torch.no_grad(): out = self._vb_terms_bpd( model, x_start=x_start, x_t=x_t, t=t_batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs, ) vb.append(out['output']) xstart_mse.append(mean_flat((out['pred_xstart'] - x_start) ** 2)) eps = self._predict_eps_from_xstart(x_t, t_batch, out['pred_xstart']) mse.append(mean_flat((eps - noise) ** 2)) vb = torch.stack(vb, dim=1) xstart_mse = torch.stack(xstart_mse, dim=1) mse = torch.stack(mse, dim=1) prior_bpd = self._prior_bpd(x_start) total_bpd = vb.sum(dim=1) + prior_bpd return { 'total_bpd': total_bpd, 'prior_bpd': prior_bpd, 'vb': vb, 'xstart_mse': xstart_mse, 'mse': mse, } def _extract_into_tensor(arr, timesteps, broadcast_shape): """ Extract values from a 1-D numpy array for a batch of indices. :param arr: the 1-D numpy array. :param timesteps: a tensor of indices into the array to extract. :param broadcast_shape: a larger shape of K dimensions with the batch dimension equal to the length of timesteps. :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. """ res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() while len(res.shape) < len(broadcast_shape): res = res[..., None] return res.expand(broadcast_shape) ================================================ FILE: deepfloyd_if/model/losses.py ================================================ # -*- coding: utf-8 -*- """ Helpers for various likelihood-based losses. These are ported from the original Ho et al. diffusion model codebase: https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py """ import torch import numpy as np def normal_kl(mean1, logvar1, mean2, logvar2): """ Compute the KL divergence between two gaussians. Shapes are automatically broadcasted, so batches can be compared to scalars, among other use cases. """ tensor = None for obj in (mean1, logvar1, mean2, logvar2): if isinstance(obj, torch.Tensor): tensor = obj break assert tensor is not None, 'at least one argument must be a Tensor' # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for th.exp(). logvar1, logvar2 = [ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2) ] return 0.5 * ( -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) ) def approx_standard_normal_cdf(x): """ A fast approximation of the cumulative distribution function of the standard normal. """ return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3)))) def discretized_gaussian_log_likelihood(x, *, means, log_scales): """ Compute the log-likelihood of a Gaussian distribution discretizing to a given image. :param x: the target images. It is assumed that this was uint8 values, rescaled to the range [-1, 1]. :param means: the Gaussian mean Tensor. :param log_scales: the Gaussian log stddev Tensor. :return: a tensor like x of log probabilities (in nats). """ assert x.shape == means.shape == log_scales.shape centered_x = x - means inv_stdv = torch.exp(-log_scales) plus_in = inv_stdv * (centered_x + 1.0 / 255.0) cdf_plus = approx_standard_normal_cdf(plus_in) min_in = inv_stdv * (centered_x - 1.0 / 255.0) cdf_min = approx_standard_normal_cdf(min_in) log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) cdf_delta = cdf_plus - cdf_min log_probs = torch.where( x < -0.999, log_cdf_plus, torch.where(x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), ) assert log_probs.shape == x.shape return log_probs ================================================ FILE: deepfloyd_if/model/nn.py ================================================ # -*- coding: utf-8 -*- import math import torch import torch.nn.functional as F from torch import nn from torch import Tensor def mean_flat(tensor): """ Take the mean over all non-batch dimensions. """ return tensor.mean(dim=list(range(1, len(tensor.shape)))) def gelu(x): return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) @torch.jit.script def gelu_jit(x): """OpenAI's gelu implementation.""" return gelu(x) class GELUJit(torch.nn.Module): def forward(self, input: Tensor) -> Tensor: return gelu_jit(input) def get_activation(activation): if activation == 'silu': return torch.nn.SiLU() elif activation == 'gelu_jit': return GELUJit() elif activation == 'gelu': return torch.nn.GELU() elif activation == 'none': return torch.nn.Identity() else: raise ValueError(f'unknown activation type {activation}') class GroupNorm32(nn.GroupNorm): def __init__(self, num_groups, num_channels, eps=1e-5, dtype=None): super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, dtype=dtype) def forward(self, x): y = super().forward(x).to(x.dtype) return y class AttentionPooling(nn.Module): def __init__(self, num_heads, embed_dim, dtype=None): super().__init__() self.dtype = dtype self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim ** 0.5) self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) self.num_heads = num_heads self.dim_per_head = embed_dim // self.num_heads def forward(self, x): bs, length, width = x.size() def shape(x): # (bs, length, width) --> (bs, length, n_heads, dim_per_head) x = x.view(bs, -1, self.num_heads, self.dim_per_head) # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) x = x.transpose(1, 2) # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) x = x.reshape(bs*self.num_heads, -1, self.dim_per_head) # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length) x = x.transpose(1, 2) return x class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) x = torch.cat([class_token, x], dim=1) # (bs, length+1, width) # (bs*n_heads, class_token_length, dim_per_head) q = shape(self.q_proj(class_token)) # (bs*n_heads, length+class_token_length, dim_per_head) k = shape(self.k_proj(x)) v = shape(self.v_proj(x)) # (bs*n_heads, class_token_length, length+class_token_length): scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) weight = torch.einsum( 'bct,bcs->bts', q * scale, k * scale ) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) # (bs*n_heads, dim_per_head, class_token_length) a = torch.einsum('bts,bcs->bct', weight, v) # (bs, length+1, width) a = a.reshape(bs, -1, 1).transpose(1, 2) return a[:, 0, :] # cls_token def conv_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D convolution module. """ if dims == 1: return nn.Conv1d(*args, **kwargs) elif dims == 2: return nn.Conv2d(*args, **kwargs) elif dims == 3: return nn.Conv3d(*args, **kwargs) raise ValueError(f'unsupported dimensions: {dims}') def linear(*args, **kwargs): """ Create a linear module. """ return nn.Linear(*args, **kwargs) def avg_pool_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D average pooling module. """ if dims == 1: return nn.AvgPool1d(*args, **kwargs) elif dims == 2: return nn.AvgPool2d(*args, **kwargs) elif dims == 3: return nn.AvgPool3d(*args, **kwargs) raise ValueError(f'unsupported dimensions: {dims}') def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module def scale_module(module, scale): """ Scale the parameters of a module and return it. """ for p in module.parameters(): p.detach().mul_(scale) return module def normalization(channels, dtype=None): """ Make a standard normalization layer. :param channels: number of input channels. :return: an nn.Module for normalization. """ return GroupNorm32(num_channels=channels, num_groups=32, dtype=dtype) def timestep_embedding(timesteps, dim, max_period=10000, dtype=None): """ Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ if dtype is None: dtype = torch.float32 half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=timesteps.device, dtype=dtype) args = timesteps[:, None].type(dtype) * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def attention(q, k, v, d_k): scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) scores = F.softmax(scores, dim=-1) output = torch.matmul(scores, v) return output ================================================ FILE: deepfloyd_if/model/resample.py ================================================ # -*- coding: utf-8 -*- from abc import ABC, abstractmethod import torch import numpy as np class ScheduleSampler(ABC): """ A distribution over timesteps in the diffusion process, intended to reduce variance of the objective. By default, samplers perform unbiased importance sampling, in which the objective's mean is unchanged. However, subclasses may override sample() to change how the resampled terms are reweighted, allowing for actual changes in the objective. """ @abstractmethod def weights(self): """ Get a numpy array of weights, one per diffusion step. The weights needn't be normalized, but must be positive. """ def sample(self, batch_size, device): """ Importance-sample timesteps for a batch. :param batch_size: the number of timesteps. :param device: the torch device to save to. :return: a tuple (timesteps, weights): - timesteps: a tensor of timestep indices. - weights: a tensor of weights to scale the resulting losses. """ w = self.weights() p = w / np.sum(w) indices_np = np.random.choice(len(p), size=(batch_size,), p=p) indices = torch.from_numpy(indices_np).long().to(device) weights_np = 1 / (len(p) * p[indices_np]) weights = torch.from_numpy(weights_np).float().to(device) return indices, weights class UniformSampler(ScheduleSampler): def __init__(self, num_timesteps): self._weights = np.ones([num_timesteps]) def weights(self): return self._weights class StaticSampler(ABC): def sample(self, batch_size, device, static_step=100): indices_np = np.ones(batch_size, dtype=np.int) * static_step weights_np = np.ones(batch_size, dtype=np.int) indices = torch.from_numpy(indices_np).long().to(device) weights = torch.from_numpy(weights_np).float().to(device) return indices, weights ================================================ FILE: deepfloyd_if/model/respace.py ================================================ # -*- coding: utf-8 -*- import torch import numpy as np from . import gaussian_diffusion as gd def create_gaussian_diffusion( *, steps=1000, learn_sigma=False, sigma_small=False, noise_schedule='linear', use_kl=False, predict_xstart=False, rescale_timesteps=False, rescale_learned_sigmas=False, timestep_respacing='', ): betas = gd.get_named_beta_schedule(noise_schedule, steps) if use_kl: loss_type = gd.LossType.RESCALED_KL elif rescale_learned_sigmas: loss_type = gd.LossType.RESCALED_MSE else: loss_type = gd.LossType.MSE if not timestep_respacing: timestep_respacing = [steps] return SpacedDiffusion( use_timesteps=space_timesteps(steps, timestep_respacing), betas=betas, model_mean_type=( gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X ), model_var_type=( ( gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL ) if not learn_sigma else gd.ModelVarType.LEARNED_RANGE ), loss_type=loss_type, rescale_timesteps=rescale_timesteps, ) def space_timesteps(num_timesteps, section_counts): """ Create a list of timesteps to use from an original diffusion process, given the number of timesteps we want to take from equally-sized portions of the original process. For example, if there's 300 timesteps and the section counts are [10,15,20] then the first 100 timesteps are strided to be 10 timesteps, the second 100 are strided to be 15 timesteps, and the final 100 are strided to be 20. If the stride is a string starting with "ddim", then the fixed striding from the DDIM paper is used, and only one section is allowed. :param num_timesteps: the number of diffusion steps in the original process to divide up. :param section_counts: either a list of numbers, or a string containing comma-separated numbers, indicating the step count per section. As a special case, use "ddimN" where N is a number of steps to use the striding from the DDIM paper. :return: a set of diffusion steps from the original process to use. """ if isinstance(section_counts, str): if section_counts.startswith('ddim'): desired_count = int(section_counts[len('ddim'):]) for i in range(1, num_timesteps): if len(range(0, num_timesteps, i)) == desired_count: return set(range(0, num_timesteps, i)) raise ValueError( f'cannot create exactly {num_timesteps} steps with an integer stride' ) elif section_counts == 'fast27': steps = space_timesteps(num_timesteps, '10,10,3,2,2') # Help reduce DDIM artifacts from noisiest timesteps. steps.remove(num_timesteps - 1) steps.add(num_timesteps - 3) return steps section_counts = [int(x) for x in section_counts.split(',')] size_per = num_timesteps // len(section_counts) extra = num_timesteps % len(section_counts) start_idx = 0 all_steps = [] for i, section_count in enumerate(section_counts): size = size_per + (1 if i < extra else 0) if size < section_count: raise ValueError( f'cannot divide section of {size} steps into {section_count}' ) if section_count <= 1: frac_stride = 1 else: frac_stride = (size - 1) / (section_count - 1) cur_idx = 0.0 taken_steps = [] for _ in range(section_count): taken_steps.append(start_idx + round(cur_idx)) cur_idx += frac_stride all_steps += taken_steps start_idx += size return set(all_steps) class SpacedDiffusion(gd.GaussianDiffusion): """ A diffusion process which can skip steps in a base diffusion process. :param use_timesteps: a collection (sequence or set) of timesteps from the original diffusion process to retain. :param kwargs: the kwargs to create the base diffusion process. """ def __init__(self, use_timesteps, **kwargs): self.use_timesteps = set(use_timesteps) self.timestep_map = [] self.original_num_steps = len(kwargs['betas']) base_diffusion = gd.GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa last_alpha_cumprod = 1.0 new_betas = [] for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): if i in self.use_timesteps: new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) last_alpha_cumprod = alpha_cumprod self.timestep_map.append(i) kwargs['betas'] = np.array(new_betas) super().__init__(**kwargs) def p_mean_variance( self, model, *args, **kwargs ): # pylint: disable=signature-differs return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) def training_losses( self, model, *args, **kwargs ): # pylint: disable=signature-differs return super().training_losses(self._wrap_model(model), *args, **kwargs) def _wrap_model(self, model): if isinstance(model, _WrappedModel): return model return _WrappedModel( model, self.timestep_map, self.rescale_timesteps, self.original_num_steps ) def _scale_timesteps(self, t): # Scaling is done by the wrapped model. return t class _WrappedModel: def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): self.model = model self.timestep_map = timestep_map self.rescale_timesteps = rescale_timesteps self.original_num_steps = original_num_steps def __call__(self, x, ts, **kwargs): map_tensor = torch.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) new_ts = map_tensor[ts] if self.rescale_timesteps: new_ts = new_ts.float() * (1000.0 / self.original_num_steps) return self.model(x, new_ts, **kwargs) ================================================ FILE: deepfloyd_if/model/unet.py ================================================ # -*- coding: utf-8 -*- import os import math from abc import abstractmethod import torch import numpy as np import torch.nn as nn import torch.nn.functional as F from .nn import avg_pool_nd, conv_nd, linear, normalization, timestep_embedding, zero_module, get_activation, \ AttentionPooling _FORCE_MEM_EFFICIENT_ATTN = int(os.environ.get('FORCE_MEM_EFFICIENT_ATTN', 0)) print('FORCE_MEM_EFFICIENT_ATTN=', _FORCE_MEM_EFFICIENT_ATTN, '@UNET:QKVATTENTION') if _FORCE_MEM_EFFICIENT_ATTN: from xformers.ops import memory_efficient_attention # noqa class TimestepBlock(nn.Module): """ Any module where forward() takes timestep embeddings as a second argument. """ @abstractmethod def forward(self, x, emb): """ Apply the module to `x` given `emb` timestep embeddings. """ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): """ A sequential module that passes timestep embeddings to the children that support it as an extra input. """ def forward(self, x, emb, encoder_out=None): for layer in self: if isinstance(layer, TimestepBlock): x = layer(x, emb) elif isinstance(layer, AttentionBlock): x = layer(x, encoder_out) else: x = layer(x) return x class Upsample(nn.Module): """ An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining a convolution is applied. :param dims: determines the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions. """ def __init__(self, channels, use_conv, dims=2, out_channels=None, dtype=None): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims self.dtype = dtype if use_conv: self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1, dtype=self.dtype) def forward(self, x): assert x.shape[1] == self.channels if self.dims == 3: x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest') else: if self.dtype == torch.bfloat16: x = x.type(torch.float32 if x.device.type == 'cpu' else torch.float16) x = F.interpolate(x, scale_factor=2, mode='nearest') if self.dtype == torch.bfloat16: x = x.type(torch.bfloat16) if self.use_conv: x = self.conv(x) return x class Downsample(nn.Module): """ A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining a convolution is applied. :param dims: determines the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions. """ def __init__(self, channels, use_conv, dims=2, out_channels=None, dtype=None): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims self.dtype = dtype stride = 2 if dims != 3 else (1, 2, 2) if use_conv: self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1, dtype=self.dtype) else: assert self.channels == self.out_channels self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) def forward(self, x): assert x.shape[1] == self.channels return self.op(x) class ResBlock(TimestepBlock): """ A residual block that can optionally change the number of channels. :param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param out_channels: specified, the number of out channels. :param use_conv: True and out_channels is specified, use a spatial convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. :param dims: determines the signal is 1D, 2D, or 3D. :param up: True, use this block for upsampling. :param down: True, use this block for downsampling. """ def __init__( self, channels, emb_channels, dropout, activation, out_channels=None, use_conv=False, use_scale_shift_norm=False, dims=2, up=False, down=False, dtype=None, efficient_activation=False, scale_skip_connection=False, ): super().__init__() self.dtype = dtype self.channels = channels self.emb_channels = emb_channels self.dropout = dropout self.out_channels = out_channels or channels self.use_conv = use_conv self.use_scale_shift_norm = use_scale_shift_norm self.efficient_activation = efficient_activation self.scale_skip_connection = scale_skip_connection self.in_layers = nn.Sequential( normalization(channels, dtype=self.dtype), get_activation(activation), conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=self.dtype), ) self.updown = up or down if up: self.h_upd = Upsample(channels, False, dims, dtype=self.dtype) self.x_upd = Upsample(channels, False, dims, dtype=self.dtype) elif down: self.h_upd = Downsample(channels, False, dims, dtype=self.dtype) self.x_upd = Downsample(channels, False, dims, dtype=self.dtype) else: self.h_upd = self.x_upd = nn.Identity() self.emb_layers = nn.Sequential( nn.Identity() if self.efficient_activation else get_activation(activation), linear( emb_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=self.dtype ), ) self.out_layers = nn.Sequential( normalization(self.out_channels, dtype=self.dtype), get_activation(activation), nn.Dropout(p=dropout), zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=self.dtype)), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=self.dtype) else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=self.dtype) def forward(self, x, emb): """ Apply the block to a Tensor, conditioned on a timestep embedding. :param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings. :return: an [N x C x ...] Tensor of outputs. """ if self.updown: in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] h = in_rest(x) h = self.h_upd(h) x = self.x_upd(x) h = in_conv(h) else: h = self.in_layers(x) emb_out = self.emb_layers(emb).type(h.dtype) while len(emb_out.shape) < len(h.shape): emb_out = emb_out[..., None] if self.use_scale_shift_norm: out_norm, out_rest = self.out_layers[0], self.out_layers[1:] scale, shift = torch.chunk(emb_out, 2, dim=1) h = out_norm(h) * (1 + scale) + shift h = out_rest(h) else: h = h + emb_out h = self.out_layers(h) res = self.skip_connection(x) + h if self.scale_skip_connection: res *= 0.7071 # 1 / sqrt(2), https://arxiv.org/pdf/2104.07636.pdf return res class AttentionBlock(nn.Module): """ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted to the N-d case. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. """ def __init__( self, channels, num_heads=1, num_head_channels=-1, disable_self_attention=False, encoder_channels=None, dtype=None, ): super().__init__() self.dtype = dtype self.channels = channels self.disable_self_attention = disable_self_attention if num_head_channels == -1: self.num_heads = num_heads else: assert ( channels % num_head_channels == 0 ), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}' self.num_heads = channels // num_head_channels self.norm = normalization(channels, dtype=self.dtype) self.qkv = conv_nd(1, channels, channels * 3, 1, dtype=self.dtype) if self.disable_self_attention: self.qkv = conv_nd(1, channels, channels, 1, dtype=self.dtype) else: self.qkv = conv_nd(1, channels, channels * 3, 1, dtype=self.dtype) self.attention = QKVAttention(self.num_heads, disable_self_attention=disable_self_attention) if encoder_channels is not None: self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1, dtype=self.dtype) self.norm_encoder = normalization(encoder_channels, dtype=self.dtype) self.proj_out = zero_module(conv_nd(1, channels, channels, 1, dtype=self.dtype)) def forward(self, x, encoder_out=None): b, c, *spatial = x.shape qkv = self.qkv(self.norm(x).view(b, c, -1)) if encoder_out is not None: # from imagen article: https://arxiv.org/pdf/2205.11487.abs encoder_out = self.norm_encoder(encoder_out) # # # encoder_out = self.encoder_kv(encoder_out) h = self.attention(qkv, encoder_out) else: h = self.attention(qkv) h = self.proj_out(h) return x + h.reshape(b, c, *spatial) class QKVAttention(nn.Module): """ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping """ def __init__(self, n_heads, disable_self_attention=False): super().__init__() self.n_heads = n_heads self.disable_self_attention = disable_self_attention def forward(self, qkv, encoder_kv=None): """ Apply QKV attention. :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after attention. """ bs, width, length = qkv.shape if self.disable_self_attention: ch = width // (1 * self.n_heads) q, = qkv.reshape(bs * self.n_heads, ch * 1, length).split(ch, dim=1) else: assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) if encoder_kv is not None: assert encoder_kv.shape[1] == self.n_heads * ch * 2 if self.disable_self_attention: k, v = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1) else: ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1) k = torch.cat([ek, k], dim=-1) v = torch.cat([ev, v], dim=-1) scale = 1 / math.sqrt(math.sqrt(ch)) if _FORCE_MEM_EFFICIENT_ATTN: q, k, v = map(lambda t: t.permute(0, 2, 1).contiguous(), (q, k, v)) a = memory_efficient_attention(q, k, v) a = a.permute(0, 2, 1) else: weight = torch.einsum( 'bct,bcs->bts', q * scale, k * scale ) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) a = torch.einsum('bts,bcs->bct', weight, v) return a.reshape(bs, -1, length) class UNetModel(nn.Module): """ The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample rates at which attention will take place. May be a set, list, or tuple. For example, this contains 4, then at 4x downsampling, attention will be used. :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param conv_resample: True, use learned convolutions for upsampling and downsampling. :param dims: determines the signal is 1D, 2D, or 3D. :param num_classes: specified (as an int), then this model will be class-conditional with `num_classes` classes. :param num_heads: the number of attention heads in each attention layer. :param num_heads_channels: specified, ignore num_heads and instead use a fixed channel width per attention head. :param num_heads_upsample: works with num_heads to set a different number of heads for upsampling. Deprecated. :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks for up/downsampling. """ def __init__( self, in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, activation, encoder_dim, att_pool_heads, encoder_channels, image_size, disable_self_attentions=None, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, num_classes=None, precision='32', num_heads=1, num_head_channels=-1, num_heads_upsample=-1, use_scale_shift_norm=False, resblock_updown=False, efficient_activation=False, scale_skip_connection=False, ): super().__init__() if num_heads_upsample == -1: num_heads_upsample = num_heads self.encoder_channels = encoder_channels self.encoder_dim = encoder_dim self.efficient_activation = efficient_activation self.scale_skip_connection = scale_skip_connection self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels self.dropout = dropout # adapt attention resolutions if isinstance(attention_resolutions, str): self.attention_resolutions = [] for res in attention_resolutions.split(','): self.attention_resolutions.append(image_size // int(res)) else: self.attention_resolutions = attention_resolutions self.attention_resolutions = tuple(self.attention_resolutions) # # adapt disable self attention resolutions if not disable_self_attentions: self.disable_self_attentions = [] elif disable_self_attentions is True: self.disable_self_attentions = attention_resolutions elif isinstance(disable_self_attentions, str): self.disable_self_attentions = [] for res in disable_self_attentions.split(','): self.disable_self_attentions.append(image_size // int(res)) else: self.disable_self_attentions = disable_self_attentions self.disable_self_attentions = tuple(self.disable_self_attentions) # # adapt channel mult if isinstance(channel_mult, str): self.channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(',')) else: self.channel_mult = tuple(channel_mult) # self.conv_resample = conv_resample self.num_classes = num_classes self.dtype = torch.float32 self.precision = str(precision) self.use_fp16 = precision == '16' if self.precision == '16': self.dtype = torch.float16 elif self.precision == 'bf16': self.dtype = torch.bfloat16 self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample self.time_embed_dim = model_channels * max(self.channel_mult) self.time_embed = nn.Sequential( linear(model_channels, self.time_embed_dim, dtype=self.dtype), get_activation(activation), linear(self.time_embed_dim, self.time_embed_dim, dtype=self.dtype), ) if self.num_classes is not None: self.label_emb = nn.Embedding(num_classes, self.time_embed_dim) ch = input_ch = int(self.channel_mult[0] * model_channels) self.input_blocks = nn.ModuleList( [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1, dtype=self.dtype))] ) self._feature_size = ch input_block_chans = [ch] ds = 1 if isinstance(num_res_blocks, int): num_res_blocks = [num_res_blocks]*len(self.channel_mult) self.num_res_blocks = num_res_blocks for level, mult in enumerate(self.channel_mult): for _ in range(num_res_blocks[level]): layers = [ ResBlock( ch, self.time_embed_dim, dropout, out_channels=int(mult * model_channels), dims=dims, use_scale_shift_norm=use_scale_shift_norm, dtype=self.dtype, activation=activation, efficient_activation=self.efficient_activation, scale_skip_connection=self.scale_skip_connection, ) ] ch = int(mult * model_channels) if ds in self.attention_resolutions: layers.append( AttentionBlock( ch, num_heads=num_heads, num_head_channels=num_head_channels, encoder_channels=encoder_channels, dtype=self.dtype, disable_self_attention=ds in self.disable_self_attentions, ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) if level != len(self.channel_mult) - 1: out_ch = ch self.input_blocks.append( TimestepEmbedSequential( ResBlock( ch, self.time_embed_dim, dropout, out_channels=out_ch, dims=dims, use_scale_shift_norm=use_scale_shift_norm, down=True, dtype=self.dtype, activation=activation, efficient_activation=self.efficient_activation, scale_skip_connection=self.scale_skip_connection, ) if resblock_updown else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) ) ) ch = out_ch input_block_chans.append(ch) ds *= 2 self._feature_size += ch self.middle_block = TimestepEmbedSequential( ResBlock( ch, self.time_embed_dim, dropout, dims=dims, use_scale_shift_norm=use_scale_shift_norm, dtype=self.dtype, activation=activation, efficient_activation=self.efficient_activation, scale_skip_connection=self.scale_skip_connection, ), AttentionBlock( ch, num_heads=num_heads, num_head_channels=num_head_channels, encoder_channels=encoder_channels, dtype=self.dtype, disable_self_attention=ds in self.disable_self_attentions, ), ResBlock( ch, self.time_embed_dim, dropout, dims=dims, use_scale_shift_norm=use_scale_shift_norm, dtype=self.dtype, activation=activation, efficient_activation=self.efficient_activation, scale_skip_connection=self.scale_skip_connection, ), ) self._feature_size += ch self.output_blocks = nn.ModuleList([]) for level, mult in list(enumerate(self.channel_mult))[::-1]: for i in range(num_res_blocks[level] + 1): ich = input_block_chans.pop() layers = [ ResBlock( ch + ich, self.time_embed_dim, dropout, out_channels=int(model_channels * mult), dims=dims, use_scale_shift_norm=use_scale_shift_norm, dtype=self.dtype, activation=activation, efficient_activation=self.efficient_activation, scale_skip_connection=self.scale_skip_connection, ) ] ch = int(model_channels * mult) if ds in self.attention_resolutions: layers.append( AttentionBlock( ch, num_heads=num_heads_upsample, num_head_channels=num_head_channels, encoder_channels=encoder_channels, dtype=self.dtype, disable_self_attention=ds in self.disable_self_attentions, ) ) if level and i == num_res_blocks[level]: out_ch = ch layers.append( ResBlock( ch, self.time_embed_dim, dropout, out_channels=out_ch, dims=dims, use_scale_shift_norm=use_scale_shift_norm, up=True, dtype=self.dtype, activation=activation, efficient_activation=self.efficient_activation, scale_skip_connection=self.scale_skip_connection, ) if resblock_updown else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) ) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch self.out = nn.Sequential( normalization(ch, dtype=self.dtype), get_activation(activation), zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1, dtype=self.dtype)), ) self.activation_layer = get_activation(activation) if self.efficient_activation else nn.Identity() self.encoder_pooling = nn.Sequential( nn.LayerNorm(encoder_dim, dtype=self.dtype), AttentionPooling(att_pool_heads, encoder_dim, dtype=self.dtype), nn.Linear(encoder_dim, self.time_embed_dim, dtype=self.dtype), nn.LayerNorm(self.time_embed_dim, dtype=self.dtype) ) if encoder_dim != encoder_channels: self.encoder_proj = nn.Linear(encoder_dim, encoder_channels, dtype=self.dtype) else: self.encoder_proj = nn.Identity() self.cache = None def forward(self, x, timesteps, text_emb, timestep_text_emb=None, aug_emb=None, use_cache=False, **kwargs): hs = [] emb = self.time_embed(timestep_embedding(timesteps, self.model_channels, dtype=self.dtype)) if use_cache and self.cache is not None: encoder_out, encoder_pool = self.cache else: text_emb = text_emb.type(self.dtype) encoder_out = self.encoder_proj(text_emb) encoder_out = encoder_out.permute(0, 2, 1) # NLC -> NCL if timestep_text_emb is None: timestep_text_emb = text_emb encoder_pool = self.encoder_pooling(timestep_text_emb) if use_cache: self.cache = (encoder_out, encoder_pool) emb = emb + encoder_pool.to(emb) if aug_emb is not None: emb = emb + aug_emb.to(emb) emb = self.activation_layer(emb) h = x.type(self.dtype) for module in self.input_blocks: h = module(h, emb, encoder_out) hs.append(h) h = self.middle_block(h, emb, encoder_out) for module in self.output_blocks: h = torch.cat([h, hs.pop()], dim=1) h = module(h, emb, encoder_out) h = h.type(self.dtype) h = self.out(h) return h class SuperResUNetModel(UNetModel): """ A text2im model that performs super-resolution. Expects an extra kwarg `low_res` to condition on a low-resolution image. """ def __init__(self, low_res_diffusion, interpolate_mode='bilinear', *args, **kwargs): self.low_res_diffusion = low_res_diffusion self.interpolate_mode = interpolate_mode super().__init__(*args, **kwargs) self.aug_proj = nn.Sequential( linear(self.model_channels, self.time_embed_dim, dtype=self.dtype), get_activation(kwargs['activation']), linear(self.time_embed_dim, self.time_embed_dim, dtype=self.dtype), ) def forward(self, x, timesteps, low_res, aug_level=None, **kwargs): bs, _, new_height, new_width = x.shape align_corners = True if self.interpolate_mode == 'nearest': align_corners = None upsampled = F.interpolate( low_res, (new_height, new_width), mode=self.interpolate_mode, align_corners=align_corners ) if aug_level is None: aug_steps = (np.random.random(bs)*1000).astype(np.int64) # uniform [0, 1) aug_steps = torch.from_numpy(aug_steps).to(x.device, dtype=torch.long) else: aug_steps = torch.tensor([int(aug_level * 1000)]).repeat(bs).to(x.device, dtype=torch.long) upsampled = self.low_res_diffusion.q_sample(upsampled, aug_steps) x = torch.cat([x, upsampled], dim=1) aug_emb = self.aug_proj( timestep_embedding(aug_steps, self.model_channels, dtype=self.dtype) ) return super().forward(x, timesteps, aug_emb=aug_emb, **kwargs) ================================================ FILE: deepfloyd_if/modules/__init__.py ================================================ # -*- coding: utf-8 -*- from .stage_I import IFStageI from .stage_II import IFStageII from .stage_III import IFStageIII from .stage_III_sd_x4 import StableStageIII from .t5 import T5Embedder from .base import IFBaseModule __all__ = ['IFBaseModule', 'IFStageI', 'IFStageII', 'IFStageIII', 'StableStageIII', 'T5Embedder'] ================================================ FILE: deepfloyd_if/modules/base.py ================================================ # -*- coding: utf-8 -*- import os import random import platform from datetime import datetime import torch import torchvision import numpy as np import matplotlib.pyplot as plt import torchvision.transforms as T from PIL import Image from omegaconf import OmegaConf from huggingface_hub import hf_hub_download from accelerate.utils import set_module_tensor_to_device from .. import utils from ..model.respace import create_gaussian_diffusion from .utils import load_model_weights, predict_proba, clip_process_generations class IFBaseModule: stage = '-' available_models = [] cpu_zero_emb = np.load(os.path.join(utils.RESOURCES_ROOT, 'zero_t5-v1_1-xxl_vector.npy')) cpu_zero_emb = torch.from_numpy(cpu_zero_emb) respacing_modes = { 'fast27': '10,10,3,2,2', 'smart27': '7,4,2,1,2,4,7', 'smart50': '10,6,4,3,2,2,3,4,6,10', 'smart100': '1,1,1,1,2,2,2,2,2,2,3,3,4,4,5,5,6,7,7,8,9,10,13', 'smart185': '1,1,2,2,2,3,3,3,4,5,6,7,8,9,10,11,12,13,14,15,16,18,20', 'super27': '1,1,1,1,1,1,1,2,5,13', # for III super-res 'super40': '2,2,2,2,2,2,3,4,6,15', # for III super-res 'super100': '4,4,6,6,8,8,10,10,14,30', # for III super-res } wm_pil_img = Image.open(os.path.join(utils.RESOURCES_ROOT, 'wm.png')) try: import clip # noqa except ModuleNotFoundError: print('Warning! You should install CLIP: "pip install git+https://github.com/openai/CLIP.git --no-deps"') raise clip_model, clip_preprocess = clip.load('ViT-L/14', device='cpu') clip_model.eval() cpu_w_weights, cpu_w_biases = load_model_weights(os.path.join(utils.RESOURCES_ROOT, 'w_head_v1.npz')) cpu_p_weights, cpu_p_biases = load_model_weights(os.path.join(utils.RESOURCES_ROOT, 'p_head_v1.npz')) w_threshold, p_threshold = 0.5, 0.5 def __init__(self, dir_or_name, device, pil_img_size=256, cache_dir=None, hf_token=None): self.hf_token = hf_token self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_') self.dir_or_name = dir_or_name self.conf = self.load_conf(dir_or_name) if not self.use_diffusers else None self.device = torch.device(device) self.zero_emb = self.cpu_zero_emb.clone().to(self.device) self.pil_img_size = pil_img_size @property def use_diffusers(self): return False def embeddings_to_image( self, t5_embs, low_res=None, *, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1, dynamic_thresholding_p=0.95, sample_loop='ddpm', sample_timestep_respacing='smart185', dynamic_thresholding_c=1.5, guidance_scale=7.0, aug_level=0.25, positive_mixer=0.15, blur_sigma=None, img_size=None, img_scale=4.0, aspect_ratio='1:1', progress=True, seed=None, sample_fn=None, support_noise=None, support_noise_less_qsample_steps=0, inpainting_mask=None, **kwargs, ): self._clear_cache() image_w, image_h = self._get_image_sizes(low_res, img_size, aspect_ratio, img_scale) diffusion = self.get_diffusion(sample_timestep_respacing) bs_scale = 2 if positive_t5_embs is None else 3 def model_fn(x_t, ts, **kwargs): half = x_t[: len(x_t) // bs_scale] combined = torch.cat([half]*bs_scale, dim=0) model_out = self.model(combined, ts, **kwargs) eps, rest = model_out[:, :3], model_out[:, 3:] if bs_scale == 3: cond_eps, pos_cond_eps, uncond_eps = torch.split(eps, len(eps) // bs_scale, dim=0) half_eps = uncond_eps + guidance_scale * ( cond_eps * (1 - positive_mixer) + pos_cond_eps * positive_mixer - uncond_eps) pos_half_eps = uncond_eps + guidance_scale * (pos_cond_eps - uncond_eps) eps = torch.cat([half_eps, pos_half_eps, half_eps], dim=0) else: cond_eps, uncond_eps = torch.split(eps, len(eps) // bs_scale, dim=0) half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) return torch.cat([eps, rest], dim=1) seed = self.seed_everything(seed) text_emb = t5_embs.to(self.device, dtype=self.model.dtype).repeat(batch_repeat, 1, 1) batch_size = text_emb.shape[0] * batch_repeat if positive_t5_embs is not None: positive_t5_embs = positive_t5_embs.to(self.device, dtype=self.model.dtype).repeat(batch_repeat, 1, 1) if negative_t5_embs is not None: negative_t5_embs = negative_t5_embs.to(self.device, dtype=self.model.dtype).repeat(batch_repeat, 1, 1) timestep_text_emb = None if style_t5_embs is not None: list_timestep_text_emb = [ style_t5_embs.to(self.device, dtype=self.model.dtype).repeat(batch_repeat, 1, 1), ] if positive_t5_embs is not None: list_timestep_text_emb.append(positive_t5_embs) if negative_t5_embs is not None: list_timestep_text_emb.append(negative_t5_embs) else: list_timestep_text_emb.append( self.zero_emb.unsqueeze(0).repeat(batch_size, 1, 1).to(self.device, dtype=self.model.dtype)) timestep_text_emb = torch.cat(list_timestep_text_emb, dim=0).to(self.device, dtype=self.model.dtype) metadata = { 'seed': seed, 'guidance_scale': guidance_scale, 'dynamic_thresholding_p': dynamic_thresholding_p, 'dynamic_thresholding_c': dynamic_thresholding_c, 'batch_size': batch_size, 'device_name': self.device_name, 'img_size': [image_w, image_h], 'sample_loop': sample_loop, 'sample_timestep_respacing': sample_timestep_respacing, 'stage': self.stage, } list_text_emb = [t5_embs.to(self.device)] if positive_t5_embs is not None: list_text_emb.append(positive_t5_embs.to(self.device)) if negative_t5_embs is not None: list_text_emb.append(negative_t5_embs.to(self.device)) else: list_text_emb.append( self.zero_emb.unsqueeze(0).repeat(batch_size, 1, 1).to(self.device, dtype=self.model.dtype)) model_kwargs = dict( text_emb=torch.cat(list_text_emb, dim=0).to(self.device, dtype=self.model.dtype), timestep_text_emb=timestep_text_emb, use_cache=True, ) if low_res is not None: if blur_sigma is not None: low_res = T.GaussianBlur(3, sigma=(blur_sigma, blur_sigma))(low_res) model_kwargs['low_res'] = torch.cat([low_res]*bs_scale, dim=0).to(self.device) model_kwargs['aug_level'] = aug_level if support_noise is None: noise = torch.randn( (batch_size * bs_scale, 3, image_h, image_w), device=self.device, dtype=self.model.dtype) else: assert support_noise_less_qsample_steps < len(diffusion.timestep_map) - 1 assert support_noise.shape == (1, 3, image_h, image_w) q_sample_steps = torch.tensor([int(len(diffusion.timestep_map) - 1 - support_noise_less_qsample_steps)]) support_noise = support_noise.cpu() noise = support_noise.clone() noise[inpainting_mask.cpu().bool() if inpainting_mask is not None else ...] = diffusion.q_sample( support_noise[inpainting_mask.cpu().bool() if inpainting_mask is not None else ...], q_sample_steps, ) noise = noise.repeat(batch_size*bs_scale, 1, 1, 1).to(device=self.device, dtype=self.model.dtype) if inpainting_mask is not None: inpainting_mask = inpainting_mask.to(device=self.device, dtype=torch.long) if sample_loop == 'ddpm': with torch.no_grad(): sample = diffusion.p_sample_loop( model_fn, (batch_size * bs_scale, 3, image_h, image_w), noise=noise, clip_denoised=True, model_kwargs=model_kwargs, dynamic_thresholding_p=dynamic_thresholding_p, dynamic_thresholding_c=dynamic_thresholding_c, inpainting_mask=inpainting_mask, device=self.device, progress=progress, sample_fn=sample_fn, )[:batch_size] elif sample_loop == 'ddim': with torch.no_grad(): sample = diffusion.ddim_sample_loop( model_fn, (batch_size * bs_scale, 3, image_h, image_w), noise=noise, clip_denoised=True, model_kwargs=model_kwargs, dynamic_thresholding_p=dynamic_thresholding_p, dynamic_thresholding_c=dynamic_thresholding_c, device=self.device, progress=progress, sample_fn=sample_fn, )[:batch_size] else: raise ValueError(f'Sample loop "{sample_loop}" doesnt support') sample = self.__validate_generations(sample) self._clear_cache() return sample, metadata def load_conf(self, dir_or_name, filename='config.yml'): path = self._get_path_or_download_file_from_hf(dir_or_name, filename) conf = OmegaConf.load(path) return conf def load_checkpoint(self, model, dir_or_name, filename='pytorch_model.bin'): path = self._get_path_or_download_file_from_hf(dir_or_name, filename) if os.path.exists(path): checkpoint = torch.load(path, map_location='cpu') param_device = 'cpu' for param_name, param in checkpoint.items(): set_module_tensor_to_device(model, param_name, param_device, value=param) else: print(f'Warning! In directory "{dir_or_name}" filename "pytorch_model.bin" is not found.') return model def _get_path_or_download_file_from_hf(self, dir_or_name, filename): if dir_or_name in self.available_models: cache_dir = os.path.join(self.cache_dir, dir_or_name) hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir, force_filename=filename, token=self.hf_token) return os.path.join(cache_dir, filename) else: return os.path.join(dir_or_name, filename) def get_diffusion(self, timestep_respacing): timestep_respacing = self.respacing_modes.get(timestep_respacing, timestep_respacing) diffusion = create_gaussian_diffusion( steps=1000, learn_sigma=True, sigma_small=False, noise_schedule='cosine', use_kl=False, predict_xstart=False, rescale_timesteps=True, rescale_learned_sigmas=True, timestep_respacing=timestep_respacing, ) return diffusion @staticmethod def seed_everything(seed=None): if seed is None: seed = int((datetime.utcnow().timestamp() * 10 ** 6) % (2 ** 32 - 1)) random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True return seed def device_name(self): if self.device.type == 'cpu': return 'cpu_' + str(platform.processor()) if self.device.type == 'cuda': return torch.cuda.get_device_name(self.device) return '-' def to_images(self, generations, disable_watermark=False): bs, c, h, w = generations.shape coef = min(h / self.pil_img_size, w / self.pil_img_size) img_h, img_w = (int(h / coef), int(w / coef)) if coef < 1 else (h, w) S1, S2 = 1024 ** 2, img_w * img_h K = (S2 / S1) ** 0.5 wm_size, wm_x, wm_y = int(K * 62), img_w - int(14 * K), img_h - int(14 * K) wm_img = self.wm_pil_img.resize( (wm_size, wm_size), getattr(Image, 'Resampling', Image).BICUBIC, reducing_gap=None) pil_images = [] for image in ((generations + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu(): pil_img = torchvision.transforms.functional.to_pil_image(image).convert('RGB') pil_img = pil_img.resize((img_w, img_h), getattr(Image, 'Resampling', Image).NEAREST) if not disable_watermark: pil_img.paste(wm_img, box=(wm_x - wm_size, wm_y - wm_size, wm_x, wm_y), mask=wm_img.split()[-1]) pil_images.append(pil_img) return pil_images def show(self, pil_images, nrow=None, size=10): if nrow is None: nrow = round(len(pil_images)**0.5) imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow) if not isinstance(imgs, list): imgs = [imgs.cpu()] fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(size, size)) for i, img in enumerate(imgs): img = img.detach() img = torchvision.transforms.functional.to_pil_image(img) axs[0, i].imshow(np.asarray(img)) axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) fix.show() plt.show() def _clear_cache(self): self.model.cache = None def _get_image_sizes(self, low_res, img_size, aspect_ratio, img_scale): if low_res is not None: bs, c, h, w = low_res.shape image_h, image_w = int((h*img_scale)//32)*32, int((w*img_scale//32))*32 else: scale_w, scale_h = aspect_ratio.split(':') scale_w, scale_h = int(scale_w), int(scale_h) coef = scale_w / scale_h image_h, image_w = img_size, img_size if coef >= 1: image_w = int(round(img_size/8 * coef) * 8) else: image_h = int(round(img_size/8 / coef) * 8) assert image_h % 8 == 0 assert image_w % 8 == 0 return image_w, image_h def __validate_generations(self, generations): with torch.no_grad(): imgs = clip_process_generations(generations) image_features = self.clip_model.encode_image(imgs.to('cpu')) image_features = image_features.detach().cpu().numpy().astype(np.float16) p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases) w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases) query = p_pred > self.p_threshold if query.sum() > 0: generations[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(generations[query]) query = w_pred > self.w_threshold if query.sum() > 0: generations[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(generations[query]) return generations ================================================ FILE: deepfloyd_if/modules/stage_I.py ================================================ # -*- coding: utf-8 -*- import accelerate from .base import IFBaseModule from ..model import UNetModel class IFStageI(IFBaseModule): stage = 'I' available_models = ['IF-I-M-v1.0', 'IF-I-L-v1.0', 'IF-I-XL-v1.0'] def __init__(self, *args, model_kwargs=None, pil_img_size=64, **kwargs): """ :param conf_or_path: :param device: :param cache_dir: :param use_auth_token: """ super().__init__(*args, pil_img_size=pil_img_size, **kwargs) model_params = dict(self.conf.params) model_params.update(model_kwargs or {}) with accelerate.init_empty_weights(): self.model = UNetModel(**model_params) self.model = self.load_checkpoint(self.model, self.dir_or_name) self.model.eval().to(self.device) def embeddings_to_image(self, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1, dynamic_thresholding_p=0.95, sample_loop='ddpm', positive_mixer=0.25, sample_timestep_respacing='150', dynamic_thresholding_c=1.5, guidance_scale=7.0, aspect_ratio='1:1', progress=True, seed=None, sample_fn=None, **kwargs): return super().embeddings_to_image( t5_embs=t5_embs, style_t5_embs=style_t5_embs, positive_t5_embs=positive_t5_embs, negative_t5_embs=negative_t5_embs, batch_repeat=batch_repeat, dynamic_thresholding_p=dynamic_thresholding_p, dynamic_thresholding_c=dynamic_thresholding_c, sample_loop=sample_loop, sample_timestep_respacing=sample_timestep_respacing, guidance_scale=guidance_scale, img_size=64, aspect_ratio=aspect_ratio, progress=progress, seed=seed, sample_fn=sample_fn, positive_mixer=positive_mixer, **kwargs ) ================================================ FILE: deepfloyd_if/modules/stage_II.py ================================================ # -*- coding: utf-8 -*- import accelerate from .base import IFBaseModule from ..model import SuperResUNetModel class IFStageII(IFBaseModule): stage = 'II' available_models = ['IF-II-M-v1.0', 'IF-II-L-v1.0'] def __init__(self, *args, model_kwargs=None, pil_img_size=256, **kwargs): super().__init__(*args, pil_img_size=pil_img_size, **kwargs) model_params = dict(self.conf.params) model_params.update(model_kwargs or {}) with accelerate.init_empty_weights(): self.model = SuperResUNetModel(low_res_diffusion=self.get_diffusion('1000'), **model_params) self.model = self.load_checkpoint(self.model, self.dir_or_name) self.model.eval().to(self.device) def embeddings_to_image( self, low_res, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1, aug_level=0.25, dynamic_thresholding_p=0.95, dynamic_thresholding_c=1.0, sample_loop='ddpm', sample_timestep_respacing='smart50', guidance_scale=4.0, img_scale=4.0, positive_mixer=0.5, progress=True, seed=None, sample_fn=None, **kwargs): return super().embeddings_to_image( t5_embs=t5_embs, low_res=low_res, style_t5_embs=style_t5_embs, positive_t5_embs=positive_t5_embs, negative_t5_embs=negative_t5_embs, batch_repeat=batch_repeat, aug_level=aug_level, dynamic_thresholding_p=dynamic_thresholding_p, dynamic_thresholding_c=dynamic_thresholding_c, sample_loop=sample_loop, sample_timestep_respacing=sample_timestep_respacing, guidance_scale=guidance_scale, positive_mixer=positive_mixer, img_size=256, img_scale=img_scale, progress=progress, seed=seed, sample_fn=sample_fn, **kwargs ) ================================================ FILE: deepfloyd_if/modules/stage_III.py ================================================ # -*- coding: utf-8 -*- import accelerate from .base import IFBaseModule from ..model import SuperResUNetModel class IFStageIII(IFBaseModule): available_models = ['IF-III-L-v1.0'] def __init__(self, *args, model_kwargs=None, pil_img_size=1024, **kwargs): super().__init__(*args, pil_img_size=pil_img_size, **kwargs) model_params = dict(self.conf.params) model_params.update(model_kwargs or {}) with accelerate.init_empty_weights(): self.model = SuperResUNetModel(low_res_diffusion=self.get_diffusion('1000'), **model_params) self.model = self.load_checkpoint(self.model, self.dir_or_name) self.model.eval().to(self.device) def embeddings_to_image( self, low_res, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1, aug_level=0.0, blur_sigma=None, dynamic_thresholding_p=0.95, dynamic_thresholding_c=1.0, positive_mixer=0.5, sample_loop='ddpm', sample_timestep_respacing='super40', guidance_scale=4.0, img_scale=4.0, progress=True, seed=None, sample_fn=None, **kwargs): return super().embeddings_to_image( t5_embs=t5_embs, low_res=low_res, style_t5_embs=style_t5_embs, positive_t5_embs=positive_t5_embs, negative_t5_embs=negative_t5_embs, batch_repeat=batch_repeat, aug_level=aug_level, blur_sigma=blur_sigma, dynamic_thresholding_p=dynamic_thresholding_p, dynamic_thresholding_c=dynamic_thresholding_c, sample_loop=sample_loop, sample_timestep_respacing=sample_timestep_respacing, guidance_scale=guidance_scale, positive_mixer=positive_mixer, img_size=1024, img_scale=img_scale, progress=progress, seed=seed, sample_fn=sample_fn, **kwargs ) ================================================ FILE: deepfloyd_if/modules/stage_III_sd_x4.py ================================================ # -*- coding: utf-8 -*- import diffusers from diffusers import DiffusionPipeline, DDPMScheduler import torch import os from .base import IFBaseModule import packaging.version as pv class StableStageIII(IFBaseModule): available_models = ['stable-diffusion-x4-upscaler'] def __init__(self, *args, model_kwargs=None, pil_img_size=1024, **kwargs): super().__init__(*args, pil_img_size=pil_img_size, **kwargs) if pv.parse(diffusers.__version__) <= pv.parse('0.15.1'): raise ValueError( 'Make sure to have `diffusers >= 0.16.0` installed.' ' Please run `pip install diffusers --upgrade`' ) model_id = os.path.join('stabilityai', self.dir_or_name) model_kwargs = model_kwargs or {} precision = str(model_kwargs.get('precision', '16')) if precision == '16': torch_dtype = torch.float16 elif precision == 'bf16': torch_dtype = torch.bfloat16 else: torch_dtype = torch.float32 self.model = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype, token=self.hf_token) self.model.to(self.device) if bool(os.environ.get('FORCE_MEM_EFFICIENT_ATTN')): self.model.enable_xformers_memory_efficient_attention() @property def use_diffusers(self): if self.dir_or_name == self.available_models[-1]: return True elif os.path.isdir(self.dir_or_name) and os.path.isfile(os.path.join(self.dir_or_name, 'model_index.json')): return True return False def embeddings_to_image( self, low_res, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1, aug_level=0.0, blur_sigma=None, dynamic_thresholding_p=0.95, dynamic_thresholding_c=1.0, positive_mixer=0.5, sample_loop='ddpm', sample_timestep_respacing='75', guidance_scale=4.0, img_scale=4.0, progress=True, seed=None, sample_fn=None, **kwargs): prompt = kwargs.pop('prompt') noise_level = kwargs.pop('noise_level', 20) if sample_loop == 'ddpm': self.model.scheduler = DDPMScheduler.from_config(self.model.scheduler.config) else: raise ValueError(f"For now only the 'ddpm' sample loop type is supported, but you passed {sample_loop}") num_inference_steps = int(sample_timestep_respacing) self.model.set_progress_bar_config(disable=not progress) generator = torch.manual_seed(seed) prompt = sum([batch_repeat * [p] for p in prompt], []) low_res = low_res.repeat(batch_repeat, 1, 1, 1) metadata = { 'image': low_res, 'prompt': prompt, 'noise_level': noise_level, 'generator': generator, 'guidance_scale': guidance_scale, 'num_inference_steps': num_inference_steps, 'output_type': 'pt', } images = self.model(**metadata).images sample = self._IFBaseModule__validate_generations(images) return sample, metadata ================================================ FILE: deepfloyd_if/modules/t5.py ================================================ # -*- coding: utf-8 -*- import os import re import html import urllib.parse as ul import ftfy import torch from bs4 import BeautifulSoup from transformers import T5EncoderModel, AutoTokenizer from huggingface_hub import hf_hub_download class T5Embedder: available_models = ['t5-v1_1-xxl'] bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, cache_dir=None, hf_token=None, use_text_preprocessing=True, t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None): self.device = torch.device(device) self.torch_dtype = torch_dtype or torch.bfloat16 if t5_model_kwargs is None: t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype} if use_offload_folder is not None: t5_model_kwargs['offload_folder'] = use_offload_folder t5_model_kwargs['device_map'] = { 'shared': self.device, 'encoder.embed_tokens': self.device, 'encoder.block.0': self.device, 'encoder.block.1': self.device, 'encoder.block.2': self.device, 'encoder.block.3': self.device, 'encoder.block.4': self.device, 'encoder.block.5': self.device, 'encoder.block.6': self.device, 'encoder.block.7': self.device, 'encoder.block.8': self.device, 'encoder.block.9': self.device, 'encoder.block.10': self.device, 'encoder.block.11': self.device, 'encoder.block.12': 'disk', 'encoder.block.13': 'disk', 'encoder.block.14': 'disk', 'encoder.block.15': 'disk', 'encoder.block.16': 'disk', 'encoder.block.17': 'disk', 'encoder.block.18': 'disk', 'encoder.block.19': 'disk', 'encoder.block.20': 'disk', 'encoder.block.21': 'disk', 'encoder.block.22': 'disk', 'encoder.block.23': 'disk', 'encoder.final_layer_norm': 'disk', 'encoder.dropout': 'disk', } else: t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device} self.use_text_preprocessing = use_text_preprocessing self.hf_token = hf_token self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_') self.dir_or_name = dir_or_name tokenizer_path, path = dir_or_name, dir_or_name if dir_or_name in self.available_models: cache_dir = os.path.join(self.cache_dir, dir_or_name) for filename in [ 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json', 'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin' ]: hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir, force_filename=filename, token=self.hf_token) tokenizer_path, path = cache_dir, cache_dir else: cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl') for filename in [ 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json', ]: hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir, force_filename=filename, token=self.hf_token) tokenizer_path = cache_dir self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval() def get_text_embeddings(self, texts): texts = [self.text_preprocessing(text) for text in texts] text_tokens_and_mask = self.tokenizer( texts, max_length=77, padding='max_length', truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors='pt' ) text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids'] text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask'] with torch.no_grad(): text_encoder_embs = self.model( input_ids=text_tokens_and_mask['input_ids'].to(self.device), attention_mask=text_tokens_and_mask['attention_mask'].to(self.device), )['last_hidden_state'].detach() return text_encoder_embs def text_preprocessing(self, text): if self.use_text_preprocessing: # The exact text cleaning as was in the training stage: text = self.clean_caption(text) text = self.clean_caption(text) return text else: return text.lower().strip() @staticmethod def basic_clean(text): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) return text.strip() def clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub('', 'person', caption) # urls: caption = re.sub( r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa '', caption) # regex for urls caption = re.sub( r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa '', caption) # regex for urls # html: caption = BeautifulSoup(caption, features='html.parser').text # @ caption = re.sub(r'@[\w\d]+\b', '', caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r'[\u31c0-\u31ef]+', '', caption) caption = re.sub(r'[\u31f0-\u31ff]+', '', caption) caption = re.sub(r'[\u3200-\u32ff]+', '', caption) caption = re.sub(r'[\u3300-\u33ff]+', '', caption) caption = re.sub(r'[\u3400-\u4dbf]+', '', caption) caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption) caption = re.sub(r'[\u4e00-\u9fff]+', '', caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa '-', caption) # кавычки к одному стандарту caption = re.sub(r'[`´«»“”¨]', '"', caption) caption = re.sub(r'[‘’]', "'", caption) # " caption = re.sub(r'"?', '', caption) # & caption = re.sub(r'&', '', caption) # ip adresses: caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption) # article ids: caption = re.sub(r'\d:\d\d\s+$', '', caption) # \n caption = re.sub(r'\\n', ' ', caption) # "#123" caption = re.sub(r'#\d{1,3}\b', '', caption) # "#12345.." caption = re.sub(r'#\d{5,}\b', '', caption) # "123456.." caption = re.sub(r'\b\d{6,}\b', '', caption) # filenames: caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption) # caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r'\s+\.\s+', r' ', caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r'(?:\-|\_)') if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, ' ', caption) caption = self.basic_clean(caption) caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640 caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231 caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption) caption = re.sub(r'(free\s)?download(\sfree)?', '', caption) caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption) caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption) caption = re.sub(r'\bpage\s+\d+\b', '', caption) caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a... caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption) caption = re.sub(r'\b\s+\:\s+', r': ', caption) caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption) caption = re.sub(r'\s+', ' ', caption) caption.strip() caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption) caption = re.sub(r'^[\'\_,\-\:;]', r'', caption) caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption) caption = re.sub(r'^\.\S+$', '', caption) return caption.strip() ================================================ FILE: deepfloyd_if/modules/utils.py ================================================ # -*- coding: utf-8 -*- import numpy as np import torchvision.transforms as T def predict_proba(X, weights, biases): logits = X @ weights.T + biases proba = np.where(logits >= 0, 1 / (1 + np.exp(-logits)), np.exp(logits) / (1 + np.exp(logits))) return proba.T def load_model_weights(path): model_weights = np.load(path) return model_weights['weights'], model_weights['biases'] def clip_process_generations(generations): min_size = min(generations.shape[-2:]) return T.Compose([ T.CenterCrop(min_size), T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True), T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ])(generations) ================================================ FILE: deepfloyd_if/pipelines/__init__.py ================================================ # -*- coding: utf-8 -*- from .dream import dream from .style_transfer import style_transfer from .super_resolution import super_resolution from .inpainting import inpainting __all__ = ['dream', 'style_transfer', 'super_resolution', 'inpainting'] ================================================ FILE: deepfloyd_if/pipelines/dream.py ================================================ # -*- coding: utf-8 -*- from datetime import datetime import torch def dream( t5, if_I, if_II=None, if_III=None, *, prompt, style_prompt=None, negative_prompt=None, seed=None, aspect_ratio='1:1', if_I_kwargs=None, if_II_kwargs=None, if_III_kwargs=None, progress=True, return_tensors=False, disable_watermark=False, ): """ Generate pictures using text description! :param optional dict if_I_kwargs: "dynamic_thresholding_p": 0.95, [0.5, 1.0] it controls color saturation on high cfg values "dynamic_thresholding_c": 1.5, [1.0, 15.0] clips the limiter to avoid greyish images on high limiter values "guidance_scale": 7.0, [1.0, 20.0] control the level of text understanding "positive_mixer": 0.25, [0.0, 1.0] contribution of the second positive prompt, 0.0 - minimum, 1.0 - maximum "sample_timestep_respacing": "150", see available modes IFBaseModule.respacing_modes or use custom :param optional dict if_II_kwargs: "dynamic_thresholding_p": 0.95, [0.5, 1.0] it controls color saturation on high cfg values "dynamic_thresholding_c": 1.0, [1.0, 15.0] clips the limiter to avoid greyish images on high limiter values "guidance_scale": 4.0, [1.0, 20.0] control the amount of texture and details in the final image "aug_level": 0.25, [0.0, 1.0] adds additional augmentation to generate more realistic images "positive_mixer": 0.5, [0.0, 1.0] contribution of the second positive prompt, 0.0 - minimum, 1.0 - maximum "sample_timestep_respacing": "smart50", see available modes IFBaseModule.respacing_modes or use custom :param deepfloyd_if.modules.IFStageI if_I: obj :param deepfloyd_if.modules.IFStageII if_II: obj :param deepfloyd_if.modules.IFStageIII if_III: obj :param deepfloyd_if.modules.T5Embedder t5: obj :param int seed: int, in case None will use random value :param aspect_ratio: :param str prompt: text hint/description :param str style_prompt: text hint/description for style :param str negative_prompt: text hint/description for negative prompt, will use it as unconditional emb :param progress: :return: """ if seed is None: seed = int((datetime.utcnow().timestamp() * 10 ** 6) % (2 ** 32 - 1)) if_I.seed_everything(seed) if isinstance(prompt, str): prompt = [prompt] t5_embs = t5.get_text_embeddings(prompt) if_I_kwargs = if_I_kwargs or {} if_I_kwargs['seed'] = seed if_I_kwargs['t5_embs'] = t5_embs if_I_kwargs['aspect_ratio'] = aspect_ratio if_I_kwargs['progress'] = progress if style_prompt is not None: if isinstance(style_prompt, str): style_prompt = [style_prompt] style_t5_embs = t5.get_text_embeddings(style_prompt) if_I_kwargs['style_t5_embs'] = style_t5_embs if_I_kwargs['positive_t5_embs'] = style_t5_embs if negative_prompt is not None: if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] negative_t5_embs = t5.get_text_embeddings(negative_prompt) if_I_kwargs['negative_t5_embs'] = negative_t5_embs stageI_generations, _ = if_I.embeddings_to_image(**if_I_kwargs) pil_images_I = if_I.to_images(stageI_generations, disable_watermark=disable_watermark) result = {'I': pil_images_I} if if_II is not None: if_II_kwargs = if_II_kwargs or {} if_II_kwargs['low_res'] = stageI_generations if_II_kwargs['seed'] = seed if_II_kwargs['t5_embs'] = t5_embs if_II_kwargs['progress'] = progress if_II_kwargs['style_t5_embs'] = if_I_kwargs.get('style_t5_embs') if_II_kwargs['positive_t5_embs'] = if_I_kwargs.get('positive_t5_embs') stageII_generations, _meta = if_II.embeddings_to_image(**if_II_kwargs) pil_images_II = if_II.to_images(stageII_generations, disable_watermark=disable_watermark) result['II'] = pil_images_II else: stageII_generations = None if if_II is not None and if_III is not None: if_III_kwargs = if_III_kwargs or {} stageIII_generations = [] for idx in range(len(stageII_generations)): if if_III.use_diffusers: if_III_kwargs['prompt'] = prompt[idx: idx+1] if_III_kwargs['low_res'] = stageII_generations[idx:idx+1] if_III_kwargs['seed'] = seed if_III_kwargs['t5_embs'] = t5_embs[idx:idx+1] if_III_kwargs['progress'] = progress style_t5_embs = if_I_kwargs.get('style_t5_embs') if style_t5_embs is not None: style_t5_embs = style_t5_embs[idx:idx+1] positive_t5_embs = if_I_kwargs.get('positive_t5_embs') if positive_t5_embs is not None: positive_t5_embs = positive_t5_embs[idx:idx+1] if_III_kwargs['style_t5_embs'] = style_t5_embs if_III_kwargs['positive_t5_embs'] = positive_t5_embs _stageIII_generations, _meta = if_III.embeddings_to_image(**if_III_kwargs) stageIII_generations.append(_stageIII_generations) stageIII_generations = torch.cat(stageIII_generations, 0) pil_images_III = if_III.to_images(stageIII_generations, disable_watermark=disable_watermark) result['III'] = pil_images_III else: stageIII_generations = None if return_tensors: return result, (stageI_generations, stageII_generations, stageIII_generations) else: return result ================================================ FILE: deepfloyd_if/pipelines/inpainting.py ================================================ # -*- coding: utf-8 -*- from datetime import datetime import PIL import torch from .utils import _prepare_pil_image def inpainting( t5, if_I, if_II=None, if_III=None, *, support_pil_img, prompt, inpainting_mask, negative_prompt=None, seed=None, if_I_kwargs=None, if_II_kwargs=None, if_III_kwargs=None, progress=True, return_tensors=False, disable_watermark=False, ): from skimage.transform import resize # noqa from skimage import img_as_bool # noqa assert isinstance(support_pil_img, PIL.Image.Image) if seed is None: seed = int((datetime.utcnow().timestamp() * 10 ** 6) % (2 ** 32 - 1)) t5_embs = t5.get_text_embeddings(prompt) if negative_prompt is not None: if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] negative_t5_embs = t5.get_text_embeddings(negative_prompt) else: negative_t5_embs = None low_res = _prepare_pil_image(support_pil_img, 64) mid_res = _prepare_pil_image(support_pil_img, 256) high_res = _prepare_pil_image(support_pil_img, 1024) result = {} _, _, image_h, image_w = low_res.shape if_I_kwargs = if_I_kwargs or {} if_I_kwargs['seed'] = seed if_I_kwargs['progress'] = progress if_I_kwargs['aspect_ratio'] = f'{image_w}:{image_h}' if_I_kwargs['t5_embs'] = t5_embs if_I_kwargs['negative_t5_embs'] = negative_t5_embs if_I_kwargs['support_noise'] = low_res inpainting_mask_I = img_as_bool(resize(inpainting_mask[0].cpu(), (3, image_h, image_w))) inpainting_mask_I = torch.from_numpy(inpainting_mask_I).unsqueeze(0).to(if_I.device) if_I_kwargs['inpainting_mask'] = inpainting_mask_I stageI_generations, _ = if_I.embeddings_to_image(**if_I_kwargs) pil_images_I = if_I.to_images(stageI_generations, disable_watermark=disable_watermark) result['I'] = pil_images_I if if_II is not None: _, _, image_h, image_w = mid_res.shape if_II_kwargs = if_II_kwargs or {} if_II_kwargs['low_res'] = stageI_generations if_II_kwargs['seed'] = seed if_II_kwargs['t5_embs'] = t5_embs if_II_kwargs['negative_t5_embs'] = negative_t5_embs if_II_kwargs['progress'] = progress if_II_kwargs['support_noise'] = mid_res if 'inpainting_mask' not in if_II_kwargs: inpainting_mask_II = img_as_bool(resize(inpainting_mask[0].cpu(), (3, image_h, image_w))) inpainting_mask_II = torch.from_numpy(inpainting_mask_II).unsqueeze(0).to(if_II.device) if_II_kwargs['inpainting_mask'] = inpainting_mask_II stageII_generations, _meta = if_II.embeddings_to_image(**if_II_kwargs) pil_images_II = if_II.to_images(stageII_generations, disable_watermark=disable_watermark) result['II'] = pil_images_II else: stageII_generations = None if if_II is not None and if_III is not None: _, _, image_h, image_w = high_res.shape if_III_kwargs = if_III_kwargs or {} stageIII_generations = [] for idx in range(len(stageII_generations)): if if_III.use_diffusers: if_III_kwargs['prompt'] = prompt[idx: idx+1] if_III_kwargs['low_res'] = stageII_generations[idx:idx+1] if_III_kwargs['seed'] = seed if_III_kwargs['t5_embs'] = t5_embs[idx:idx+1] if negative_t5_embs is not None: if_III_kwargs['negative_t5_embs'] = negative_t5_embs[idx:idx+1] if_III_kwargs['progress'] = progress if_III_kwargs['support_noise'] = high_res if 'inpainting_mask' not in if_III_kwargs: inpainting_mask_III = img_as_bool(resize(inpainting_mask[0].cpu(), (3, image_h, image_w))) inpainting_mask_III = torch.from_numpy(inpainting_mask_III).unsqueeze(0).to(if_III.device) if_III_kwargs['inpainting_mask'] = inpainting_mask_III _stageIII_generations, _meta = if_III.embeddings_to_image(**if_III_kwargs) stageIII_generations.append(_stageIII_generations) stageIII_generations = torch.cat(stageIII_generations, 0) pil_images_III = if_III.to_images(stageIII_generations, disable_watermark=disable_watermark) result['III'] = pil_images_III else: stageIII_generations = None if return_tensors: return result, (stageI_generations, stageII_generations, stageIII_generations) else: return result ================================================ FILE: deepfloyd_if/pipelines/style_transfer.py ================================================ # -*- coding: utf-8 -*- from datetime import datetime import PIL import torch from .utils import _prepare_pil_image def style_transfer( t5, if_I, if_II, if_III=None, *, support_pil_img, style_prompt, prompt=None, negative_prompt=None, seed=None, if_I_kwargs=None, if_II_kwargs=None, if_III_kwargs=None, progress=True, return_tensors=False, disable_watermark=False, ): assert isinstance(support_pil_img, PIL.Image.Image) bs = len(style_prompt) if seed is None: seed = int((datetime.utcnow().timestamp() * 10 ** 6) % (2 ** 32 - 1)) if prompt is not None: t5_embs = t5.get_text_embeddings(prompt) else: t5_embs = t5.get_text_embeddings(style_prompt) style_t5_embs = t5.get_text_embeddings(style_prompt) if negative_prompt is not None: if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] negative_t5_embs = t5.get_text_embeddings(negative_prompt) else: negative_t5_embs = None low_res = _prepare_pil_image(support_pil_img, 64) mid_res = _prepare_pil_image(support_pil_img, 256) # high_res = _prepare_pil_image(support_pil_img, 1024) result = {} if if_I is not None: _, _, image_h, image_w = low_res.shape if_I_kwargs = if_I_kwargs or {'sample_timestep_respacing': '20,20,20,20,10,0,0,0,0,0'} if_I_kwargs['seed'] = seed if_I_kwargs['progress'] = progress if_I_kwargs['aspect_ratio'] = f'{image_w}:{image_h}' if_I_kwargs['t5_embs'] = t5_embs if_I_kwargs['style_t5_embs'] = style_t5_embs if_I_kwargs['positive_t5_embs'] = style_t5_embs if_I_kwargs['negative_t5_embs'] = negative_t5_embs if_I_kwargs['support_noise'] = low_res stageI_generations, _ = if_I.embeddings_to_image(**if_I_kwargs) pil_images_I = if_I.to_images(stageI_generations, disable_watermark=disable_watermark) result['I'] = pil_images_I else: stageI_generations = None if if_II is not None: if stageI_generations is None: stageI_generations = low_res.repeat(bs, 1, 1, 1) if_II_kwargs = if_II_kwargs or {} if_II_kwargs['low_res'] = stageI_generations if_II_kwargs['seed'] = seed if_II_kwargs['t5_embs'] = t5_embs if_II_kwargs['style_t5_embs'] = style_t5_embs if_II_kwargs['positive_t5_embs'] = style_t5_embs if_II_kwargs['negative_t5_embs'] = negative_t5_embs if_II_kwargs['progress'] = progress if_II_kwargs['support_noise'] = mid_res stageII_generations, _meta = if_II.embeddings_to_image(**if_II_kwargs) pil_images_II = if_II.to_images(stageII_generations, disable_watermark=disable_watermark) result['II'] = pil_images_II else: stageII_generations = None if if_II is not None and if_III is not None: if_III_kwargs = if_III_kwargs or {} stageIII_generations = [] for idx in range(len(stageII_generations)): if if_III.use_diffusers: if_III_kwargs['prompt'] = prompt[idx: idx+1] if prompt is not None else style_prompt[idx: idx+1] if_III_kwargs['low_res'] = stageII_generations[idx:idx+1] if_III_kwargs['seed'] = seed if_III_kwargs['t5_embs'] = t5_embs[idx:idx+1] if_III_kwargs['progress'] = progress style_t5_embs = if_II_kwargs.get('style_t5_embs') if style_t5_embs is not None: style_t5_embs = style_t5_embs[idx:idx+1] positive_t5_embs = if_II_kwargs.get('positive_t5_embs') if positive_t5_embs is not None: positive_t5_embs = positive_t5_embs[idx:idx+1] if_III_kwargs['style_t5_embs'] = style_t5_embs if_III_kwargs['positive_t5_embs'] = positive_t5_embs _stageIII_generations, _meta = if_III.embeddings_to_image(**if_III_kwargs) stageIII_generations.append(_stageIII_generations) stageIII_generations = torch.cat(stageIII_generations, 0) pil_images_III = if_III.to_images(stageIII_generations, disable_watermark=disable_watermark) result['III'] = pil_images_III else: stageIII_generations = None if return_tensors: return result, (stageI_generations, stageII_generations, stageIII_generations) else: return result ================================================ FILE: deepfloyd_if/pipelines/super_resolution.py ================================================ # -*- coding: utf-8 -*- from datetime import datetime import PIL from .utils import _prepare_pil_image def super_resolution( t5, if_III=None, *, support_pil_img, prompt=None, negative_prompt=None, seed=None, if_III_kwargs=None, progress=True, img_size=256, img_scale=4.0, return_tensors=False, disable_watermark=False, ): assert isinstance(support_pil_img, PIL.Image.Image) assert img_size % 8 == 0 if seed is None: seed = int((datetime.utcnow().timestamp() * 10 ** 6) % (2 ** 32 - 1)) if prompt is not None: t5_embs = t5.get_text_embeddings(prompt) else: t5_embs = t5.get_text_embeddings('') if negative_prompt is not None: if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] negative_t5_embs = t5.get_text_embeddings(negative_prompt) else: negative_t5_embs = None low_res = _prepare_pil_image(support_pil_img, img_size) result = {} bs = 1 if_III_kwargs = if_III_kwargs or {} if if_III.use_diffusers: if_III_kwargs['prompt'] = prompt if_III_kwargs['low_res'] = low_res.repeat(bs, 1, 1, 1) if_III_kwargs['seed'] = seed if_III_kwargs['t5_embs'] = t5_embs if_III_kwargs['negative_t5_embs'] = negative_t5_embs if_III_kwargs['progress'] = progress if_III_kwargs['img_scale'] = img_scale stageIII_generations, _meta = if_III.embeddings_to_image(**if_III_kwargs) pil_images_III = if_III.to_images(stageIII_generations, disable_watermark=disable_watermark) result['III'] = pil_images_III if return_tensors: return result, (stageIII_generations,) else: return result ================================================ FILE: deepfloyd_if/pipelines/utils.py ================================================ # -*- coding: utf-8 -*- import torch import numpy as np from PIL import Image def _prepare_pil_image(raw_pil_img, img_size): raw_pil_img = raw_pil_img.convert('RGB') w, h = raw_pil_img.size coef = w / h image_h, image_w = img_size, img_size if coef >= 1: image_w = int(round(img_size / 8 * coef) * 8) else: image_h = int(round(img_size / 8 / coef) * 8) pil_img = raw_pil_img.resize( (image_w, image_h), resample=getattr(Image, 'Resampling', Image).BICUBIC, reducing_gap=None ) img = np.array(pil_img) img = img.astype(np.float32) / 127.5 - 1 img = np.transpose(img, [2, 0, 1]) img = torch.from_numpy(img).unsqueeze(0) return img ================================================ FILE: deepfloyd_if/utils.py ================================================ # -*- coding: utf-8 -*- from os.path import abspath, dirname, join import torch import numpy as np from PIL import Image, ImageFilter RESOURCES_ROOT = join(abspath(dirname(__file__)), 'resources') def drop_shadow(image, offset=(5, 5), background=0xffffff, shadow=0x444444, border=8, iterations=3): """ Drop shadows with PIL. Author: Kevin Schluff License: Python license https://code.activestate.com/recipes/474116/ Add a gaussian blur drop shadow to an image. image - The image to overlay on top of the shadow. offset - Offset of the shadow from the image as an (x,y) tuple. Can be positive or negative. background - Background colour behind the image. shadow - Shadow colour (darkness). border - Width of the border around the image. This must be wide enough to account for the blurring of the shadow. iterations - Number of times to apply the filter. More iterations produce a more blurred shadow, but increase processing time. """ # Create the backdrop image -- a box in the background colour with a # shadow on it. total_width = image.size[0] + abs(offset[0]) + 2 * border total_height = image.size[1] + abs(offset[1]) + 2 * border back = Image.new(image.mode, (total_width, total_height), background) # Place the shadow, taking into account the offset from the image shadow_left = border + max(offset[0], 0) shadow_top = border + max(offset[1], 0) back.paste(shadow, [shadow_left, shadow_top, shadow_left + image.size[0], shadow_top + image.size[1]]) # Apply the filter to blur the edges of the shadow. Since a small kernel # is used, the filter must be applied repeatedly to get a decent blur. n = 0 while n < iterations: back = back.filter(ImageFilter.BLUR) n += 1 # Paste the input image onto the shadow backdrop image_left = border - min(offset[0], 0) image_top = border - min(offset[1], 0) back.paste(image, (image_left, image_top)) return back def pil_list_to_torch_tensors(pil_images): result = [] for pil_image in pil_images: image = np.array(pil_image, dtype=np.uint8) image = torch.from_numpy(image) image = image.permute(2, 0, 1).unsqueeze(0) result.append(image) return torch.cat(result, dim=0) ================================================ FILE: requirements-dev.txt ================================================ -r requirements-test.txt pre-commit ================================================ FILE: requirements-test.txt ================================================ -r requirements.txt pytest pytest-cov ================================================ FILE: requirements.txt ================================================ tqdm numpy torch<2.0.0 torchvision omegaconf matplotlib Pillow>=9.2.0 huggingface_hub>=0.13.2 transformers~=4.25.1 accelerate~=0.15.0 diffusers~=0.16.0 tokenizers~=0.13.2 sentencepiece~=0.1.97 ftfy~=6.1.1 beautifulsoup4~=4.11.1 ================================================ FILE: setup.cfg ================================================ [pep8] max-line-length = 120 exclude = .tox,*migrations*,.json [flake8] max-line-length = 120 exclude = .tox,*migrations*,.json [autopep8-wrapper] exclude = .tox,*migrations*,.json [check-docstring-first] exclude = .tox,*migrations*,.json ================================================ FILE: setup.py ================================================ # -*- coding: utf-8 -*- import os import re from setuptools import setup def read(filename): with open(os.path.join(os.path.dirname(__file__), filename)) as f: file_content = f.read() return file_content def get_requirements(): requirements = [] for requirement in read('requirements.txt').splitlines(): if requirement.startswith('git+') or requirement.startswith('svn+') or requirement.startswith('hg+'): parsed_requires = re.findall(r'#egg=([\w\d\.]+)-([\d\.]+)$', requirement) if parsed_requires: package, version = parsed_requires[0] requirements.append(f'{package}=={version}') else: print('WARNING! For correct matching dependency links need to specify package name and version' 'such as #egg=-') else: requirements.append(requirement) return requirements def get_links(): return [ requirement for requirement in read('requirements.txt').splitlines() if requirement.startswith('git+') or requirement.startswith('svn+') or requirement.startswith('hg+') ] def get_version(): """ Get version from the package without actually importing it. """ init = read('deepfloyd_if/__init__.py') for line in init.split('\n'): if line.startswith('__version__'): return eval(line.split('=')[1]) setup( name='deepfloyd_if', version=get_version(), author='DeepFloyd, StabilityAI', author_email='shonenkov@gmail.com', description='DeepFloyd-IF (Imagen Free)', packages=['deepfloyd_if', 'deepfloyd_if/model', 'deepfloyd_if/modules', 'deepfloyd_if/pipelines', 'deepfloyd_if/resources'], package_data={'deepfloyd_if/resources': ['*.png', '*.npy', '*.npz']}, install_requires=get_requirements(), dependency_links=get_links(), long_description=read('README.md'), long_description_content_type='text/markdown', )