[
  {
    "path": ".gitattributes",
    "content": "notebooks/pipes-DeepFloyd-IF.ipynb filter=lfs diff=lfs merge=lfs -text\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.idea\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, deepfloyd_if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n-   repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v4.2.0\n    hooks:\n    -   id: check-docstring-first\n    -   id: check-merge-conflict\n        stages:\n        - push\n    -   id: double-quote-string-fixer\n    -   id: end-of-file-fixer\n    -   id: fix-encoding-pragma\n    -   id: mixed-line-ending\n    -   id: trailing-whitespace\n-   repo: https://github.com/pycqa/flake8\n    rev: \"4.0.1\"\n    hooks:\n    -   id: flake8\n        args: ['--config=setup.cfg']\n-   repo: https://github.com/pre-commit/mirrors-autopep8\n    rev: v1.6.0\n    hooks:\n    -   id: autopep8\n"
  },
  {
    "path": "CHANGELOG.md",
    "content": "v1.0.2rc\n-------\n\n- uses separated tokenizer_path to init tokenizer in T5Embedder\n\nv1.0.1\n------\n\n- renamed main model `IF-I-IF` --> `IF-I-XL`\n- moved dir `notebooks` to HF storage https://huggingface.co/DeepFloyd/IF-notebooks; lets keep new notebooks there;\n- added additional kaggle notebook (more free GPU resources) how to generate pictures 1k: [![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/code/shonenkov/deepfloyd-if-4-3b-generator-of-pictures)\n\nv1.0.0\n------\n\n- initial version\n"
  },
  {
    "path": "LICENSE",
    "content": "Copyright (c) 2023 DeepFloyd, StabilityAI\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\n1. The above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\n2. All persons obtaining a copy or substantial portion of the Software,\na modified version of the Software (or substantial portion thereof), or\na derivative work based upon this Software (or substantial portion thereof)\nmust not delete, remove, disable, diminish, or circumvent any inference filters or\ninference filter mechanisms in the Software, or any portion of the Software that\nimplements any such filters or filter mechanisms.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "LICENSE-MODEL",
    "content": "DEEPFLOYD IF LICENSE AGREEMENT\n\nThis License Agreement (as may be amended in accordance with this License Agreement, “License”),\nbetween you, or your employer or other entity (if you are entering into this agreement on behalf\nof your employer or other entity) (“Licensee” or “you”) and Stability AI Ltd.. (“Stability AI” or “we”)\napplies to your use of any computer program, algorithm, source code, object code, or software that is made\navailable by Stability AI under this License (“Software”) and any specifications, manuals, documentation,\nand other written information provided by Stability AI related to the Software (“Documentation”).\nBy clicking “I Accept” below or by using the Software, you agree to the terms of this License.\nIf you do not agree to this License, then you do not have any rights to use the Software or\nDocumentation (collectively, the “Software Products”), and you must immediately cease using\nthe Software Products. If you are agreeing to be bound by the terms of this License on behalf\nof your employer or other entity, you represent and warrant to Stability AI that you have full legal\nauthority to bind your employer or such entity to this License. If you do not have the requisite authority,\nyou may not accept the License or access the Software Products on behalf of your employer or other entity.\n\n1. LICENSE GRANT\n\na. Subject to your compliance with the Documentation and Sections 2, 3, and 5, Stability AI grants\nyou a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited\nlicense under Stability AI’s copyright interests to reproduce, distribute, and create derivative works of\nthe Software solely for your non-commercial research purposes. The foregoing license is personal to you,\nand you may not assign or sublicense this License or any other rights or obligations under this License\nwithout Stability AI’s prior written consent; any such assignment or sublicense will be void and will\nautomatically and immediately terminate this License.\n\nb. You may make a reasonable number of copies of the Documentation solely for use in connection with\nthe license to the Software granted above.\n\nc. The grant of rights expressly set forth in this Section 1 (License Grant) are the complete\ngrant of rights to you in the Software Products, and no other licenses are granted, whether by waiver,\nestoppel, implication, equity or otherwise. Stability AI and its licensors reserve all rights\nnot expressly granted by this License.\n\n\n2. RESTRICTIONS\n\nYou will not, and will not permit, assist or cause any third party to:\n\na. use, modify, copy, reproduce, create derivative works of, or distribute the Software Products\n(or any derivative works thereof, works incorporating the Software Products, or any data produced\nby the Software), in whole or in part, for (i) any commercial or production purposes,\n(ii) military purposes or in the service of nuclear technology, (iii) purposes of surveillance,\nincluding any research or development relating to surveillance, (iv) biometric processing,\n(v) in any manner that infringes, misappropriates, or otherwise violates any third-party rights,\nor (vi) in any manner that violates any applicable law and violating any privacy or security laws,\nrules, regulations, directives, or governmental requirements (including the General Data Privacy\nRegulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and all laws\ngoverning the processing of biometric information), as well as all amendments and successor laws\nto any of the foregoing;\n\nb. alter or remove copyright and other proprietary notices which appear on or in the Software Products;\n\nc. utilize any equipment, device, software, or other means to circumvent or remove any security or\nprotection used by Stability AI in connection with the Software, or to circumvent or remove any\nusage restrictions, or to enable functionality disabled by Stability AI; or\n\nd. offer or impose any terms on the Software Products that alter, restrict, or are inconsistent\nwith the terms of this License.\n\ne. 1) violate any applicable U.S. and non-U.S. export control and trade sanctions laws\n(“Export Laws”); 2) directly or indirectly export, re-export, provide, or otherwise\ntransfer Software Products: (a) to any individual, entity, or country prohibited by Export Laws; (b)\nto anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited\nby Export Laws, including nuclear, chemical or biological weapons, or missile technology applications;\n3) use or download Software Products if you or they are: (a) located in a comprehensively sanctioned\njurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any\npurpose prohibited by Export Laws; and (4) will not disguise your location through IP proxying or other methods.\n\n\n3. ATTRIBUTION\n\nTogether with any copies of the Software Products (as well as derivative works thereof or works\nincorporating the Software Products) that you distribute, you must provide (i) a copy of this License,\nand (ii) the following attribution notice: “DeepFloyd is licensed under the DeepFloyd License,\nCopyright (c) Stability AI Ltd. All Rights Reserved.”\n\n\n4. DISCLAIMERS\n\nTHE SOFTWARE PRODUCTS ARE PROVIDED “AS IS” and “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND,\nEXPRESS OR IMPLIED. STABILITY AIEXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED,\nWHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE SOFTWARE PRODUCTS,\nINCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE,\nTITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. STABILITY AI MAKES NO WARRANTIES OR REPRESENTATIONS\nTHAT THE SOFTWARE PRODUCTS WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS,\nOR PRODUCE ANY PARTICULAR RESULTS.\n\n\n5. LIMITATION OF LIABILITY\n\nTO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL STABILITY AI BE LIABLE TO YOU (A) UNDER\nANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY,\nOR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL,\nPUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF STABILITY AI HAS BEEN ADVISED OF THE POSSIBILITY\nOF SUCH DAMAGES. THE SOFTWARE PRODUCTS, THEIR CONSTITUENT COMPONENTS, AND ANY OUTPUT\n(COLLECTIVELY, “SOFTWARE MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR\nSITUATION WHERE FAILURE OR FAULT OF THE SOFTWARE MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD\nTO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S\nPRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”).\nIF YOU ELECT TO USE ANY OF THE SOFTWARE MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK.\nYOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND\nPOLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY\nOF THE SOFTWARE MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL\nTHAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE.\n\n\n6. INDEMNIFICATION\n\nYou will indemnify, defend and hold harmless Stability AI and our subsidiaries and affiliates,\nand each of our respective shareholders, directors, officers, employees, agents, successors,\nand assigns (collectively, the “Stability AI Parties”) from and against any losses, liabilities,\ndamages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any\nStability AI Party in connection with any claim, demand, allegation, lawsuit, proceeding, or\ninvestigation (collectively, “Claims”) arising out of or related to: (a) your access to or\nuse of the Software Products (as well as any results or data generated from such access or use),\nincluding any High-Risk Use (defined below); (b) your violation of this License; or (c)\nyour violation, misappropriation or infringement of any rights of another (including intellectual\nproperty or other proprietary rights and privacy rights). You will promptly notify the Stability AI\nParties of any such Claims, and cooperate with Stability AI Parties in defending such Claims.\nYou will also grant the Stability AI Parties sole control of the defense or settlement,\nat Stability AI’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of,\nany other indemnities or remedies set forth in a written agreement between you and\nStability AI or the other Stability AI Parties.\n\n\n7. TERMINATION; SURVIVAL\n\na. This License will automatically terminate upon any breach by you of the terms of this License.\n\nb. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you.\n\nc. The following sections survive termination of this License: 2 (Restrictions), 3 (Attribution),\n4 (Disclaimers), 5 (Limitation on Liability), 6 (Indemnification) 7 (Termination; Survival),\n8 (Third Party Materials), 9 (Trademarks), 10 (Applicable Law; Dispute Resolution), and 11 (Miscellaneous).\n\n\n8. THIRD PARTY MATERIALS\n\nThe Software Products may contain third-party software or other components (including free and\nopen source software) (all of the foregoing, “Third Party Materials”), which are subject to\nthe license terms of the respective third-party licensors. Your dealings or correspondence\nwith third parties and your use of or interaction with any Third Party Materials are solely\nbetween you and the third party. Stability AI does not control or endorse, and makes\nno representations or warranties regarding, any Third Party Materials, and your access\nto and use of such Third Party Materials are at your own risk.\n\n\n9. TRADEMARKS\n\nLicensee has not been granted any trademark license as part of this License and may not use any name\nor mark associated with Stability AI without the prior written permission of Stability AI, except to\nthe extent necessary to make the reference required by the “ATTRIBUTION” section of this Agreement.\n\n\n10. APPLICABLE LAW; DISPUTE RESOLUTION\n\nThis License will be governed and construed under the laws of the State of California without regard\nto conflicts of law provisions. Any suit or proceeding arising out of or relating to this License\nwill be brought in the federal or state courts, as applicable, in San Mateo County, California,\nand each party irrevocably submits to the jurisdiction and venue of such courts.\n\n\n11. MISCELLANEOUS\n\nIf any provision or part of a provision of this License is unlawful, void or unenforceable,\nthat provision or part of the provision is deemed severed from this License, and will not affect\nthe validity and enforceability of any remaining provisions. The failure of Stability AI to exercise\nor enforce any right or provision of this License will not operate as a waiver of such right or provision.\nThis License does not confer any third-party beneficiary rights upon any other person or entity.\nThis License, together with the Documentation, contains the entire understanding between you and\nStability AI regarding the subject matter of this License, and supersedes all other written or\noral agreements and understandings between you and Stability AI regarding such subject matter.\nNo change or addition to any provision of this License will be binding unless it is in writing and\nsigned by an authorized representative of both you and Stability AI.\n"
  },
  {
    "path": "README.md",
    "content": "[![License](https://img.shields.io/badge/Code_License-Modified_MIT-blue.svg)](LICENSE)\n[![License](https://img.shields.io/badge/Weights_License-DeepFloyd_IF-orange.svg)](LICENSE-MODEL)\n[![Downloads](https://pepy.tech/badge/deepfloyd_if)](https://pepy.tech/project/deepfloyd_if)\n[![Discord](https://img.shields.io/badge/Discord-%237289DA.svg?logo=discord&logoColor=white)](https://discord.gg/umz62Mgr)\n[![Twitter](https://img.shields.io/badge/Twitter-%231DA1F2.svg?logo=twitter&logoColor=white)](https://twitter.com/deepfloydai)\n[![Linktree](https://img.shields.io/badge/Linktree-%2339E09B.svg?logo=linktree&logoColor=white)](http://linktr.ee/deepfloyd)\n\n# IF by [DeepFloyd Lab](https://deepfloyd.ai) at [StabilityAI](https://stability.ai/)\n\n<p align=\"center\">\n  <img src=\"./pics/nabla.jpg\" width=\"100%\">\n</p>\n\nWe introduce DeepFloyd IF, a novel state-of-the-art open-source text-to-image model with a high degree of photorealism and language understanding. DeepFloyd IF is a modular composed of a frozen text encoder and three cascaded pixel diffusion modules: a base model that generates 64x64 px image based on text prompt and two super-resolution models, each designed to generate images of increasing resolution: 256x256 px and 1024x1024 px. All stages of the model utilize a frozen text encoder based on the T5 transformer to extract text embeddings, which are then fed into a UNet architecture enhanced with cross-attention and attention pooling. The result is a highly efficient model that outperforms current state-of-the-art models, achieving a zero-shot FID score of 6.66 on the COCO dataset. Our work underscores the potential of larger UNet architectures in the first stage of cascaded diffusion models and depicts a promising future for text-to-image synthesis.\n\n<p align=\"center\">\n  <img src=\"./pics/deepfloyd_if_scheme.jpg\" width=\"100%\">\n</p>\n\n*Inspired by* [*Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding*](https://arxiv.org/pdf/2205.11487.pdf)\n\n## Minimum requirements to use all IF models:\n- 16GB vRAM for IF-I-XL (4.3B text to 64x64 base module) & IF-II-L (1.2B to 256x256 upscaler module)\n- 24GB vRAM for IF-I-XL (4.3B text to 64x64 base module) & IF-II-L (1.2B to 256x256 upscaler module) & Stable x4 (to 1024x1024 upscaler)\n- `xformers` and set env variable `FORCE_MEM_EFFICIENT_ATTN=1`\n\n\n## Quick Start\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/deepfloyd_if_free_tier_google_colab.ipynb)\n[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/DeepFloyd/IF)\n\n```shell\npip install deepfloyd_if==1.0.2rc0\npip install xformers==0.0.16\npip install git+https://github.com/openai/CLIP.git --no-deps\n```\n\n## Local notebooks\n[![Jupyter Notebook](https://img.shields.io/badge/jupyter_notebook-%23FF7A01.svg?logo=jupyter&logoColor=white)](https://huggingface.co/DeepFloyd/IF-notebooks/blob/main/pipes-DeepFloyd-IF-v1.0.ipynb)\n[![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/code/shonenkov/deepfloyd-if-4-3b-generator-of-pictures)\n\nThe Dream, Style Transfer, Super Resolution or Inpainting modes are avaliable in a Jupyter Notebook [here](https://huggingface.co/DeepFloyd/IF-notebooks/blob/main/pipes-DeepFloyd-IF-v1.0.ipynb).\n\n\n\n## Integration with 🤗 Diffusers\n\nIF is also integrated with the 🤗 Hugging Face [Diffusers library](https://github.com/huggingface/diffusers/).\n\nDiffusers runs each stage individually allowing the user to customize the image generation process as well as allowing to inspect intermediate results easily.\n\n### Example\n\nBefore you can use IF, you need to accept its usage conditions. To do so:\n1. Make sure to have a [Hugging Face account](https://huggingface.co/join) and be loggin in\n2. Accept the license on the model card of [DeepFloyd/IF-I-XL-v1.0](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0)\n3. Make sure to login locally. Install `huggingface_hub`\n```sh\npip install huggingface_hub --upgrade\n```\n\nrun the login function in a Python shell\n\n```py\nfrom huggingface_hub import login\n\nlogin()\n```\n\nand enter your [Hugging Face Hub access token](https://huggingface.co/docs/hub/security-tokens#what-are-user-access-tokens).\n\nNext we install `diffusers` and dependencies:\n\n```sh\npip install diffusers accelerate transformers safetensors\n```\n\nAnd we can now run the model locally.\n\nBy default `diffusers` makes use of [model cpu offloading](https://huggingface.co/docs/diffusers/optimization/fp16#model-offloading-for-fast-inference-and-memory-savings) to run the whole IF pipeline with as little as 14 GB of VRAM.\n\nIf you are using `torch>=2.0.0`, make sure to **delete all** `enable_xformers_memory_efficient_attention()`\nfunctions.\n\n```py\nfrom diffusers import DiffusionPipeline\nfrom diffusers.utils import pt_to_pil\nimport torch\n\n# stage 1\nstage_1 = DiffusionPipeline.from_pretrained(\"DeepFloyd/IF-I-XL-v1.0\", variant=\"fp16\", torch_dtype=torch.float16)\nstage_1.enable_xformers_memory_efficient_attention()  # remove line if torch.__version__ >= 2.0.0\nstage_1.enable_model_cpu_offload()\n\n# stage 2\nstage_2 = DiffusionPipeline.from_pretrained(\n    \"DeepFloyd/IF-II-L-v1.0\", text_encoder=None, variant=\"fp16\", torch_dtype=torch.float16\n)\nstage_2.enable_xformers_memory_efficient_attention()  # remove line if torch.__version__ >= 2.0.0\nstage_2.enable_model_cpu_offload()\n\n# stage 3\nsafety_modules = {\"feature_extractor\": stage_1.feature_extractor, \"safety_checker\": stage_1.safety_checker, \"watermarker\": stage_1.watermarker}\nstage_3 = DiffusionPipeline.from_pretrained(\"stabilityai/stable-diffusion-x4-upscaler\", **safety_modules, torch_dtype=torch.float16)\nstage_3.enable_xformers_memory_efficient_attention()  # remove line if torch.__version__ >= 2.0.0\nstage_3.enable_model_cpu_offload()\n\nprompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says \"very deep learning\"'\n\n# text embeds\nprompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)\n\ngenerator = torch.manual_seed(0)\n\n# stage 1\nimage = stage_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type=\"pt\").images\npt_to_pil(image)[0].save(\"./if_stage_I.png\")\n\n# stage 2\nimage = stage_2(\n    image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type=\"pt\"\n).images\npt_to_pil(image)[0].save(\"./if_stage_II.png\")\n\n# stage 3\nimage = stage_3(prompt=prompt, image=image, generator=generator, noise_level=100).images\nimage[0].save(\"./if_stage_III.png\")\n```\n\n There are multiple ways to speed up the inference time and lower the memory consumption even more with `diffusers`. To do so, please have a look at the Diffusers docs:\n\n- 🚀 [Optimizing for inference time](https://huggingface.co/docs/diffusers/api/pipelines/if#optimizing-for-speed)\n- ⚙️ [Optimizing for low memory during inference](https://huggingface.co/docs/diffusers/api/pipelines/if#optimizing-for-memory)\n\nFor more in-detail information about how to use IF, please have a look at [the IF blog post](https://huggingface.co/blog/if) and [the documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/if) 📖.\n\nDiffusers dreambooth scripts also supports fine-tuning 🎨 [IF](https://huggingface.co/docs/diffusers/main/en/training/dreambooth#if).\nWith parameter efficient finetuning, you can add new concepts to IF with a single GPU and ~28 GB VRAM.\n\n## Run the code locally\n\n### Loading the models into VRAM\n\n```python\nfrom deepfloyd_if.modules import IFStageI, IFStageII, StableStageIII\nfrom deepfloyd_if.modules.t5 import T5Embedder\n\ndevice = 'cuda:0'\nif_I = IFStageI('IF-I-XL-v1.0', device=device)\nif_II = IFStageII('IF-II-L-v1.0', device=device)\nif_III = StableStageIII('stable-diffusion-x4-upscaler', device=device)\nt5 = T5Embedder(device=\"cpu\")\n```\n\n### I. Dream\nDream is the text-to-image mode of the IF model\n\n```python\nfrom deepfloyd_if.pipelines import dream\n\nprompt = 'ultra close-up color photo portrait of rainbow owl with deer horns in the woods'\ncount = 4\n\nresult = dream(\n    t5=t5, if_I=if_I, if_II=if_II, if_III=if_III,\n    prompt=[prompt]*count,\n    seed=42,\n    if_I_kwargs={\n        \"guidance_scale\": 7.0,\n        \"sample_timestep_respacing\": \"smart100\",\n    },\n    if_II_kwargs={\n        \"guidance_scale\": 4.0,\n        \"sample_timestep_respacing\": \"smart50\",\n    },\n    if_III_kwargs={\n        \"guidance_scale\": 9.0,\n        \"noise_level\": 20,\n        \"sample_timestep_respacing\": \"75\",\n    },\n)\n\nif_III.show(result['III'], size=14)\n```\n![](./pics/dream-III.jpg)\n\n## II. Zero-shot Image-to-Image Translation\n\n![](./pics/img_to_img_scheme.jpeg)\n\nIn Style Transfer mode, the output of your prompt comes out at the style of the `support_pil_img`\n```python\nfrom deepfloyd_if.pipelines import style_transfer\n\nresult = style_transfer(\n    t5=t5, if_I=if_I, if_II=if_II,\n    support_pil_img=raw_pil_image,\n    style_prompt=[\n        'in style of professional origami',\n        'in style of oil art, Tate modern',\n        'in style of plastic building bricks',\n        'in style of classic anime from 1990',\n    ],\n    seed=42,\n    if_I_kwargs={\n        \"guidance_scale\": 10.0,\n        \"sample_timestep_respacing\": \"10,10,10,10,10,10,10,10,0,0\",\n        'support_noise_less_qsample_steps': 5,\n    },\n    if_II_kwargs={\n        \"guidance_scale\": 4.0,\n        \"sample_timestep_respacing\": 'smart50',\n        \"support_noise_less_qsample_steps\": 5,\n    },\n)\nif_I.show(result['II'], 1, 20)\n```\n\n![Alternative Text](./pics/deep_floyd_if_image_2_image.gif)\n\n\n## III. Super Resolution\nFor super-resolution, users can run `IF-II` and `IF-III` or 'Stable x4' on an image that was not necessarely generated by IF (two cascades):\n\n```python\nfrom deepfloyd_if.pipelines import super_resolution\n\nmiddle_res = super_resolution(\n    t5,\n    if_III=if_II,\n    prompt=['woman with a blue headscarf and a blue sweaterp, detailed picture, 4k dslr, best quality'],\n    support_pil_img=raw_pil_image,\n    img_scale=4.,\n    img_size=64,\n    if_III_kwargs={\n        'sample_timestep_respacing': 'smart100',\n        'aug_level': 0.5,\n        'guidance_scale': 6.0,\n    },\n)\nhigh_res = super_resolution(\n    t5,\n    if_III=if_III,\n    prompt=[''],\n    support_pil_img=middle_res['III'][0],\n    img_scale=4.,\n    img_size=256,\n    if_III_kwargs={\n        \"guidance_scale\": 9.0,\n        \"noise_level\": 20,\n        \"sample_timestep_respacing\": \"75\",\n    },\n)\nshow_superres(raw_pil_image, high_res['III'][0])\n```\n\n![](./pics/if_as_upscaler.jpg)\n\n\n### IV. Zero-shot Inpainting\n\n```python\nfrom deepfloyd_if.pipelines import inpainting\n\nresult = inpainting(\n    t5=t5, if_I=if_I,\n    if_II=if_II,\n    if_III=if_III,\n    support_pil_img=raw_pil_image,\n    inpainting_mask=inpainting_mask,\n    prompt=[\n        'oil art, a man in a hat',\n    ],\n    seed=42,\n    if_I_kwargs={\n        \"guidance_scale\": 7.0,\n        \"sample_timestep_respacing\": \"10,10,10,10,10,0,0,0,0,0\",\n        'support_noise_less_qsample_steps': 0,\n    },\n    if_II_kwargs={\n        \"guidance_scale\": 4.0,\n        'aug_level': 0.0,\n        \"sample_timestep_respacing\": '100',\n    },\n    if_III_kwargs={\n        \"guidance_scale\": 9.0,\n        \"noise_level\": 20,\n        \"sample_timestep_respacing\": \"75\",\n    },\n)\nif_I.show(result['I'], 2, 3)\nif_I.show(result['II'], 2, 6)\nif_I.show(result['III'], 2, 14)\n```\n![](./pics/deep_floyd_if_inpainting.gif)\n\n### 🤗 Model Zoo 🤗\nThe link to download the weights as well as the model cards will be available soon on each model of the model zoo\n\n#### Original\n\n| Name                                                      | Cascade | Params | FID  | Batch size | Steps |\n|:----------------------------------------------------------|:-------:|:------:|:----:|:----------:|:-----:|\n| [IF-I-M](https://huggingface.co/DeepFloyd/IF-I-M-v1.0)    |    I    |  400M  | 8.86 |    3072    | 2.5M  |\n| [IF-I-L](https://huggingface.co/DeepFloyd/IF-I-L-v1.0)    |    I    |  900M  | 8.06 |    3200    | 3.0M  |\n| [IF-I-XL](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0)* |    I    |  4.3B  | 6.66 |    3072    | 2.42M |\n| [IF-II-M](https://huggingface.co/DeepFloyd/IF-II-M-v1.0)  |   II    |  450M  |  -   |    1536    | 2.5M  |\n| [IF-II-L](https://huggingface.co/DeepFloyd/IF-II-L-v1.0)* |   II    |  1.2B  |  -   |    1536    | 2.5M  |\n| IF-III-L* _(soon)_                                        |   III   |  700M  |  -   |    3072    | 1.25M |\n\n *best modules\n\n### Quantitative Evaluation\n\n`FID = 6.66`\n\n![](./pics/fid30k_if.jpg)\n\n## License\n\nThe code in this repository is released under the bespoke license (see added [point two](https://github.com/deep-floyd/IF/blob/main/LICENSE#L13)).\n\nThe weights will be available soon via [the DeepFloyd organization at Hugging Face](https://huggingface.co/DeepFloyd) and have their own LICENSE.\n\n**Disclaimer:** *The initial release of the IF model is under a restricted research-purposes-only license temporarily to gather feedback, and after that we intend to release a fully open-source model in line with other Stability AI models.*\n\n## Limitations and Biases\n\nThe models available in this codebase have known limitations and biases. Please refer to [the model card](https://huggingface.co/DeepFloyd/IF-I-L-v1.0) for more information.\n\n\n## 🎓 DeepFloyd IF creators:\n\n- Alex Shonenkov [GitHub](https://github.com/shonenkov) | [Linktr](https://linktr.ee/shonenkovAI)\n- Misha Konstantinov [GitHub](https://github.com/zeroshot-ai) | [Twitter](https://twitter.com/_bra_ket)\n- Daria Bakshandaeva [GitHub](https://github.com/Gugutse) | [Twitter](https://twitter.com/_gugutse_)\n- Christoph Schuhmann [GitHub](https://github.com/christophschuhmann) | [Twitter](https://twitter.com/laion_ai)\n- Ksenia Ivanova [GitHub](https://github.com/ivksu) | [Twitter](https://twitter.com/susiaiv)\n- Nadiia Klokova [GitHub](https://github.com/vauimpuls) | [Twitter](https://twitter.com/vauimpuls)\n\n\n## 📄 Research Paper (Soon)\n\n## Acknowledgements\n\nSpecial thanks to [StabilityAI](http://stability.ai) and its CEO [Emad Mostaque](https://twitter.com/emostaque) for invaluable support, providing GPU compute and infrastructure to train the models (our gratitude goes to [Richard Vencu](https://github.com/rvencu)); thanks to [LAION](https://laion.ai) and [Christoph Schuhmann](https://github.com/christophschuhmann) in particular for contribution to the project and well-prepared datasets; thanks to [Huggingface](https://huggingface.co) teams for optimizing models' speed and memory consumption during inference, creating demos and giving cool advice!\n\n## 🚀 External Contributors 🚀\n- The Biggest Thanks [@Apolinário](https://github.com/apolinario), for ideas, consultations, help and support on all stages to make IF available in open-source; for writing a lot of documentation and instructions; for creating a friendly atmosphere in difficult moments 🦉;\n- Thanks, [@patrickvonplaten](https://github.com/patrickvonplaten), for improving loading time of unet models by 80%;\nfor integration Stable-Diffusion-x4 as native pipeline 💪;\n- Thanks, [@williamberman](https://github.com/williamberman) and [@patrickvonplaten](https://github.com/patrickvonplaten) for diffusers integration 🙌;\n- Thanks, [@hysts](https://github.com/hysts) and [@Apolinário](https://github.com/apolinario) for creating [the best gradio demo with IF](https://huggingface.co/spaces/DeepFloyd/IF) 🚀;\n- Thanks, [@Dango233](https://github.com/Dango233), for adapting IF with xformers memory efficient attention 💪;\n"
  },
  {
    "path": "deepfloyd_if/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n\n\n__version__ = '1.0.2rc0'\n"
  },
  {
    "path": "deepfloyd_if/model/__init__.py",
    "content": "# -*- coding: utf-8 -*-\nfrom .unet import UNetModel, SuperResUNetModel\n\n\n__all__ = ['UNetModel', 'SuperResUNetModel']\n"
  },
  {
    "path": "deepfloyd_if/model/gaussian_diffusion.py",
    "content": "# -*- coding: utf-8 -*-\n\"\"\"\nThis code started out as a PyTorch port of Ho et al's diffusion model:\nhttps://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py\nDocstrings have been added, as well as DDIM sampling and a new collection of beta schedules.\n\"\"\"\n\nimport enum\nimport math\nimport numpy as np\nimport torch\n\nfrom .nn import mean_flat\nfrom .losses import normal_kl, discretized_gaussian_log_likelihood\n\n\ndef get_named_beta_schedule(schedule_name, num_diffusion_timesteps):\n    \"\"\"\n    Get a pre-defined beta schedule for the given name.\n    The beta schedule library consists of beta schedules which remain similar\n    in the limit of num_diffusion_timesteps.\n    Beta schedules may be added, but should not be removed or changed once\n    they are committed to maintain backwards compatibility.\n    \"\"\"\n    if schedule_name == 'linear':\n        # Linear schedule from Ho et al, extended to work for any number of\n        # diffusion steps.\n        scale = 1000 / num_diffusion_timesteps\n        beta_start = scale * 0.0001\n        beta_end = scale * 0.02\n        return np.linspace(\n            beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64\n        )\n    elif schedule_name == 'cosine':\n        return betas_for_alpha_bar(\n            num_diffusion_timesteps,\n            lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,\n        )\n    else:\n        raise NotImplementedError(f'unknown beta schedule: {schedule_name}')\n\n\ndef betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):\n    \"\"\"\n    Create a beta schedule that discretizes the given alpha_t_bar function,\n    which defines the cumulative product of (1-beta) over time from t = [0,1].\n    :param num_diffusion_timesteps: the number of betas to produce.\n    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and\n                      produces the cumulative product of (1-beta) up to that\n                      part of the diffusion process.\n    :param max_beta: the maximum beta to use; use values lower than 1 to\n                     prevent singularities.\n    \"\"\"\n    betas = []\n    for i in range(num_diffusion_timesteps):\n        t1 = i / num_diffusion_timesteps\n        t2 = (i + 1) / num_diffusion_timesteps\n        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))\n    return np.array(betas)\n\n\nclass ModelMeanType(enum.Enum):\n    \"\"\"\n    Which type of output the model predicts.\n    \"\"\"\n\n    PREVIOUS_X = enum.auto()  # the model predicts x_{t-1}\n    START_X = enum.auto()  # the model predicts x_0\n    EPSILON = enum.auto()  # the model predicts epsilon\n\n\nclass ModelVarType(enum.Enum):\n    \"\"\"\n    What is used as the model's output variance.\n    The LEARNED_RANGE option has been added to allow the model to predict\n    values between FIXED_SMALL and FIXED_LARGE, making its job easier.\n    \"\"\"\n\n    LEARNED = enum.auto()\n    FIXED_SMALL = enum.auto()\n    FIXED_LARGE = enum.auto()\n    LEARNED_RANGE = enum.auto()\n\n\nclass LossType(enum.Enum):\n    MSE = enum.auto()  # use raw MSE loss (and KL when learning variances)\n    RESCALED_MSE = (\n        enum.auto()\n    )  # use raw MSE loss (with RESCALED_KL when learning variances)\n    KL = enum.auto()  # use the variational lower-bound\n    RESCALED_KL = enum.auto()  # like KL, but rescale to estimate the full VLB\n\n    def is_vb(self):\n        return self == LossType.KL or self == LossType.RESCALED_KL\n\n\nclass GaussianDiffusion:\n    \"\"\"\n    Utilities for training and sampling diffusion model.\n    Ported directly from here, and then adapted over time to further experimentation.\n    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42\n    :param betas: a 1-D numpy array of betas for each diffusion timestep,\n                  starting at T and going to 1.\n    :param model_mean_type: a ModelMeanType determining what the model outputs.\n    :param model_var_type: a ModelVarType determining how variance is output.\n    :param loss_type: a LossType determining the loss function to use.\n    :param rescale_timesteps: if True, pass floating point timesteps into the\n                              model so that they are always scaled like in the\n                              original paper (0 to 1000).\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        betas,\n        model_mean_type,\n        model_var_type,\n        loss_type,\n        rescale_timesteps=False,\n    ):\n        self.model_mean_type = model_mean_type\n        self.model_var_type = model_var_type\n        self.loss_type = loss_type\n        self.rescale_timesteps = rescale_timesteps\n\n        # Use float64 for accuracy.\n        betas = np.array(betas, dtype=np.float64)\n        self.betas = betas\n        assert len(betas.shape) == 1, 'betas must be 1-D'\n        assert (betas > 0).all() and (betas <= 1).all()\n\n        self.num_timesteps = int(betas.shape[0])\n\n        alphas = 1.0 - betas\n        self.alphas_cumprod = np.cumprod(alphas, axis=0)\n        self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])\n        self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)\n        assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)\n\n        # calculations for diffusion q(x_t | x_{t-1}) and others\n        self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)\n        self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)\n        self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)\n        self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)\n        self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)\n\n        # calculations for posterior q(x_{t-1} | x_t, x_0)\n        self.posterior_variance = (\n            betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)\n        )\n        # log calculation clipped because the posterior variance is 0 at the\n        # beginning of the diffusion chain.\n        self.posterior_log_variance_clipped = np.log(\n            np.append(self.posterior_variance[1], self.posterior_variance[1:])\n        )\n        self.posterior_mean_coef1 = (\n            betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)\n        )\n        self.posterior_mean_coef2 = (\n            (1.0 - self.alphas_cumprod_prev)\n            * np.sqrt(alphas)\n            / (1.0 - self.alphas_cumprod)\n        )\n\n    def dynamic_thresholding(self, x, p=0.995, c=1.7):\n        \"\"\"\n        Dynamic thresholding, a diffusion sampling technique from Imagen (https://arxiv.org/abs/2205.11487)\n        to leverage high guidance weights and generating more photorealistic and detailed images\n        than previously was possible based on x.clamp(-1, 1) vanilla clipping or static thresholding\n\n        p — percentile determine relative value for clipping threshold for dynamic compression,\n            helps prevent oversaturation recommend values [0.96 — 0.99]\n\n        c — absolute hard clipping of value for clipping threshold for dynamic compression,\n            helps prevent undersaturation and low contrast issues; recommend values [1.5 — 2.]\n        \"\"\"\n        x_shapes = x.shape\n        s = torch.quantile(x.abs().reshape(x_shapes[0], -1), p, dim=-1)\n        s = torch.clamp(s, min=1, max=c)\n        x_compressed = torch.clip(x.reshape(x_shapes[0], -1).T, -s, s) / s\n        x_compressed = x_compressed.T.reshape(x_shapes)\n        return x_compressed\n\n    def q_mean_variance(self, x_start, t):\n        \"\"\"\n        Get the distribution q(x_t | x_0).\n        :param x_start: the [N x C x ...] tensor of noiseless inputs.\n        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.\n        :return: A tuple (mean, variance, log_variance), all of x_start's shape.\n        \"\"\"\n        mean = (\n            _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start\n        )\n        variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)\n        log_variance = _extract_into_tensor(\n            self.log_one_minus_alphas_cumprod, t, x_start.shape\n        )\n        return mean, variance, log_variance\n\n    def q_sample(self, x_start, t, noise=None):\n        \"\"\"\n        Diffuse the data for a given number of diffusion steps.\n        In other words, sample from q(x_t | x_0).\n        :param x_start: the initial data batch.\n        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.\n        :param noise: if specified, the split-out normal noise.\n        :return: A noisy version of x_start.\n        \"\"\"\n        if noise is None:\n            noise = torch.randn_like(x_start)\n        assert noise.shape == x_start.shape\n        return (\n            _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start\n            + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)\n            * noise\n        )\n\n    def q_posterior_mean_variance(self, x_start, x_t, t):\n        \"\"\"\n        Compute the mean and variance of the diffusion posterior:\n            q(x_{t-1} | x_t, x_0)\n        \"\"\"\n        assert x_start.shape == x_t.shape\n        posterior_mean = (\n            _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start\n            + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t\n        )\n        posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)\n        posterior_log_variance_clipped = _extract_into_tensor(\n            self.posterior_log_variance_clipped, t, x_t.shape\n        )\n        assert (\n            posterior_mean.shape[0]\n            == posterior_variance.shape[0]\n            == posterior_log_variance_clipped.shape[0]\n            == x_start.shape[0]\n        )\n        return posterior_mean, posterior_variance, posterior_log_variance_clipped\n\n    def p_mean_variance(\n        self, model, x, t, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7,\n        denoised_fn=None, model_kwargs=None\n    ):\n        \"\"\"\n        Apply the model to get p(x_{t-1} | x_t), as well as a prediction of\n        the initial x, x_0.\n        :param model: the model, which takes a signal and a batch of timesteps\n                      as input.\n        :param x: the [N x C x ...] tensor at time t.\n        :param t: a 1-D Tensor of timesteps.\n        :param clip_denoised: if True, clip the denoised signal into [-1, 1].\n        :param denoised_fn: if not None, a function which applies to the\n            x_start prediction before it is used to sample. Applies before\n            clip_denoised.\n        :param model_kwargs: if not None, a dict of extra keyword arguments to\n            pass to the model. This can be used for conditioning.\n        :return: a dict with the following keys:\n                 - 'mean': the model mean output.\n                 - 'variance': the model variance output.\n                 - 'log_variance': the log of 'variance'.\n                 - 'pred_xstart': the prediction for x_0.\n        \"\"\"\n        if model_kwargs is None:\n            model_kwargs = {}\n\n        B, C = x.shape[:2]\n        assert t.shape == (B,)\n        model_output = model(x, self._scale_timesteps(t), **model_kwargs)\n\n        if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:\n            assert model_output.shape == (B, C * 2, *x.shape[2:])\n            model_output, model_var_values = torch.split(model_output, C, dim=1)\n            if self.model_var_type == ModelVarType.LEARNED:\n                model_log_variance = model_var_values\n                model_variance = torch.exp(model_log_variance)\n            else:\n                min_log = _extract_into_tensor(\n                    self.posterior_log_variance_clipped, t, x.shape\n                )\n                max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)\n                # The model_var_values is [-1, 1] for [min_var, max_var].\n                frac = (model_var_values + 1) / 2\n                model_log_variance = frac * max_log + (1 - frac) * min_log\n                model_variance = torch.exp(model_log_variance)\n        else:\n            model_variance, model_log_variance = {\n                # for fixedlarge, we set the initial (log-)variance like so\n                # to get a better decoder log likelihood.\n                ModelVarType.FIXED_LARGE: (\n                    np.append(self.posterior_variance[1], self.betas[1:]),\n                    np.log(np.append(self.posterior_variance[1], self.betas[1:])),\n                ),\n                ModelVarType.FIXED_SMALL: (\n                    self.posterior_variance,\n                    self.posterior_log_variance_clipped,\n                ),\n            }[self.model_var_type]\n            model_variance = _extract_into_tensor(model_variance, t, x.shape)\n            model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)\n\n        def process_xstart(x):\n            if denoised_fn is not None:\n                x = denoised_fn(x)\n            if clip_denoised:\n                x = self.dynamic_thresholding(x, p=dynamic_thresholding_p, c=dynamic_thresholding_c)\n                return x  # x.clamp(-1, 1)\n            return x\n\n        if self.model_mean_type == ModelMeanType.PREVIOUS_X:\n            pred_xstart = process_xstart(\n                self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)\n            )\n            model_mean = model_output\n        elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:\n            if self.model_mean_type == ModelMeanType.START_X:\n                pred_xstart = process_xstart(model_output)\n            else:\n                pred_xstart = process_xstart(\n                    self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)\n                )\n            model_mean, _, _ = self.q_posterior_mean_variance(\n                x_start=pred_xstart, x_t=x, t=t\n            )\n        else:\n            raise NotImplementedError(self.model_mean_type)\n\n        assert (\n            model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape\n        )\n        return {\n            'mean': model_mean,\n            'variance': model_variance,\n            'log_variance': model_log_variance,\n            'pred_xstart': pred_xstart,\n        }\n\n    def _predict_xstart_from_eps(self, x_t, t, eps):\n        assert x_t.shape == eps.shape\n        return (\n            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t\n            - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps\n        )\n\n    def _predict_xstart_from_xprev(self, x_t, t, xprev):\n        assert x_t.shape == xprev.shape\n        return (  # (xprev - coef2*x_t) / coef1\n            _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev\n            - _extract_into_tensor(\n                self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape\n            )\n            * x_t\n        )\n\n    def _predict_eps_from_xstart(self, x_t, t, pred_xstart):\n        return (\n            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t\n            - pred_xstart\n        ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)\n\n    def _scale_timesteps(self, t):\n        if self.rescale_timesteps:\n            return t.float() * (1000.0 / self.num_timesteps)\n        return t\n\n    def p_sample(\n        self, model, x, t, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7,\n        denoised_fn=None, model_kwargs=None, inpainting_mask=None,\n    ):\n        \"\"\"\n        Sample x_{t-1} from the model at the given timestep.\n        :param model: the model to sample from.\n        :param x: the current tensor at x_{t-1}.\n        :param t: the value of t, starting at 0 for the first diffusion step.\n        :param clip_denoised: if True, clip the x_start prediction to [-1, 1].\n        :param denoised_fn: if not None, a function which applies to the\n            x_start prediction before it is used to sample.\n        :param model_kwargs: if not None, a dict of extra keyword arguments to\n            pass to the model. This can be used for conditioning.\n        :return: a dict containing the following keys:\n                 - 'sample': a random sample from the model.\n                 - 'pred_xstart': a prediction of x_0.\n        \"\"\"\n        out = self.p_mean_variance(\n            model,\n            x,\n            t,\n            clip_denoised=clip_denoised,\n            dynamic_thresholding_p=dynamic_thresholding_p,\n            dynamic_thresholding_c=dynamic_thresholding_c,\n            denoised_fn=denoised_fn,\n            model_kwargs=model_kwargs,\n        )\n        noise = torch.randn_like(x)\n        nonzero_mask = (\n            (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))\n        )  # no noise when t == 0\n        if inpainting_mask is None:\n            inpainting_mask = torch.ones_like(x, device=x.device)\n\n        sample = out['mean'] + nonzero_mask * torch.exp(0.5 * out['log_variance']) * noise\n        sample = (1 - inpainting_mask)*x + inpainting_mask*sample\n        return {'sample': sample, 'pred_xstart': out['pred_xstart']}\n\n    def p_sample_loop(\n        self,\n        model,\n        shape,\n        noise=None,\n        clip_denoised=True,\n        dynamic_thresholding_p=0.99,\n        dynamic_thresholding_c=1.7,\n        inpainting_mask=None,\n        denoised_fn=None,\n        model_kwargs=None,\n        device=None,\n        progress=False,\n        sample_fn=None,\n    ):\n        \"\"\"\n        Generate samples from the model.\n        :param model: the model module.\n        :param shape: the shape of the samples, (N, C, H, W).\n        :param noise: if specified, the noise from the encoder to sample.\n                      Should be of the same shape as `shape`.\n        :param clip_denoised: if True, clip x_start predictions to [-1, 1].\n        :param denoised_fn: if not None, a function which applies to the\n            x_start prediction before it is used to sample.\n        :param model_kwargs: if not None, a dict of extra keyword arguments to\n            pass to the model. This can be used for conditioning.\n        :param device: if specified, the device to create the samples on.\n                       If not specified, use a model parameter's device.\n        :param progress: if True, show a tqdm progress bar.\n        :return: a non-differentiable batch of samples.\n        \"\"\"\n        final = None\n        for step_idx, sample in enumerate(self.p_sample_loop_progressive(\n            model,\n            shape,\n            noise=noise,\n            clip_denoised=clip_denoised,\n            dynamic_thresholding_p=dynamic_thresholding_p,\n            dynamic_thresholding_c=dynamic_thresholding_c,\n            denoised_fn=denoised_fn,\n            inpainting_mask=inpainting_mask,\n            model_kwargs=model_kwargs,\n            device=device,\n            progress=progress,\n        )):\n            if sample_fn is not None:\n                sample = sample_fn(step_idx, sample)\n            final = sample\n        return final['sample']\n\n    def p_sample_loop_progressive(\n        self,\n        model,\n        shape,\n        inpainting_mask=None,\n        noise=None,\n        clip_denoised=True,\n        dynamic_thresholding_p=0.99,\n        dynamic_thresholding_c=1.7,\n        denoised_fn=None,\n        model_kwargs=None,\n        device=None,\n        progress=False,\n    ):\n        \"\"\"\n        Generate samples from the model and yield intermediate samples from\n        each timestep of diffusion.\n        Arguments are the same as p_sample_loop().\n        Returns a generator over dicts, where each dict is the return value of\n        p_sample().\n        \"\"\"\n        if device is None:\n            device = next(model.parameters()).device\n        assert isinstance(shape, (tuple, list))\n        if noise is not None:\n            img = noise\n        else:\n            img = torch.randn(*shape, device=device)\n        indices = list(range(self.num_timesteps))[::-1]\n\n        if progress:\n            # Lazy import so that we don't depend on tqdm.\n            from tqdm.auto import tqdm\n\n            indices = tqdm(indices)\n\n        for i in indices:\n            t = torch.tensor([i] * shape[0], device=device)\n            with torch.no_grad():\n                out = self.p_sample(\n                    model,\n                    img,\n                    t,\n                    clip_denoised=clip_denoised,\n                    dynamic_thresholding_p=dynamic_thresholding_p,\n                    dynamic_thresholding_c=dynamic_thresholding_c,\n                    denoised_fn=denoised_fn,\n                    inpainting_mask=inpainting_mask,\n                    model_kwargs=model_kwargs,\n                )\n                yield out\n                img = out['sample']\n\n    def ddim_sample(\n        self,\n        model,\n        x,\n        t,\n        clip_denoised=True,\n        dynamic_thresholding_p=0.99,\n        dynamic_thresholding_c=1.7,\n        denoised_fn=None,\n        model_kwargs=None,\n        eta=0.0,\n    ):\n        \"\"\"\n        Sample x_{t-1} from the model using DDIM.\n        Same usage as p_sample().\n        \"\"\"\n        out = self.p_mean_variance(\n            model,\n            x,\n            t,\n            dynamic_thresholding_p=dynamic_thresholding_p,\n            dynamic_thresholding_c=dynamic_thresholding_c,\n            clip_denoised=clip_denoised,\n            denoised_fn=denoised_fn,\n            model_kwargs=model_kwargs,\n        )\n        # Usually our model outputs epsilon, but we re-derive it\n        # in case we used x_start or x_prev prediction.\n        eps = self._predict_eps_from_xstart(x, t, out['pred_xstart'])\n        alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)\n        alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)\n        sigma = (\n            eta\n            * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))\n            * torch.sqrt(1 - alpha_bar / alpha_bar_prev)\n        )\n        # Equation 12.\n        noise = torch.randn_like(x)\n        mean_pred = (\n            out['pred_xstart'] * torch.sqrt(alpha_bar_prev)\n            + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps\n        )\n        nonzero_mask = (\n            (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))\n        )  # no noise when t == 0\n        sample = mean_pred + nonzero_mask * sigma * noise\n        return {'sample': sample, 'pred_xstart': out['pred_xstart']}\n\n    def ddim_reverse_sample(\n        self,\n        model,\n        x,\n        t,\n        clip_denoised=True,\n        dynamic_thresholding_p=0.99,\n        dynamic_thresholding_c=1.7,\n        denoised_fn=None,\n        model_kwargs=None,\n        eta=0.0,\n    ):\n        \"\"\"\n        Sample x_{t+1} from the model using DDIM reverse ODE.\n        \"\"\"\n        assert eta == 0.0, 'Reverse ODE only for deterministic path'\n        out = self.p_mean_variance(\n            model,\n            x,\n            t,\n            clip_denoised=clip_denoised,\n            dynamic_thresholding_p=dynamic_thresholding_p,\n            dynamic_thresholding_c=dynamic_thresholding_c,\n            denoised_fn=denoised_fn,\n            model_kwargs=model_kwargs,\n        )\n        # Usually our model outputs epsilon, but we re-derive it\n        # in case we used x_start or x_prev prediction.\n        eps = (\n            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x\n            - out['pred_xstart']\n        ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)\n        alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)\n\n        # Equation 12. reversed\n        mean_pred = (\n            out['pred_xstart'] * torch.sqrt(alpha_bar_next)\n            + torch.sqrt(1 - alpha_bar_next) * eps\n        )\n\n        return {'sample': mean_pred, 'pred_xstart': out['pred_xstart']}\n\n    def ddim_sample_loop(\n        self,\n        model,\n        shape,\n        noise=None,\n        clip_denoised=True,\n        dynamic_thresholding_p=0.99,\n        dynamic_thresholding_c=1.7,\n        denoised_fn=None,\n        model_kwargs=None,\n        device=None,\n        progress=False,\n        eta=0.0,\n        sample_fn=None,\n    ):\n        \"\"\"\n        Generate samples from the model using DDIM.\n        Same usage as p_sample_loop().\n        \"\"\"\n        final = None\n        for step_idx, sample in enumerate(self.ddim_sample_loop_progressive(\n            model,\n            shape,\n            noise=noise,\n            clip_denoised=clip_denoised,\n            denoised_fn=denoised_fn,\n            dynamic_thresholding_p=dynamic_thresholding_p,\n            dynamic_thresholding_c=dynamic_thresholding_c,\n            model_kwargs=model_kwargs,\n            device=device,\n            progress=progress,\n            eta=eta,\n        )):\n            if sample_fn is not None:\n                sample = sample_fn(step_idx, sample)\n            final = sample\n        return final['sample']\n\n    def ddim_sample_loop_progressive(\n        self,\n        model,\n        shape,\n        noise=None,\n        clip_denoised=True,\n        dynamic_thresholding_p=0.99,\n        dynamic_thresholding_c=1.7,\n        denoised_fn=None,\n        model_kwargs=None,\n        device=None,\n        progress=False,\n        eta=0.0,\n    ):\n        \"\"\"\n        Use DDIM to sample from the model and yield intermediate samples from\n        each timestep of DDIM.\n        Same usage as p_sample_loop_progressive().\n        \"\"\"\n        if device is None:\n            device = next(model.parameters()).device\n        assert isinstance(shape, (tuple, list))\n        if noise is not None:\n            img = noise\n        else:\n            img = torch.randn(*shape, device=device)\n        indices = list(range(self.num_timesteps))[::-1]\n\n        if progress:\n            # Lazy import so that we don't depend on tqdm.\n            from tqdm.auto import tqdm\n\n            indices = tqdm(indices)\n\n        for i in indices:\n            t = torch.tensor([i] * shape[0], device=device)\n            with torch.no_grad():\n                out = self.ddim_sample(\n                    model,\n                    img,\n                    t,\n                    clip_denoised=clip_denoised,\n                    dynamic_thresholding_p=dynamic_thresholding_p,\n                    dynamic_thresholding_c=dynamic_thresholding_c,\n                    denoised_fn=denoised_fn,\n                    model_kwargs=model_kwargs,\n                    eta=eta,\n                )\n                yield out\n                img = out['sample']\n\n    def _vb_terms_bpd(\n        self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None\n    ):\n        \"\"\"\n        Get a term for the variational lower-bound.\n        The resulting units are bits (rather than nats, as one might expect).\n        This allows for comparison to other papers.\n        :return: a dict with the following keys:\n                 - 'output': a shape [N] tensor of NLLs or KLs.\n                 - 'pred_xstart': the x_0 predictions.\n        \"\"\"\n        true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(\n            x_start=x_start, x_t=x_t, t=t\n        )\n        out = self.p_mean_variance(\n            model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs\n        )\n        kl = normal_kl(\n            true_mean, true_log_variance_clipped, out['mean'], out['log_variance']\n        )\n        kl = mean_flat(kl) / np.log(2.0)\n\n        decoder_nll = -discretized_gaussian_log_likelihood(\n            x_start, means=out['mean'], log_scales=0.5 * out['log_variance']\n        )\n        assert decoder_nll.shape == x_start.shape\n        decoder_nll = mean_flat(decoder_nll) / np.log(2.0)\n\n        # At the first timestep return the decoder NLL,\n        # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))\n        output = torch.where((t == 0), decoder_nll, kl)\n        return {'output': output, 'pred_xstart': out['pred_xstart']}\n\n    def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):\n        \"\"\"\n        Compute training losses for a single timestep.\n        :param model: the model to evaluate loss on.\n        :param x_start: the [N x C x ...] tensor of inputs.\n        :param t: a batch of timestep indices.\n        :param model_kwargs: if not None, a dict of extra keyword arguments to\n            pass to the model. This can be used for conditioning.\n        :param noise: if specified, the specific Gaussian noise to try to remove.\n        :return: a dict with the key \"loss\" containing a tensor of shape [N].\n                 Some mean or variance settings may also have other keys.\n        \"\"\"\n        if model_kwargs is None:\n            model_kwargs = {}\n        if noise is None:\n            noise = torch.randn_like(x_start)\n        x_t = self.q_sample(x_start, t, noise=noise)\n\n        terms = {}\n\n        if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:\n            terms['loss'] = self._vb_terms_bpd(\n                model=model,\n                x_start=x_start,\n                x_t=x_t,\n                t=t,\n                clip_denoised=False,\n                model_kwargs=model_kwargs,\n            )['output']\n            if self.loss_type == LossType.RESCALED_KL:\n                terms['loss'] *= self.num_timesteps\n        elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:\n            model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)\n\n            if self.model_var_type in [\n                ModelVarType.LEARNED,\n                ModelVarType.LEARNED_RANGE,\n            ]:\n                B, C = x_t.shape[:2]\n                assert model_output.shape == (B, C * 2, *x_t.shape[2:])\n                model_output, model_var_values = torch.split(model_output, C, dim=1)\n                # Learn the variance using the variational bound, but don't let\n                # it affect our mean prediction.\n                frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1)\n                terms['vb'] = self._vb_terms_bpd(\n                    model=lambda *args, r=frozen_out: r,\n                    x_start=x_start,\n                    x_t=x_t,\n                    t=t,\n                    clip_denoised=False,\n                )['output']\n                if self.loss_type == LossType.RESCALED_MSE:\n                    # Divide by 1000 for equivalence with initial implementation.\n                    # Without a factor of 1/1000, the VB term hurts the MSE term.\n                    terms['vb'] *= self.num_timesteps / 1000.0\n\n            target = {\n                ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(\n                    x_start=x_start, x_t=x_t, t=t\n                )[0],\n                ModelMeanType.START_X: x_start,\n                ModelMeanType.EPSILON: noise,\n            }[self.model_mean_type]\n            assert model_output.shape == target.shape == x_start.shape\n            terms['mse'] = mean_flat((target - model_output) ** 2)\n            if 'vb' in terms:\n                terms['loss'] = terms['mse'] + terms['vb']\n            else:\n                terms['loss'] = terms['mse']\n        else:\n            raise NotImplementedError(self.loss_type)\n\n        return terms\n\n    def _prior_bpd(self, x_start):\n        \"\"\"\n        Get the prior KL term for the variational lower-bound, measured in\n        bits-per-dim.\n        This term can't be optimized, as it only depends on the encoder.\n        :param x_start: the [N x C x ...] tensor of inputs.\n        :return: a batch of [N] KL values (in bits), one per batch element.\n        \"\"\"\n        batch_size = x_start.shape[0]\n        t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)\n        qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)\n        kl_prior = normal_kl(\n            mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0\n        )\n        return mean_flat(kl_prior) / np.log(2.0)\n\n    def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):\n        \"\"\"\n        Compute the entire variational lower-bound, measured in bits-per-dim,\n        as well as other related quantities.\n        :param model: the model to evaluate loss on.\n        :param x_start: the [N x C x ...] tensor of inputs.\n        :param clip_denoised: if True, clip denoised samples.\n        :param model_kwargs: if not None, a dict of extra keyword arguments to\n            pass to the model. This can be used for conditioning.\n        :return: a dict containing the following keys:\n                 - total_bpd: the total variational lower-bound, per batch element.\n                 - prior_bpd: the prior term in the lower-bound.\n                 - vb: an [N x T] tensor of terms in the lower-bound.\n                 - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.\n                 - mse: an [N x T] tensor of epsilon MSEs for each timestep.\n        \"\"\"\n        device = x_start.device\n        batch_size = x_start.shape[0]\n\n        vb = []\n        xstart_mse = []\n        mse = []\n        for t in list(range(self.num_timesteps))[::-1]:\n            t_batch = torch.tensor([t] * batch_size, device=device)\n            noise = torch.randn_like(x_start)\n            x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)\n            # Calculate VLB term at the current timestep\n            with torch.no_grad():\n                out = self._vb_terms_bpd(\n                    model,\n                    x_start=x_start,\n                    x_t=x_t,\n                    t=t_batch,\n                    clip_denoised=clip_denoised,\n                    model_kwargs=model_kwargs,\n                )\n            vb.append(out['output'])\n            xstart_mse.append(mean_flat((out['pred_xstart'] - x_start) ** 2))\n            eps = self._predict_eps_from_xstart(x_t, t_batch, out['pred_xstart'])\n            mse.append(mean_flat((eps - noise) ** 2))\n\n        vb = torch.stack(vb, dim=1)\n        xstart_mse = torch.stack(xstart_mse, dim=1)\n        mse = torch.stack(mse, dim=1)\n\n        prior_bpd = self._prior_bpd(x_start)\n        total_bpd = vb.sum(dim=1) + prior_bpd\n        return {\n            'total_bpd': total_bpd,\n            'prior_bpd': prior_bpd,\n            'vb': vb,\n            'xstart_mse': xstart_mse,\n            'mse': mse,\n        }\n\n\ndef _extract_into_tensor(arr, timesteps, broadcast_shape):\n    \"\"\"\n    Extract values from a 1-D numpy array for a batch of indices.\n    :param arr: the 1-D numpy array.\n    :param timesteps: a tensor of indices into the array to extract.\n    :param broadcast_shape: a larger shape of K dimensions with the batch\n                            dimension equal to the length of timesteps.\n    :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.\n    \"\"\"\n    res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()\n    while len(res.shape) < len(broadcast_shape):\n        res = res[..., None]\n    return res.expand(broadcast_shape)\n"
  },
  {
    "path": "deepfloyd_if/model/losses.py",
    "content": "# -*- coding: utf-8 -*-\n\"\"\"\nHelpers for various likelihood-based losses. These are ported from the original\nHo et al. diffusion model codebase:\nhttps://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py\n\"\"\"\n\nimport torch\nimport numpy as np\n\n\ndef normal_kl(mean1, logvar1, mean2, logvar2):\n    \"\"\"\n    Compute the KL divergence between two gaussians.\n    Shapes are automatically broadcasted, so batches can be compared to\n    scalars, among other use cases.\n    \"\"\"\n    tensor = None\n    for obj in (mean1, logvar1, mean2, logvar2):\n        if isinstance(obj, torch.Tensor):\n            tensor = obj\n            break\n    assert tensor is not None, 'at least one argument must be a Tensor'\n\n    # Force variances to be Tensors. Broadcasting helps convert scalars to\n    # Tensors, but it does not work for th.exp().\n    logvar1, logvar2 = [\n        x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)\n        for x in (logvar1, logvar2)\n    ]\n\n    return 0.5 * (\n        -1.0\n        + logvar2\n        - logvar1\n        + torch.exp(logvar1 - logvar2)\n        + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)\n    )\n\n\ndef approx_standard_normal_cdf(x):\n    \"\"\"\n    A fast approximation of the cumulative distribution function of the\n    standard normal.\n    \"\"\"\n    return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3))))\n\n\ndef discretized_gaussian_log_likelihood(x, *, means, log_scales):\n    \"\"\"\n    Compute the log-likelihood of a Gaussian distribution discretizing to a\n    given image.\n    :param x: the target images. It is assumed that this was uint8 values,\n              rescaled to the range [-1, 1].\n    :param means: the Gaussian mean Tensor.\n    :param log_scales: the Gaussian log stddev Tensor.\n    :return: a tensor like x of log probabilities (in nats).\n    \"\"\"\n    assert x.shape == means.shape == log_scales.shape\n    centered_x = x - means\n    inv_stdv = torch.exp(-log_scales)\n    plus_in = inv_stdv * (centered_x + 1.0 / 255.0)\n    cdf_plus = approx_standard_normal_cdf(plus_in)\n    min_in = inv_stdv * (centered_x - 1.0 / 255.0)\n    cdf_min = approx_standard_normal_cdf(min_in)\n    log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))\n    log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))\n    cdf_delta = cdf_plus - cdf_min\n    log_probs = torch.where(\n        x < -0.999,\n        log_cdf_plus,\n        torch.where(x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))),\n    )\n    assert log_probs.shape == x.shape\n    return log_probs\n"
  },
  {
    "path": "deepfloyd_if/model/nn.py",
    "content": "# -*- coding: utf-8 -*-\nimport math\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom torch import Tensor\n\n\ndef mean_flat(tensor):\n    \"\"\"\n    Take the mean over all non-batch dimensions.\n    \"\"\"\n    return tensor.mean(dim=list(range(1, len(tensor.shape))))\n\n\ndef gelu(x):\n    return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))\n\n\n@torch.jit.script\ndef gelu_jit(x):\n    \"\"\"OpenAI's gelu implementation.\"\"\"\n    return gelu(x)\n\n\nclass GELUJit(torch.nn.Module):\n    def forward(self, input: Tensor) -> Tensor:\n        return gelu_jit(input)\n\n\ndef get_activation(activation):\n    if activation == 'silu':\n        return torch.nn.SiLU()\n    elif activation == 'gelu_jit':\n        return GELUJit()\n    elif activation == 'gelu':\n        return torch.nn.GELU()\n    elif activation == 'none':\n        return torch.nn.Identity()\n    else:\n        raise ValueError(f'unknown activation type {activation}')\n\n\nclass GroupNorm32(nn.GroupNorm):\n    def __init__(self, num_groups, num_channels, eps=1e-5, dtype=None):\n        super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, dtype=dtype)\n\n    def forward(self, x):\n        y = super().forward(x).to(x.dtype)\n        return y\n\n\nclass AttentionPooling(nn.Module):\n\n    def __init__(self, num_heads, embed_dim, dtype=None):\n        super().__init__()\n        self.dtype = dtype\n        self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim ** 0.5)\n        self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)\n        self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)\n        self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)\n        self.num_heads = num_heads\n        self.dim_per_head = embed_dim // self.num_heads\n\n    def forward(self, x):\n        bs, length, width = x.size()\n\n        def shape(x):\n            # (bs, length, width) --> (bs, length, n_heads, dim_per_head)\n            x = x.view(bs, -1, self.num_heads, self.dim_per_head)\n            # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)\n            x = x.transpose(1, 2)\n            # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)\n            x = x.reshape(bs*self.num_heads, -1, self.dim_per_head)\n            # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)\n            x = x.transpose(1, 2)\n            return x\n\n        class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)\n        x = torch.cat([class_token, x], dim=1)  # (bs, length+1, width)\n\n        # (bs*n_heads, class_token_length, dim_per_head)\n        q = shape(self.q_proj(class_token))\n        # (bs*n_heads, length+class_token_length, dim_per_head)\n        k = shape(self.k_proj(x))\n        v = shape(self.v_proj(x))\n\n        # (bs*n_heads, class_token_length, length+class_token_length):\n        scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))\n        weight = torch.einsum(\n            'bct,bcs->bts', q * scale, k * scale\n        )  # More stable with f16 than dividing afterwards\n        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)\n\n        # (bs*n_heads, dim_per_head, class_token_length)\n        a = torch.einsum('bts,bcs->bct', weight, v)\n\n        # (bs, length+1, width)\n        a = a.reshape(bs, -1, 1).transpose(1, 2)\n\n        return a[:, 0, :]  # cls_token\n\n\ndef conv_nd(dims, *args, **kwargs):\n    \"\"\"\n    Create a 1D, 2D, or 3D convolution module.\n    \"\"\"\n    if dims == 1:\n        return nn.Conv1d(*args, **kwargs)\n    elif dims == 2:\n        return nn.Conv2d(*args, **kwargs)\n    elif dims == 3:\n        return nn.Conv3d(*args, **kwargs)\n    raise ValueError(f'unsupported dimensions: {dims}')\n\n\ndef linear(*args, **kwargs):\n    \"\"\"\n    Create a linear module.\n    \"\"\"\n    return nn.Linear(*args, **kwargs)\n\n\ndef avg_pool_nd(dims, *args, **kwargs):\n    \"\"\"\n    Create a 1D, 2D, or 3D average pooling module.\n    \"\"\"\n    if dims == 1:\n        return nn.AvgPool1d(*args, **kwargs)\n    elif dims == 2:\n        return nn.AvgPool2d(*args, **kwargs)\n    elif dims == 3:\n        return nn.AvgPool3d(*args, **kwargs)\n    raise ValueError(f'unsupported dimensions: {dims}')\n\n\ndef zero_module(module):\n    \"\"\"\n    Zero out the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().zero_()\n    return module\n\n\ndef scale_module(module, scale):\n    \"\"\"\n    Scale the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().mul_(scale)\n    return module\n\n\ndef normalization(channels, dtype=None):\n    \"\"\"\n    Make a standard normalization layer.\n    :param channels: number of input channels.\n    :return: an nn.Module for normalization.\n    \"\"\"\n    return GroupNorm32(num_channels=channels, num_groups=32, dtype=dtype)\n\n\ndef timestep_embedding(timesteps, dim, max_period=10000, dtype=None):\n    \"\"\"\n    Create sinusoidal timestep embeddings.\n    :param timesteps: a 1-D Tensor of N indices, one per batch element.\n                      These may be fractional.\n    :param dim: the dimension of the output.\n    :param max_period: controls the minimum frequency of the embeddings.\n    :return: an [N x dim] Tensor of positional embeddings.\n    \"\"\"\n    if dtype is None:\n        dtype = torch.float32\n    half = dim // 2\n    freqs = torch.exp(\n        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half\n    ).to(device=timesteps.device, dtype=dtype)\n    args = timesteps[:, None].type(dtype) * freqs[None]\n    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n    if dim % 2:\n        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)\n    return embedding\n\n\ndef attention(q, k, v, d_k):\n    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)\n    scores = F.softmax(scores, dim=-1)\n    output = torch.matmul(scores, v)\n    return output\n"
  },
  {
    "path": "deepfloyd_if/model/resample.py",
    "content": "# -*- coding: utf-8 -*-\nfrom abc import ABC, abstractmethod\n\nimport torch\nimport numpy as np\n\n\nclass ScheduleSampler(ABC):\n    \"\"\"\n    A distribution over timesteps in the diffusion process, intended to reduce\n    variance of the objective.\n    By default, samplers perform unbiased importance sampling, in which the\n    objective's mean is unchanged.\n    However, subclasses may override sample() to change how the resampled\n    terms are reweighted, allowing for actual changes in the objective.\n    \"\"\"\n\n    @abstractmethod\n    def weights(self):\n        \"\"\"\n        Get a numpy array of weights, one per diffusion step.\n        The weights needn't be normalized, but must be positive.\n        \"\"\"\n\n    def sample(self, batch_size, device):\n        \"\"\"\n        Importance-sample timesteps for a batch.\n        :param batch_size: the number of timesteps.\n        :param device: the torch device to save to.\n        :return: a tuple (timesteps, weights):\n                 - timesteps: a tensor of timestep indices.\n                 - weights: a tensor of weights to scale the resulting losses.\n        \"\"\"\n        w = self.weights()\n        p = w / np.sum(w)\n        indices_np = np.random.choice(len(p), size=(batch_size,), p=p)\n        indices = torch.from_numpy(indices_np).long().to(device)\n        weights_np = 1 / (len(p) * p[indices_np])\n        weights = torch.from_numpy(weights_np).float().to(device)\n        return indices, weights\n\n\nclass UniformSampler(ScheduleSampler):\n    def __init__(self, num_timesteps):\n        self._weights = np.ones([num_timesteps])\n\n    def weights(self):\n        return self._weights\n\n\nclass StaticSampler(ABC):\n\n    def sample(self, batch_size, device, static_step=100):\n        indices_np = np.ones(batch_size, dtype=np.int) * static_step\n        weights_np = np.ones(batch_size, dtype=np.int)\n        indices = torch.from_numpy(indices_np).long().to(device)\n        weights = torch.from_numpy(weights_np).float().to(device)\n        return indices, weights\n"
  },
  {
    "path": "deepfloyd_if/model/respace.py",
    "content": "# -*- coding: utf-8 -*-\nimport torch\nimport numpy as np\n\nfrom . import gaussian_diffusion as gd\n\n\ndef create_gaussian_diffusion(\n    *,\n    steps=1000,\n    learn_sigma=False,\n    sigma_small=False,\n    noise_schedule='linear',\n    use_kl=False,\n    predict_xstart=False,\n    rescale_timesteps=False,\n    rescale_learned_sigmas=False,\n    timestep_respacing='',\n):\n    betas = gd.get_named_beta_schedule(noise_schedule, steps)\n    if use_kl:\n        loss_type = gd.LossType.RESCALED_KL\n    elif rescale_learned_sigmas:\n        loss_type = gd.LossType.RESCALED_MSE\n    else:\n        loss_type = gd.LossType.MSE\n    if not timestep_respacing:\n        timestep_respacing = [steps]\n    return SpacedDiffusion(\n        use_timesteps=space_timesteps(steps, timestep_respacing),\n        betas=betas,\n        model_mean_type=(\n            gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X\n        ),\n        model_var_type=(\n            (\n                gd.ModelVarType.FIXED_LARGE\n                if not sigma_small\n                else gd.ModelVarType.FIXED_SMALL\n            )\n            if not learn_sigma\n            else gd.ModelVarType.LEARNED_RANGE\n        ),\n        loss_type=loss_type,\n        rescale_timesteps=rescale_timesteps,\n    )\n\n\ndef space_timesteps(num_timesteps, section_counts):\n    \"\"\"\n    Create a list of timesteps to use from an original diffusion process,\n    given the number of timesteps we want to take from equally-sized portions\n    of the original process.\n    For example, if there's 300 timesteps and the section counts are [10,15,20]\n    then the first 100 timesteps are strided to be 10 timesteps, the second 100\n    are strided to be 15 timesteps, and the final 100 are strided to be 20.\n    If the stride is a string starting with \"ddim\", then the fixed striding\n    from the DDIM paper is used, and only one section is allowed.\n    :param num_timesteps: the number of diffusion steps in the original\n                          process to divide up.\n    :param section_counts: either a list of numbers, or a string containing\n                           comma-separated numbers, indicating the step count\n                           per section. As a special case, use \"ddimN\" where N\n                           is a number of steps to use the striding from the\n                           DDIM paper.\n    :return: a set of diffusion steps from the original process to use.\n    \"\"\"\n    if isinstance(section_counts, str):\n        if section_counts.startswith('ddim'):\n            desired_count = int(section_counts[len('ddim'):])\n            for i in range(1, num_timesteps):\n                if len(range(0, num_timesteps, i)) == desired_count:\n                    return set(range(0, num_timesteps, i))\n            raise ValueError(\n                f'cannot create exactly {num_timesteps} steps with an integer stride'\n            )\n        elif section_counts == 'fast27':\n            steps = space_timesteps(num_timesteps, '10,10,3,2,2')\n            # Help reduce DDIM artifacts from noisiest timesteps.\n            steps.remove(num_timesteps - 1)\n            steps.add(num_timesteps - 3)\n            return steps\n        section_counts = [int(x) for x in section_counts.split(',')]\n    size_per = num_timesteps // len(section_counts)\n    extra = num_timesteps % len(section_counts)\n    start_idx = 0\n    all_steps = []\n    for i, section_count in enumerate(section_counts):\n        size = size_per + (1 if i < extra else 0)\n        if size < section_count:\n            raise ValueError(\n                f'cannot divide section of {size} steps into {section_count}'\n            )\n        if section_count <= 1:\n            frac_stride = 1\n        else:\n            frac_stride = (size - 1) / (section_count - 1)\n        cur_idx = 0.0\n        taken_steps = []\n        for _ in range(section_count):\n            taken_steps.append(start_idx + round(cur_idx))\n            cur_idx += frac_stride\n        all_steps += taken_steps\n        start_idx += size\n    return set(all_steps)\n\n\nclass SpacedDiffusion(gd.GaussianDiffusion):\n    \"\"\"\n    A diffusion process which can skip steps in a base diffusion process.\n    :param use_timesteps: a collection (sequence or set) of timesteps from the\n                          original diffusion process to retain.\n    :param kwargs: the kwargs to create the base diffusion process.\n    \"\"\"\n\n    def __init__(self, use_timesteps, **kwargs):\n        self.use_timesteps = set(use_timesteps)\n        self.timestep_map = []\n        self.original_num_steps = len(kwargs['betas'])\n\n        base_diffusion = gd.GaussianDiffusion(**kwargs)  # pylint: disable=missing-kwoa\n        last_alpha_cumprod = 1.0\n        new_betas = []\n        for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):\n            if i in self.use_timesteps:\n                new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)\n                last_alpha_cumprod = alpha_cumprod\n                self.timestep_map.append(i)\n        kwargs['betas'] = np.array(new_betas)\n        super().__init__(**kwargs)\n\n    def p_mean_variance(\n        self, model, *args, **kwargs\n    ):  # pylint: disable=signature-differs\n        return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)\n\n    def training_losses(\n        self, model, *args, **kwargs\n    ):  # pylint: disable=signature-differs\n        return super().training_losses(self._wrap_model(model), *args, **kwargs)\n\n    def _wrap_model(self, model):\n        if isinstance(model, _WrappedModel):\n            return model\n        return _WrappedModel(\n            model, self.timestep_map, self.rescale_timesteps, self.original_num_steps\n        )\n\n    def _scale_timesteps(self, t):\n        # Scaling is done by the wrapped model.\n        return t\n\n\nclass _WrappedModel:\n    def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):\n        self.model = model\n        self.timestep_map = timestep_map\n        self.rescale_timesteps = rescale_timesteps\n        self.original_num_steps = original_num_steps\n\n    def __call__(self, x, ts, **kwargs):\n        map_tensor = torch.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)\n        new_ts = map_tensor[ts]\n        if self.rescale_timesteps:\n            new_ts = new_ts.float() * (1000.0 / self.original_num_steps)\n        return self.model(x, new_ts, **kwargs)\n"
  },
  {
    "path": "deepfloyd_if/model/unet.py",
    "content": "# -*- coding: utf-8 -*-\nimport os\nimport math\nfrom abc import abstractmethod\n\nimport torch\nimport numpy as np\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .nn import avg_pool_nd, conv_nd, linear, normalization, timestep_embedding, zero_module, get_activation, \\\n    AttentionPooling\n\n_FORCE_MEM_EFFICIENT_ATTN = int(os.environ.get('FORCE_MEM_EFFICIENT_ATTN', 0))\nprint('FORCE_MEM_EFFICIENT_ATTN=', _FORCE_MEM_EFFICIENT_ATTN, '@UNET:QKVATTENTION')\nif _FORCE_MEM_EFFICIENT_ATTN:\n    from xformers.ops import memory_efficient_attention  # noqa\n\n\nclass TimestepBlock(nn.Module):\n    \"\"\"\n    Any module where forward() takes timestep embeddings as a second argument.\n    \"\"\"\n\n    @abstractmethod\n    def forward(self, x, emb):\n        \"\"\"\n        Apply the module to `x` given `emb` timestep embeddings.\n        \"\"\"\n\n\nclass TimestepEmbedSequential(nn.Sequential, TimestepBlock):\n    \"\"\"\n    A sequential module that passes timestep embeddings to the children that\n    support it as an extra input.\n    \"\"\"\n\n    def forward(self, x, emb, encoder_out=None):\n        for layer in self:\n            if isinstance(layer, TimestepBlock):\n                x = layer(x, emb)\n            elif isinstance(layer, AttentionBlock):\n                x = layer(x, encoder_out)\n            else:\n                x = layer(x)\n        return x\n\n\nclass Upsample(nn.Module):\n    \"\"\"\n    An upsampling layer with an optional convolution.\n    :param channels: channels in the inputs and outputs.\n    :param use_conv: a bool determining a convolution is applied.\n    :param dims: determines the signal is 1D, 2D, or 3D. If 3D, then\n                 upsampling occurs in the inner-two dimensions.\n    \"\"\"\n\n    def __init__(self, channels, use_conv, dims=2, out_channels=None, dtype=None):\n        super().__init__()\n        self.channels = channels\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.dims = dims\n        self.dtype = dtype\n        if use_conv:\n            self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1, dtype=self.dtype)\n\n    def forward(self, x):\n        assert x.shape[1] == self.channels\n        if self.dims == 3:\n            x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest')\n        else:\n            if self.dtype == torch.bfloat16:\n                x = x.type(torch.float32 if x.device.type == 'cpu' else torch.float16)\n            x = F.interpolate(x, scale_factor=2, mode='nearest')\n            if self.dtype == torch.bfloat16:\n                x = x.type(torch.bfloat16)\n        if self.use_conv:\n            x = self.conv(x)\n        return x\n\n\nclass Downsample(nn.Module):\n    \"\"\"\n    A downsampling layer with an optional convolution.\n    :param channels: channels in the inputs and outputs.\n    :param use_conv: a bool determining a convolution is applied.\n    :param dims: determines the signal is 1D, 2D, or 3D. If 3D, then\n                 downsampling occurs in the inner-two dimensions.\n    \"\"\"\n\n    def __init__(self, channels, use_conv, dims=2, out_channels=None, dtype=None):\n        super().__init__()\n        self.channels = channels\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.dims = dims\n        self.dtype = dtype\n        stride = 2 if dims != 3 else (1, 2, 2)\n        if use_conv:\n            self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1, dtype=self.dtype)\n        else:\n            assert self.channels == self.out_channels\n            self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)\n\n    def forward(self, x):\n        assert x.shape[1] == self.channels\n        return self.op(x)\n\n\nclass ResBlock(TimestepBlock):\n    \"\"\"\n    A residual block that can optionally change the number of channels.\n    :param channels: the number of input channels.\n    :param emb_channels: the number of timestep embedding channels.\n    :param dropout: the rate of dropout.\n    :param out_channels: specified, the number of out channels.\n    :param use_conv: True and out_channels is specified, use a spatial\n        convolution instead of a smaller 1x1 convolution to change the\n        channels in the skip connection.\n    :param dims: determines the signal is 1D, 2D, or 3D.\n    :param up: True, use this block for upsampling.\n    :param down: True, use this block for downsampling.\n    \"\"\"\n\n    def __init__(\n            self,\n            channels,\n            emb_channels,\n            dropout,\n            activation,\n            out_channels=None,\n            use_conv=False,\n            use_scale_shift_norm=False,\n            dims=2,\n            up=False,\n            down=False,\n            dtype=None,\n            efficient_activation=False,\n            scale_skip_connection=False,\n    ):\n        super().__init__()\n        self.dtype = dtype\n        self.channels = channels\n        self.emb_channels = emb_channels\n        self.dropout = dropout\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.use_scale_shift_norm = use_scale_shift_norm\n        self.efficient_activation = efficient_activation\n        self.scale_skip_connection = scale_skip_connection\n\n        self.in_layers = nn.Sequential(\n            normalization(channels, dtype=self.dtype),\n            get_activation(activation),\n            conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=self.dtype),\n        )\n\n        self.updown = up or down\n\n        if up:\n            self.h_upd = Upsample(channels, False, dims, dtype=self.dtype)\n            self.x_upd = Upsample(channels, False, dims, dtype=self.dtype)\n        elif down:\n            self.h_upd = Downsample(channels, False, dims, dtype=self.dtype)\n            self.x_upd = Downsample(channels, False, dims, dtype=self.dtype)\n        else:\n            self.h_upd = self.x_upd = nn.Identity()\n\n        self.emb_layers = nn.Sequential(\n            nn.Identity() if self.efficient_activation else get_activation(activation),\n            linear(\n                emb_channels,\n                2 * self.out_channels if use_scale_shift_norm else self.out_channels,\n                dtype=self.dtype\n            ),\n        )\n        self.out_layers = nn.Sequential(\n            normalization(self.out_channels, dtype=self.dtype),\n            get_activation(activation),\n            nn.Dropout(p=dropout),\n            zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=self.dtype)),\n        )\n\n        if self.out_channels == channels:\n            self.skip_connection = nn.Identity()\n        elif use_conv:\n            self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=self.dtype)\n        else:\n            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=self.dtype)\n\n    def forward(self, x, emb):\n        \"\"\"\n        Apply the block to a Tensor, conditioned on a timestep embedding.\n        :param x: an [N x C x ...] Tensor of features.\n        :param emb: an [N x emb_channels] Tensor of timestep embeddings.\n        :return: an [N x C x ...] Tensor of outputs.\n        \"\"\"\n        if self.updown:\n            in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]\n            h = in_rest(x)\n            h = self.h_upd(h)\n            x = self.x_upd(x)\n            h = in_conv(h)\n        else:\n            h = self.in_layers(x)\n        emb_out = self.emb_layers(emb).type(h.dtype)\n        while len(emb_out.shape) < len(h.shape):\n            emb_out = emb_out[..., None]\n        if self.use_scale_shift_norm:\n            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]\n            scale, shift = torch.chunk(emb_out, 2, dim=1)\n            h = out_norm(h) * (1 + scale) + shift\n            h = out_rest(h)\n        else:\n            h = h + emb_out\n            h = self.out_layers(h)\n\n        res = self.skip_connection(x) + h\n        if self.scale_skip_connection:\n            res *= 0.7071  # 1 / sqrt(2), https://arxiv.org/pdf/2104.07636.pdf\n        return res\n\n\nclass AttentionBlock(nn.Module):\n    \"\"\"\n    An attention block that allows spatial positions to attend to each other.\n    Originally ported from here, but adapted to the N-d case.\n    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.\n    \"\"\"\n\n    def __init__(\n            self,\n            channels,\n            num_heads=1,\n            num_head_channels=-1,\n            disable_self_attention=False,\n            encoder_channels=None,\n            dtype=None,\n    ):\n        super().__init__()\n        self.dtype = dtype\n        self.channels = channels\n        self.disable_self_attention = disable_self_attention\n        if num_head_channels == -1:\n            self.num_heads = num_heads\n        else:\n            assert (\n                channels % num_head_channels == 0\n            ), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}'\n            self.num_heads = channels // num_head_channels\n        self.norm = normalization(channels, dtype=self.dtype)\n        self.qkv = conv_nd(1, channels, channels * 3, 1, dtype=self.dtype)\n        if self.disable_self_attention:\n            self.qkv = conv_nd(1, channels, channels, 1, dtype=self.dtype)\n        else:\n            self.qkv = conv_nd(1, channels, channels * 3, 1, dtype=self.dtype)\n        self.attention = QKVAttention(self.num_heads, disable_self_attention=disable_self_attention)\n\n        if encoder_channels is not None:\n            self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1, dtype=self.dtype)\n            self.norm_encoder = normalization(encoder_channels, dtype=self.dtype)\n        self.proj_out = zero_module(conv_nd(1, channels, channels, 1, dtype=self.dtype))\n\n    def forward(self, x, encoder_out=None):\n        b, c, *spatial = x.shape\n        qkv = self.qkv(self.norm(x).view(b, c, -1))\n        if encoder_out is not None:\n            # from imagen article: https://arxiv.org/pdf/2205.11487.abs\n            encoder_out = self.norm_encoder(encoder_out)\n            # # #\n            encoder_out = self.encoder_kv(encoder_out)\n            h = self.attention(qkv, encoder_out)\n        else:\n            h = self.attention(qkv)\n        h = self.proj_out(h)\n        return x + h.reshape(b, c, *spatial)\n\n\nclass QKVAttention(nn.Module):\n    \"\"\"\n    A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping\n    \"\"\"\n\n    def __init__(self, n_heads, disable_self_attention=False):\n        super().__init__()\n        self.n_heads = n_heads\n        self.disable_self_attention = disable_self_attention\n\n    def forward(self, qkv, encoder_kv=None):\n        \"\"\"\n        Apply QKV attention.\n        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.\n        :return: an [N x (H * C) x T] tensor after attention.\n        \"\"\"\n        bs, width, length = qkv.shape\n        if self.disable_self_attention:\n            ch = width // (1 * self.n_heads)\n            q, = qkv.reshape(bs * self.n_heads, ch * 1, length).split(ch, dim=1)\n        else:\n            assert width % (3 * self.n_heads) == 0\n            ch = width // (3 * self.n_heads)\n            q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)\n        if encoder_kv is not None:\n            assert encoder_kv.shape[1] == self.n_heads * ch * 2\n            if self.disable_self_attention:\n                k, v = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)\n            else:\n                ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)\n                k = torch.cat([ek, k], dim=-1)\n                v = torch.cat([ev, v], dim=-1)\n        scale = 1 / math.sqrt(math.sqrt(ch))\n        if _FORCE_MEM_EFFICIENT_ATTN:\n            q, k, v = map(lambda t: t.permute(0, 2, 1).contiguous(), (q, k, v))\n            a = memory_efficient_attention(q, k, v)\n            a = a.permute(0, 2, 1)\n        else:\n            weight = torch.einsum(\n                'bct,bcs->bts', q * scale, k * scale\n            )  # More stable with f16 than dividing afterwards\n            weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)\n            a = torch.einsum('bts,bcs->bct', weight, v)\n        return a.reshape(bs, -1, length)\n\n\nclass UNetModel(nn.Module):\n    \"\"\"\n    The full UNet model with attention and timestep embedding.\n    :param in_channels: channels in the input Tensor.\n    :param model_channels: base channel count for the model.\n    :param out_channels: channels in the output Tensor.\n    :param num_res_blocks: number of residual blocks per downsample.\n    :param attention_resolutions: a collection of downsample rates at which\n        attention will take place. May be a set, list, or tuple.\n        For example, this contains 4, then at 4x downsampling, attention\n        will be used.\n    :param dropout: the dropout probability.\n    :param channel_mult: channel multiplier for each level of the UNet.\n    :param conv_resample: True, use learned convolutions for upsampling and\n        downsampling.\n    :param dims: determines the signal is 1D, 2D, or 3D.\n    :param num_classes: specified (as an int), then this model will be\n        class-conditional with `num_classes` classes.\n    :param num_heads: the number of attention heads in each attention layer.\n    :param num_heads_channels: specified, ignore num_heads and instead use\n                               a fixed channel width per attention head.\n    :param num_heads_upsample: works with num_heads to set a different number\n                               of heads for upsampling. Deprecated.\n    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.\n    :param resblock_updown: use residual blocks for up/downsampling.\n    \"\"\"\n\n    def __init__(\n            self,\n            in_channels,\n            model_channels,\n            out_channels,\n            num_res_blocks,\n            attention_resolutions,\n            activation,\n            encoder_dim,\n            att_pool_heads,\n            encoder_channels,\n            image_size,\n            disable_self_attentions=None,\n            dropout=0,\n            channel_mult=(1, 2, 4, 8),\n            conv_resample=True,\n            dims=2,\n            num_classes=None,\n            precision='32',\n            num_heads=1,\n            num_head_channels=-1,\n            num_heads_upsample=-1,\n            use_scale_shift_norm=False,\n            resblock_updown=False,\n            efficient_activation=False,\n            scale_skip_connection=False,\n    ):\n        super().__init__()\n\n        if num_heads_upsample == -1:\n            num_heads_upsample = num_heads\n\n        self.encoder_channels = encoder_channels\n        self.encoder_dim = encoder_dim\n        self.efficient_activation = efficient_activation\n        self.scale_skip_connection = scale_skip_connection\n        self.in_channels = in_channels\n        self.model_channels = model_channels\n        self.out_channels = out_channels\n        self.dropout = dropout\n\n        # adapt attention resolutions\n        if isinstance(attention_resolutions, str):\n            self.attention_resolutions = []\n            for res in attention_resolutions.split(','):\n                self.attention_resolutions.append(image_size // int(res))\n        else:\n            self.attention_resolutions = attention_resolutions\n        self.attention_resolutions = tuple(self.attention_resolutions)\n        #\n\n        # adapt disable self attention resolutions\n        if not disable_self_attentions:\n            self.disable_self_attentions = []\n        elif disable_self_attentions is True:\n            self.disable_self_attentions = attention_resolutions\n        elif isinstance(disable_self_attentions, str):\n            self.disable_self_attentions = []\n            for res in disable_self_attentions.split(','):\n                self.disable_self_attentions.append(image_size // int(res))\n        else:\n            self.disable_self_attentions = disable_self_attentions\n        self.disable_self_attentions = tuple(self.disable_self_attentions)\n        #\n\n        # adapt channel mult\n        if isinstance(channel_mult, str):\n            self.channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(','))\n        else:\n            self.channel_mult = tuple(channel_mult)\n        #\n\n        self.conv_resample = conv_resample\n        self.num_classes = num_classes\n        self.dtype = torch.float32\n\n        self.precision = str(precision)\n        self.use_fp16 = precision == '16'\n        if self.precision == '16':\n            self.dtype = torch.float16\n        elif self.precision == 'bf16':\n            self.dtype = torch.bfloat16\n\n        self.num_heads = num_heads\n        self.num_head_channels = num_head_channels\n        self.num_heads_upsample = num_heads_upsample\n\n        self.time_embed_dim = model_channels * max(self.channel_mult)\n        self.time_embed = nn.Sequential(\n            linear(model_channels, self.time_embed_dim, dtype=self.dtype),\n            get_activation(activation),\n            linear(self.time_embed_dim, self.time_embed_dim, dtype=self.dtype),\n        )\n\n        if self.num_classes is not None:\n            self.label_emb = nn.Embedding(num_classes, self.time_embed_dim)\n\n        ch = input_ch = int(self.channel_mult[0] * model_channels)\n        self.input_blocks = nn.ModuleList(\n            [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1, dtype=self.dtype))]\n        )\n        self._feature_size = ch\n        input_block_chans = [ch]\n        ds = 1\n\n        if isinstance(num_res_blocks, int):\n            num_res_blocks = [num_res_blocks]*len(self.channel_mult)\n        self.num_res_blocks = num_res_blocks\n\n        for level, mult in enumerate(self.channel_mult):\n            for _ in range(num_res_blocks[level]):\n                layers = [\n                    ResBlock(\n                        ch,\n                        self.time_embed_dim,\n                        dropout,\n                        out_channels=int(mult * model_channels),\n                        dims=dims,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                        dtype=self.dtype,\n                        activation=activation,\n                        efficient_activation=self.efficient_activation,\n                        scale_skip_connection=self.scale_skip_connection,\n                    )\n                ]\n                ch = int(mult * model_channels)\n                if ds in self.attention_resolutions:\n                    layers.append(\n                        AttentionBlock(\n                            ch,\n                            num_heads=num_heads,\n                            num_head_channels=num_head_channels,\n                            encoder_channels=encoder_channels,\n                            dtype=self.dtype,\n                            disable_self_attention=ds in self.disable_self_attentions,\n                        )\n                    )\n                self.input_blocks.append(TimestepEmbedSequential(*layers))\n                self._feature_size += ch\n                input_block_chans.append(ch)\n            if level != len(self.channel_mult) - 1:\n                out_ch = ch\n                self.input_blocks.append(\n                    TimestepEmbedSequential(\n                        ResBlock(\n                            ch,\n                            self.time_embed_dim,\n                            dropout,\n                            out_channels=out_ch,\n                            dims=dims,\n                            use_scale_shift_norm=use_scale_shift_norm,\n                            down=True,\n                            dtype=self.dtype,\n                            activation=activation,\n                            efficient_activation=self.efficient_activation,\n                            scale_skip_connection=self.scale_skip_connection,\n                        )\n                        if resblock_updown\n                        else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)\n                    )\n                )\n                ch = out_ch\n                input_block_chans.append(ch)\n                ds *= 2\n                self._feature_size += ch\n\n        self.middle_block = TimestepEmbedSequential(\n            ResBlock(\n                ch,\n                self.time_embed_dim,\n                dropout,\n                dims=dims,\n                use_scale_shift_norm=use_scale_shift_norm,\n                dtype=self.dtype,\n                activation=activation,\n                efficient_activation=self.efficient_activation,\n                scale_skip_connection=self.scale_skip_connection,\n            ),\n            AttentionBlock(\n                ch,\n                num_heads=num_heads,\n                num_head_channels=num_head_channels,\n                encoder_channels=encoder_channels,\n                dtype=self.dtype,\n                disable_self_attention=ds in self.disable_self_attentions,\n            ),\n            ResBlock(\n                ch,\n                self.time_embed_dim,\n                dropout,\n                dims=dims,\n                use_scale_shift_norm=use_scale_shift_norm,\n                dtype=self.dtype,\n                activation=activation,\n                efficient_activation=self.efficient_activation,\n                scale_skip_connection=self.scale_skip_connection,\n            ),\n        )\n        self._feature_size += ch\n\n        self.output_blocks = nn.ModuleList([])\n        for level, mult in list(enumerate(self.channel_mult))[::-1]:\n            for i in range(num_res_blocks[level] + 1):\n                ich = input_block_chans.pop()\n                layers = [\n                    ResBlock(\n                        ch + ich,\n                        self.time_embed_dim,\n                        dropout,\n                        out_channels=int(model_channels * mult),\n                        dims=dims,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                        dtype=self.dtype,\n                        activation=activation,\n                        efficient_activation=self.efficient_activation,\n                        scale_skip_connection=self.scale_skip_connection,\n                    )\n                ]\n                ch = int(model_channels * mult)\n                if ds in self.attention_resolutions:\n                    layers.append(\n                        AttentionBlock(\n                            ch,\n                            num_heads=num_heads_upsample,\n                            num_head_channels=num_head_channels,\n                            encoder_channels=encoder_channels,\n                            dtype=self.dtype,\n                            disable_self_attention=ds in self.disable_self_attentions,\n                        )\n                    )\n                if level and i == num_res_blocks[level]:\n                    out_ch = ch\n                    layers.append(\n                        ResBlock(\n                            ch,\n                            self.time_embed_dim,\n                            dropout,\n                            out_channels=out_ch,\n                            dims=dims,\n                            use_scale_shift_norm=use_scale_shift_norm,\n                            up=True,\n                            dtype=self.dtype,\n                            activation=activation,\n                            efficient_activation=self.efficient_activation,\n                            scale_skip_connection=self.scale_skip_connection,\n                        )\n                        if resblock_updown\n                        else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)\n                    )\n                    ds //= 2\n                self.output_blocks.append(TimestepEmbedSequential(*layers))\n                self._feature_size += ch\n\n        self.out = nn.Sequential(\n            normalization(ch, dtype=self.dtype),\n            get_activation(activation),\n            zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1, dtype=self.dtype)),\n        )\n\n        self.activation_layer = get_activation(activation) if self.efficient_activation else nn.Identity()\n\n        self.encoder_pooling = nn.Sequential(\n            nn.LayerNorm(encoder_dim, dtype=self.dtype),\n            AttentionPooling(att_pool_heads, encoder_dim, dtype=self.dtype),\n            nn.Linear(encoder_dim, self.time_embed_dim, dtype=self.dtype),\n            nn.LayerNorm(self.time_embed_dim, dtype=self.dtype)\n        )\n\n        if encoder_dim != encoder_channels:\n            self.encoder_proj = nn.Linear(encoder_dim, encoder_channels, dtype=self.dtype)\n        else:\n            self.encoder_proj = nn.Identity()\n\n        self.cache = None\n\n    def forward(self, x, timesteps, text_emb, timestep_text_emb=None, aug_emb=None, use_cache=False, **kwargs):\n        hs = []\n        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels, dtype=self.dtype))\n\n        if use_cache and self.cache is not None:\n            encoder_out, encoder_pool = self.cache\n        else:\n            text_emb = text_emb.type(self.dtype)\n            encoder_out = self.encoder_proj(text_emb)\n            encoder_out = encoder_out.permute(0, 2, 1)  # NLC -> NCL\n            if timestep_text_emb is None:\n                timestep_text_emb = text_emb\n            encoder_pool = self.encoder_pooling(timestep_text_emb)\n            if use_cache:\n                self.cache = (encoder_out, encoder_pool)\n\n        emb = emb + encoder_pool.to(emb)\n\n        if aug_emb is not None:\n            emb = emb + aug_emb.to(emb)\n\n        emb = self.activation_layer(emb)\n\n        h = x.type(self.dtype)\n        for module in self.input_blocks:\n            h = module(h, emb, encoder_out)\n            hs.append(h)\n        h = self.middle_block(h, emb, encoder_out)\n        for module in self.output_blocks:\n            h = torch.cat([h, hs.pop()], dim=1)\n            h = module(h, emb, encoder_out)\n        h = h.type(self.dtype)\n        h = self.out(h)\n        return h\n\n\nclass SuperResUNetModel(UNetModel):\n    \"\"\"\n    A text2im model that performs super-resolution.\n    Expects an extra kwarg `low_res` to condition on a low-resolution image.\n    \"\"\"\n\n    def __init__(self, low_res_diffusion, interpolate_mode='bilinear', *args, **kwargs):\n        self.low_res_diffusion = low_res_diffusion\n        self.interpolate_mode = interpolate_mode\n        super().__init__(*args, **kwargs)\n\n        self.aug_proj = nn.Sequential(\n            linear(self.model_channels, self.time_embed_dim, dtype=self.dtype),\n            get_activation(kwargs['activation']),\n            linear(self.time_embed_dim, self.time_embed_dim, dtype=self.dtype),\n        )\n\n    def forward(self, x, timesteps, low_res, aug_level=None, **kwargs):\n        bs, _, new_height, new_width = x.shape\n\n        align_corners = True\n        if self.interpolate_mode == 'nearest':\n            align_corners = None\n\n        upsampled = F.interpolate(\n            low_res, (new_height, new_width), mode=self.interpolate_mode, align_corners=align_corners\n        )\n\n        if aug_level is None:\n            aug_steps = (np.random.random(bs)*1000).astype(np.int64)  # uniform [0, 1)\n            aug_steps = torch.from_numpy(aug_steps).to(x.device, dtype=torch.long)\n        else:\n            aug_steps = torch.tensor([int(aug_level * 1000)]).repeat(bs).to(x.device, dtype=torch.long)\n\n        upsampled = self.low_res_diffusion.q_sample(upsampled, aug_steps)\n        x = torch.cat([x, upsampled], dim=1)\n\n        aug_emb = self.aug_proj(\n            timestep_embedding(aug_steps, self.model_channels, dtype=self.dtype)\n        )\n        return super().forward(x, timesteps, aug_emb=aug_emb, **kwargs)\n"
  },
  {
    "path": "deepfloyd_if/modules/__init__.py",
    "content": "# -*- coding: utf-8 -*-\nfrom .stage_I import IFStageI\nfrom .stage_II import IFStageII\nfrom .stage_III import IFStageIII\nfrom .stage_III_sd_x4 import StableStageIII\nfrom .t5 import T5Embedder\nfrom .base import IFBaseModule\n\n__all__ = ['IFBaseModule', 'IFStageI', 'IFStageII', 'IFStageIII', 'StableStageIII', 'T5Embedder']\n"
  },
  {
    "path": "deepfloyd_if/modules/base.py",
    "content": "# -*- coding: utf-8 -*-\nimport os\nimport random\nimport platform\nfrom datetime import datetime\n\nimport torch\nimport torchvision\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport torchvision.transforms as T\nfrom PIL import Image\nfrom omegaconf import OmegaConf\nfrom huggingface_hub import hf_hub_download\nfrom accelerate.utils import set_module_tensor_to_device\n\n\nfrom .. import utils\nfrom ..model.respace import create_gaussian_diffusion\nfrom .utils import load_model_weights, predict_proba, clip_process_generations\n\n\nclass IFBaseModule:\n\n    stage = '-'\n\n    available_models = []\n    cpu_zero_emb = np.load(os.path.join(utils.RESOURCES_ROOT, 'zero_t5-v1_1-xxl_vector.npy'))\n    cpu_zero_emb = torch.from_numpy(cpu_zero_emb)\n\n    respacing_modes = {\n        'fast27': '10,10,3,2,2',\n        'smart27': '7,4,2,1,2,4,7',\n        'smart50': '10,6,4,3,2,2,3,4,6,10',\n        'smart100': '1,1,1,1,2,2,2,2,2,2,3,3,4,4,5,5,6,7,7,8,9,10,13',\n        'smart185': '1,1,2,2,2,3,3,3,4,5,6,7,8,9,10,11,12,13,14,15,16,18,20',\n        'super27': '1,1,1,1,1,1,1,2,5,13',  # for III super-res\n        'super40': '2,2,2,2,2,2,3,4,6,15',  # for III super-res\n        'super100': '4,4,6,6,8,8,10,10,14,30',  # for III super-res\n    }\n\n    wm_pil_img = Image.open(os.path.join(utils.RESOURCES_ROOT, 'wm.png'))\n\n    try:\n        import clip  # noqa\n    except ModuleNotFoundError:\n        print('Warning! You should install CLIP: \"pip install git+https://github.com/openai/CLIP.git --no-deps\"')\n        raise\n\n    clip_model, clip_preprocess = clip.load('ViT-L/14', device='cpu')\n    clip_model.eval()\n\n    cpu_w_weights, cpu_w_biases = load_model_weights(os.path.join(utils.RESOURCES_ROOT, 'w_head_v1.npz'))\n    cpu_p_weights, cpu_p_biases = load_model_weights(os.path.join(utils.RESOURCES_ROOT, 'p_head_v1.npz'))\n    w_threshold, p_threshold = 0.5, 0.5\n\n    def __init__(self, dir_or_name, device, pil_img_size=256, cache_dir=None, hf_token=None):\n        self.hf_token = hf_token\n        self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_')\n        self.dir_or_name = dir_or_name\n        self.conf = self.load_conf(dir_or_name) if not self.use_diffusers else None\n        self.device = torch.device(device)\n        self.zero_emb = self.cpu_zero_emb.clone().to(self.device)\n        self.pil_img_size = pil_img_size\n\n    @property\n    def use_diffusers(self):\n        return False\n\n    def embeddings_to_image(\n        self, t5_embs, low_res=None, *,\n        style_t5_embs=None,\n        positive_t5_embs=None,\n        negative_t5_embs=None,\n        batch_repeat=1,\n        dynamic_thresholding_p=0.95,\n        sample_loop='ddpm',\n        sample_timestep_respacing='smart185',\n        dynamic_thresholding_c=1.5,\n        guidance_scale=7.0,\n        aug_level=0.25,\n        positive_mixer=0.15,\n        blur_sigma=None,\n        img_size=None,\n        img_scale=4.0,\n        aspect_ratio='1:1',\n        progress=True,\n        seed=None,\n        sample_fn=None,\n        support_noise=None,\n        support_noise_less_qsample_steps=0,\n        inpainting_mask=None,\n        **kwargs,\n    ):\n        self._clear_cache()\n        image_w, image_h = self._get_image_sizes(low_res, img_size, aspect_ratio, img_scale)\n        diffusion = self.get_diffusion(sample_timestep_respacing)\n\n        bs_scale = 2 if positive_t5_embs is None else 3\n\n        def model_fn(x_t, ts, **kwargs):\n            half = x_t[: len(x_t) // bs_scale]\n            combined = torch.cat([half]*bs_scale, dim=0)\n            model_out = self.model(combined, ts, **kwargs)\n            eps, rest = model_out[:, :3], model_out[:, 3:]\n            if bs_scale == 3:\n                cond_eps, pos_cond_eps, uncond_eps = torch.split(eps, len(eps) // bs_scale, dim=0)\n                half_eps = uncond_eps + guidance_scale * (\n                    cond_eps * (1 - positive_mixer) + pos_cond_eps * positive_mixer - uncond_eps)\n                pos_half_eps = uncond_eps + guidance_scale * (pos_cond_eps - uncond_eps)\n                eps = torch.cat([half_eps, pos_half_eps, half_eps], dim=0)\n            else:\n                cond_eps, uncond_eps = torch.split(eps, len(eps) // bs_scale, dim=0)\n                half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)\n                eps = torch.cat([half_eps, half_eps], dim=0)\n            return torch.cat([eps, rest], dim=1)\n\n        seed = self.seed_everything(seed)\n\n        text_emb = t5_embs.to(self.device, dtype=self.model.dtype).repeat(batch_repeat, 1, 1)\n        batch_size = text_emb.shape[0] * batch_repeat\n\n        if positive_t5_embs is not None:\n            positive_t5_embs = positive_t5_embs.to(self.device, dtype=self.model.dtype).repeat(batch_repeat, 1, 1)\n\n        if negative_t5_embs is not None:\n            negative_t5_embs = negative_t5_embs.to(self.device, dtype=self.model.dtype).repeat(batch_repeat, 1, 1)\n\n        timestep_text_emb = None\n        if style_t5_embs is not None:\n            list_timestep_text_emb = [\n                style_t5_embs.to(self.device, dtype=self.model.dtype).repeat(batch_repeat, 1, 1),\n            ]\n            if positive_t5_embs is not None:\n                list_timestep_text_emb.append(positive_t5_embs)\n            if negative_t5_embs is not None:\n                list_timestep_text_emb.append(negative_t5_embs)\n            else:\n                list_timestep_text_emb.append(\n                    self.zero_emb.unsqueeze(0).repeat(batch_size, 1, 1).to(self.device, dtype=self.model.dtype))\n            timestep_text_emb = torch.cat(list_timestep_text_emb, dim=0).to(self.device, dtype=self.model.dtype)\n\n        metadata = {\n            'seed': seed,\n            'guidance_scale': guidance_scale,\n            'dynamic_thresholding_p': dynamic_thresholding_p,\n            'dynamic_thresholding_c': dynamic_thresholding_c,\n            'batch_size': batch_size,\n            'device_name': self.device_name,\n            'img_size': [image_w, image_h],\n            'sample_loop': sample_loop,\n            'sample_timestep_respacing': sample_timestep_respacing,\n            'stage': self.stage,\n        }\n\n        list_text_emb = [t5_embs.to(self.device)]\n        if positive_t5_embs is not None:\n            list_text_emb.append(positive_t5_embs.to(self.device))\n        if negative_t5_embs is not None:\n            list_text_emb.append(negative_t5_embs.to(self.device))\n        else:\n            list_text_emb.append(\n                self.zero_emb.unsqueeze(0).repeat(batch_size, 1, 1).to(self.device, dtype=self.model.dtype))\n\n        model_kwargs = dict(\n            text_emb=torch.cat(list_text_emb, dim=0).to(self.device, dtype=self.model.dtype),\n            timestep_text_emb=timestep_text_emb,\n            use_cache=True,\n        )\n        if low_res is not None:\n            if blur_sigma is not None:\n                low_res = T.GaussianBlur(3, sigma=(blur_sigma, blur_sigma))(low_res)\n            model_kwargs['low_res'] = torch.cat([low_res]*bs_scale, dim=0).to(self.device)\n            model_kwargs['aug_level'] = aug_level\n\n        if support_noise is None:\n            noise = torch.randn(\n                (batch_size * bs_scale, 3, image_h, image_w), device=self.device, dtype=self.model.dtype)\n        else:\n            assert support_noise_less_qsample_steps < len(diffusion.timestep_map) - 1\n            assert support_noise.shape == (1, 3, image_h, image_w)\n            q_sample_steps = torch.tensor([int(len(diffusion.timestep_map) - 1 - support_noise_less_qsample_steps)])\n            support_noise = support_noise.cpu()\n            noise = support_noise.clone()\n            noise[inpainting_mask.cpu().bool() if inpainting_mask is not None else ...] = diffusion.q_sample(\n                support_noise[inpainting_mask.cpu().bool() if inpainting_mask is not None else ...],\n                q_sample_steps,\n            )\n            noise = noise.repeat(batch_size*bs_scale, 1, 1, 1).to(device=self.device, dtype=self.model.dtype)\n\n        if inpainting_mask is not None:\n            inpainting_mask = inpainting_mask.to(device=self.device, dtype=torch.long)\n\n        if sample_loop == 'ddpm':\n            with torch.no_grad():\n                sample = diffusion.p_sample_loop(\n                    model_fn,\n                    (batch_size * bs_scale, 3, image_h, image_w),\n                    noise=noise,\n                    clip_denoised=True,\n                    model_kwargs=model_kwargs,\n                    dynamic_thresholding_p=dynamic_thresholding_p,\n                    dynamic_thresholding_c=dynamic_thresholding_c,\n                    inpainting_mask=inpainting_mask,\n                    device=self.device,\n                    progress=progress,\n                    sample_fn=sample_fn,\n                )[:batch_size]\n        elif sample_loop == 'ddim':\n            with torch.no_grad():\n                sample = diffusion.ddim_sample_loop(\n                    model_fn,\n                    (batch_size * bs_scale, 3, image_h, image_w),\n                    noise=noise,\n                    clip_denoised=True,\n                    model_kwargs=model_kwargs,\n                    dynamic_thresholding_p=dynamic_thresholding_p,\n                    dynamic_thresholding_c=dynamic_thresholding_c,\n                    device=self.device,\n                    progress=progress,\n                    sample_fn=sample_fn,\n                )[:batch_size]\n        else:\n            raise ValueError(f'Sample loop \"{sample_loop}\" doesnt support')\n\n        sample = self.__validate_generations(sample)\n        self._clear_cache()\n\n        return sample, metadata\n\n    def load_conf(self, dir_or_name, filename='config.yml'):\n        path = self._get_path_or_download_file_from_hf(dir_or_name, filename)\n        conf = OmegaConf.load(path)\n        return conf\n\n    def load_checkpoint(self, model, dir_or_name, filename='pytorch_model.bin'):\n        path = self._get_path_or_download_file_from_hf(dir_or_name, filename)\n        if os.path.exists(path):\n            checkpoint = torch.load(path, map_location='cpu')\n            param_device = 'cpu'\n            for param_name, param in checkpoint.items():\n                set_module_tensor_to_device(model, param_name, param_device, value=param)\n        else:\n            print(f'Warning! In directory \"{dir_or_name}\" filename \"pytorch_model.bin\" is not found.')\n        return model\n\n    def _get_path_or_download_file_from_hf(self, dir_or_name, filename):\n        if dir_or_name in self.available_models:\n            cache_dir = os.path.join(self.cache_dir, dir_or_name)\n            hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir,\n                            force_filename=filename, token=self.hf_token)\n            return os.path.join(cache_dir, filename)\n        else:\n            return os.path.join(dir_or_name, filename)\n\n    def get_diffusion(self, timestep_respacing):\n        timestep_respacing = self.respacing_modes.get(timestep_respacing, timestep_respacing)\n        diffusion = create_gaussian_diffusion(\n            steps=1000,\n            learn_sigma=True,\n            sigma_small=False,\n            noise_schedule='cosine',\n            use_kl=False,\n            predict_xstart=False,\n            rescale_timesteps=True,\n            rescale_learned_sigmas=True,\n            timestep_respacing=timestep_respacing,\n        )\n        return diffusion\n\n    @staticmethod\n    def seed_everything(seed=None):\n        if seed is None:\n            seed = int((datetime.utcnow().timestamp() * 10 ** 6) % (2 ** 32 - 1))\n        random.seed(seed)\n        os.environ['PYTHONHASHSEED'] = str(seed)\n        np.random.seed(seed)\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed(seed)\n        torch.backends.cudnn.deterministic = True\n        torch.backends.cudnn.benchmark = True\n        return seed\n\n    def device_name(self):\n        if self.device.type == 'cpu':\n            return 'cpu_' + str(platform.processor())\n        if self.device.type == 'cuda':\n            return torch.cuda.get_device_name(self.device)\n        return '-'\n\n    def to_images(self, generations, disable_watermark=False):\n        bs, c, h, w = generations.shape\n        coef = min(h / self.pil_img_size, w / self.pil_img_size)\n        img_h, img_w = (int(h / coef), int(w / coef)) if coef < 1 else (h, w)\n\n        S1, S2 = 1024 ** 2, img_w * img_h\n        K = (S2 / S1) ** 0.5\n        wm_size, wm_x, wm_y = int(K * 62), img_w - int(14 * K), img_h - int(14 * K)\n\n        wm_img = self.wm_pil_img.resize(\n            (wm_size, wm_size), getattr(Image, 'Resampling', Image).BICUBIC, reducing_gap=None)\n\n        pil_images = []\n        for image in ((generations + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu():\n            pil_img = torchvision.transforms.functional.to_pil_image(image).convert('RGB')\n            pil_img = pil_img.resize((img_w, img_h), getattr(Image, 'Resampling', Image).NEAREST)\n            if not disable_watermark:\n                pil_img.paste(wm_img, box=(wm_x - wm_size, wm_y - wm_size, wm_x, wm_y), mask=wm_img.split()[-1])\n            pil_images.append(pil_img)\n        return pil_images\n\n    def show(self, pil_images, nrow=None, size=10):\n        if nrow is None:\n            nrow = round(len(pil_images)**0.5)\n\n        imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow)\n        if not isinstance(imgs, list):\n            imgs = [imgs.cpu()]\n\n        fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(size, size))\n        for i, img in enumerate(imgs):\n            img = img.detach()\n            img = torchvision.transforms.functional.to_pil_image(img)\n            axs[0, i].imshow(np.asarray(img))\n            axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])\n\n        fix.show()\n        plt.show()\n\n    def _clear_cache(self):\n        self.model.cache = None\n\n    def _get_image_sizes(self, low_res, img_size, aspect_ratio, img_scale):\n        if low_res is not None:\n            bs, c, h, w = low_res.shape\n            image_h, image_w = int((h*img_scale)//32)*32, int((w*img_scale//32))*32\n        else:\n            scale_w, scale_h = aspect_ratio.split(':')\n            scale_w, scale_h = int(scale_w), int(scale_h)\n            coef = scale_w / scale_h\n            image_h, image_w = img_size, img_size\n            if coef >= 1:\n                image_w = int(round(img_size/8 * coef) * 8)\n            else:\n                image_h = int(round(img_size/8 / coef) * 8)\n\n        assert image_h % 8 == 0\n        assert image_w % 8 == 0\n\n        return image_w, image_h\n\n    def __validate_generations(self, generations):\n        with torch.no_grad():\n            imgs = clip_process_generations(generations)\n            image_features = self.clip_model.encode_image(imgs.to('cpu'))\n            image_features = image_features.detach().cpu().numpy().astype(np.float16)\n            p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases)\n            w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases)\n            query = p_pred > self.p_threshold\n            if query.sum() > 0:\n                generations[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(generations[query])\n            query = w_pred > self.w_threshold\n            if query.sum() > 0:\n                generations[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(generations[query])\n        return generations\n"
  },
  {
    "path": "deepfloyd_if/modules/stage_I.py",
    "content": "# -*- coding: utf-8 -*-\nimport accelerate\n\nfrom .base import IFBaseModule\nfrom ..model import UNetModel\n\n\nclass IFStageI(IFBaseModule):\n    stage = 'I'\n    available_models = ['IF-I-M-v1.0', 'IF-I-L-v1.0', 'IF-I-XL-v1.0']\n\n    def __init__(self, *args, model_kwargs=None, pil_img_size=64, **kwargs):\n        \"\"\"\n        :param conf_or_path:\n        :param device:\n        :param cache_dir:\n        :param use_auth_token:\n        \"\"\"\n        super().__init__(*args, pil_img_size=pil_img_size, **kwargs)\n        model_params = dict(self.conf.params)\n        model_params.update(model_kwargs or {})\n        with accelerate.init_empty_weights():\n            self.model = UNetModel(**model_params)\n        self.model = self.load_checkpoint(self.model, self.dir_or_name)\n        self.model.eval().to(self.device)\n\n    def embeddings_to_image(self, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None,\n                            batch_repeat=1, dynamic_thresholding_p=0.95, sample_loop='ddpm', positive_mixer=0.25,\n                            sample_timestep_respacing='150', dynamic_thresholding_c=1.5, guidance_scale=7.0,\n                            aspect_ratio='1:1', progress=True, seed=None, sample_fn=None, **kwargs):\n\n        return super().embeddings_to_image(\n            t5_embs=t5_embs,\n            style_t5_embs=style_t5_embs,\n            positive_t5_embs=positive_t5_embs,\n            negative_t5_embs=negative_t5_embs,\n            batch_repeat=batch_repeat,\n            dynamic_thresholding_p=dynamic_thresholding_p,\n            dynamic_thresholding_c=dynamic_thresholding_c,\n            sample_loop=sample_loop,\n            sample_timestep_respacing=sample_timestep_respacing,\n            guidance_scale=guidance_scale,\n            img_size=64,\n            aspect_ratio=aspect_ratio,\n            progress=progress,\n            seed=seed,\n            sample_fn=sample_fn,\n            positive_mixer=positive_mixer,\n            **kwargs\n        )\n"
  },
  {
    "path": "deepfloyd_if/modules/stage_II.py",
    "content": "# -*- coding: utf-8 -*-\nimport accelerate\n\nfrom .base import IFBaseModule\nfrom ..model import SuperResUNetModel\n\n\nclass IFStageII(IFBaseModule):\n    stage = 'II'\n    available_models = ['IF-II-M-v1.0', 'IF-II-L-v1.0']\n\n    def __init__(self, *args, model_kwargs=None, pil_img_size=256, **kwargs):\n        super().__init__(*args, pil_img_size=pil_img_size, **kwargs)\n        model_params = dict(self.conf.params)\n        model_params.update(model_kwargs or {})\n        with accelerate.init_empty_weights():\n            self.model = SuperResUNetModel(low_res_diffusion=self.get_diffusion('1000'), **model_params)\n        self.model = self.load_checkpoint(self.model, self.dir_or_name)\n        self.model.eval().to(self.device)\n\n    def embeddings_to_image(\n            self, low_res, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1,\n            aug_level=0.25, dynamic_thresholding_p=0.95, dynamic_thresholding_c=1.0, sample_loop='ddpm',\n            sample_timestep_respacing='smart50', guidance_scale=4.0, img_scale=4.0, positive_mixer=0.5,\n            progress=True, seed=None, sample_fn=None, **kwargs):\n        return super().embeddings_to_image(\n            t5_embs=t5_embs,\n            low_res=low_res,\n            style_t5_embs=style_t5_embs,\n            positive_t5_embs=positive_t5_embs,\n            negative_t5_embs=negative_t5_embs,\n            batch_repeat=batch_repeat,\n            aug_level=aug_level,\n            dynamic_thresholding_p=dynamic_thresholding_p,\n            dynamic_thresholding_c=dynamic_thresholding_c,\n            sample_loop=sample_loop,\n            sample_timestep_respacing=sample_timestep_respacing,\n            guidance_scale=guidance_scale,\n            positive_mixer=positive_mixer,\n            img_size=256,\n            img_scale=img_scale,\n            progress=progress,\n            seed=seed,\n            sample_fn=sample_fn,\n            **kwargs\n        )\n"
  },
  {
    "path": "deepfloyd_if/modules/stage_III.py",
    "content": "# -*- coding: utf-8 -*-\nimport accelerate\n\nfrom .base import IFBaseModule\nfrom ..model import SuperResUNetModel\n\n\nclass IFStageIII(IFBaseModule):\n\n    available_models = ['IF-III-L-v1.0']\n\n    def __init__(self, *args, model_kwargs=None, pil_img_size=1024, **kwargs):\n        super().__init__(*args, pil_img_size=pil_img_size, **kwargs)\n        model_params = dict(self.conf.params)\n        model_params.update(model_kwargs or {})\n        with accelerate.init_empty_weights():\n            self.model = SuperResUNetModel(low_res_diffusion=self.get_diffusion('1000'), **model_params)\n        self.model = self.load_checkpoint(self.model, self.dir_or_name)\n        self.model.eval().to(self.device)\n\n    def embeddings_to_image(\n            self, low_res, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1,\n            aug_level=0.0, blur_sigma=None, dynamic_thresholding_p=0.95, dynamic_thresholding_c=1.0, positive_mixer=0.5,\n            sample_loop='ddpm', sample_timestep_respacing='super40', guidance_scale=4.0, img_scale=4.0,\n            progress=True, seed=None, sample_fn=None, **kwargs):\n        return super().embeddings_to_image(\n            t5_embs=t5_embs,\n            low_res=low_res,\n            style_t5_embs=style_t5_embs,\n            positive_t5_embs=positive_t5_embs,\n            negative_t5_embs=negative_t5_embs,\n            batch_repeat=batch_repeat,\n            aug_level=aug_level,\n            blur_sigma=blur_sigma,\n            dynamic_thresholding_p=dynamic_thresholding_p,\n            dynamic_thresholding_c=dynamic_thresholding_c,\n            sample_loop=sample_loop,\n            sample_timestep_respacing=sample_timestep_respacing,\n            guidance_scale=guidance_scale,\n            positive_mixer=positive_mixer,\n            img_size=1024,\n            img_scale=img_scale,\n            progress=progress,\n            seed=seed,\n            sample_fn=sample_fn,\n            **kwargs\n        )\n"
  },
  {
    "path": "deepfloyd_if/modules/stage_III_sd_x4.py",
    "content": "# -*- coding: utf-8 -*-\nimport diffusers\nfrom diffusers import DiffusionPipeline, DDPMScheduler\nimport torch\nimport os\n\nfrom .base import IFBaseModule\nimport packaging.version as pv\n\n\nclass StableStageIII(IFBaseModule):\n\n    available_models = ['stable-diffusion-x4-upscaler']\n\n    def __init__(self, *args, model_kwargs=None, pil_img_size=1024, **kwargs):\n        super().__init__(*args, pil_img_size=pil_img_size, **kwargs)\n        if pv.parse(diffusers.__version__) <= pv.parse('0.15.1'):\n            raise ValueError(\n                'Make sure to have `diffusers >= 0.16.0` installed.'\n                ' Please run `pip install diffusers --upgrade`'\n            )\n\n        model_id = os.path.join('stabilityai', self.dir_or_name)\n\n        model_kwargs = model_kwargs or {}\n        precision = str(model_kwargs.get('precision', '16'))\n        if precision == '16':\n            torch_dtype = torch.float16\n        elif precision == 'bf16':\n            torch_dtype = torch.bfloat16\n        else:\n            torch_dtype = torch.float32\n\n        self.model = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype, token=self.hf_token)\n        self.model.to(self.device)\n\n        if bool(os.environ.get('FORCE_MEM_EFFICIENT_ATTN')):\n            self.model.enable_xformers_memory_efficient_attention()\n\n    @property\n    def use_diffusers(self):\n        if self.dir_or_name == self.available_models[-1]:\n            return True\n        elif os.path.isdir(self.dir_or_name) and os.path.isfile(os.path.join(self.dir_or_name, 'model_index.json')):\n            return True\n        return False\n\n    def embeddings_to_image(\n            self, low_res, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1,\n            aug_level=0.0, blur_sigma=None, dynamic_thresholding_p=0.95, dynamic_thresholding_c=1.0, positive_mixer=0.5,\n            sample_loop='ddpm', sample_timestep_respacing='75', guidance_scale=4.0, img_scale=4.0,\n            progress=True, seed=None, sample_fn=None, **kwargs):\n\n        prompt = kwargs.pop('prompt')\n        noise_level = kwargs.pop('noise_level', 20)\n\n        if sample_loop == 'ddpm':\n            self.model.scheduler = DDPMScheduler.from_config(self.model.scheduler.config)\n        else:\n            raise ValueError(f\"For now only the 'ddpm' sample loop type is supported, but you passed {sample_loop}\")\n\n        num_inference_steps = int(sample_timestep_respacing)\n\n        self.model.set_progress_bar_config(disable=not progress)\n\n        generator = torch.manual_seed(seed)\n        prompt = sum([batch_repeat * [p] for p in prompt], [])\n        low_res = low_res.repeat(batch_repeat, 1, 1, 1)\n\n        metadata = {\n            'image': low_res,\n            'prompt': prompt,\n            'noise_level': noise_level,\n            'generator': generator,\n            'guidance_scale': guidance_scale,\n            'num_inference_steps': num_inference_steps,\n            'output_type': 'pt',\n        }\n\n        images = self.model(**metadata).images\n\n        sample = self._IFBaseModule__validate_generations(images)\n\n        return sample, metadata\n"
  },
  {
    "path": "deepfloyd_if/modules/t5.py",
    "content": "# -*- coding: utf-8 -*-\nimport os\nimport re\nimport html\nimport urllib.parse as ul\n\nimport ftfy\nimport torch\nfrom bs4 import BeautifulSoup\nfrom transformers import T5EncoderModel, AutoTokenizer\nfrom huggingface_hub import hf_hub_download\n\n\nclass T5Embedder:\n\n    available_models = ['t5-v1_1-xxl']\n    bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\\)'+'\\('+'\\]'+'\\['+'\\}'+'\\{'+'\\|'+'\\\\'+'\\/'+'\\*' + r']{1,}')  # noqa\n\n    def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, cache_dir=None, hf_token=None, use_text_preprocessing=True,\n                 t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None):\n        self.device = torch.device(device)\n        self.torch_dtype = torch_dtype or torch.bfloat16\n        if t5_model_kwargs is None:\n            t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype}\n            if use_offload_folder is not None:\n                t5_model_kwargs['offload_folder'] = use_offload_folder\n                t5_model_kwargs['device_map'] = {\n                    'shared': self.device,\n                    'encoder.embed_tokens': self.device,\n                    'encoder.block.0': self.device,\n                    'encoder.block.1': self.device,\n                    'encoder.block.2': self.device,\n                    'encoder.block.3': self.device,\n                    'encoder.block.4': self.device,\n                    'encoder.block.5': self.device,\n                    'encoder.block.6': self.device,\n                    'encoder.block.7': self.device,\n                    'encoder.block.8': self.device,\n                    'encoder.block.9': self.device,\n                    'encoder.block.10': self.device,\n                    'encoder.block.11': self.device,\n                    'encoder.block.12': 'disk',\n                    'encoder.block.13': 'disk',\n                    'encoder.block.14': 'disk',\n                    'encoder.block.15': 'disk',\n                    'encoder.block.16': 'disk',\n                    'encoder.block.17': 'disk',\n                    'encoder.block.18': 'disk',\n                    'encoder.block.19': 'disk',\n                    'encoder.block.20': 'disk',\n                    'encoder.block.21': 'disk',\n                    'encoder.block.22': 'disk',\n                    'encoder.block.23': 'disk',\n                    'encoder.final_layer_norm': 'disk',\n                    'encoder.dropout': 'disk',\n                }\n            else:\n                t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device}\n\n        self.use_text_preprocessing = use_text_preprocessing\n        self.hf_token = hf_token\n        self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_')\n        self.dir_or_name = dir_or_name\n\n        tokenizer_path, path = dir_or_name, dir_or_name\n        if dir_or_name in self.available_models:\n            cache_dir = os.path.join(self.cache_dir, dir_or_name)\n            for filename in [\n                'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',\n                'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin'\n            ]:\n                hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir,\n                                force_filename=filename, token=self.hf_token)\n            tokenizer_path, path = cache_dir, cache_dir\n        else:\n            cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl')\n            for filename in [\n                'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',\n            ]:\n                hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir,\n                                force_filename=filename, token=self.hf_token)\n            tokenizer_path = cache_dir\n\n        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)\n        self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()\n\n    def get_text_embeddings(self, texts):\n        texts = [self.text_preprocessing(text) for text in texts]\n\n        text_tokens_and_mask = self.tokenizer(\n            texts,\n            max_length=77,\n            padding='max_length',\n            truncation=True,\n            return_attention_mask=True,\n            add_special_tokens=True,\n            return_tensors='pt'\n        )\n        text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids']\n        text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask']\n\n        with torch.no_grad():\n            text_encoder_embs = self.model(\n                input_ids=text_tokens_and_mask['input_ids'].to(self.device),\n                attention_mask=text_tokens_and_mask['attention_mask'].to(self.device),\n            )['last_hidden_state'].detach()\n\n        return text_encoder_embs\n\n    def text_preprocessing(self, text):\n        if self.use_text_preprocessing:\n            # The exact text cleaning as was in the training stage:\n            text = self.clean_caption(text)\n            text = self.clean_caption(text)\n            return text\n        else:\n            return text.lower().strip()\n\n    @staticmethod\n    def basic_clean(text):\n        text = ftfy.fix_text(text)\n        text = html.unescape(html.unescape(text))\n        return text.strip()\n\n    def clean_caption(self, caption):\n        caption = str(caption)\n        caption = ul.unquote_plus(caption)\n        caption = caption.strip().lower()\n        caption = re.sub('<person>', 'person', caption)\n        # urls:\n        caption = re.sub(\n            r'\\b((?:https?:(?:\\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\\w/-]*\\b\\/?(?!@)))',  # noqa\n            '', caption)  # regex for urls\n        caption = re.sub(\n            r'\\b((?:www:(?:\\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\\w/-]*\\b\\/?(?!@)))',  # noqa\n            '', caption)  # regex for urls\n        # html:\n        caption = BeautifulSoup(caption, features='html.parser').text\n\n        # @<nickname>\n        caption = re.sub(r'@[\\w\\d]+\\b', '', caption)\n\n        # 31C0—31EF CJK Strokes\n        # 31F0—31FF Katakana Phonetic Extensions\n        # 3200—32FF Enclosed CJK Letters and Months\n        # 3300—33FF CJK Compatibility\n        # 3400—4DBF CJK Unified Ideographs Extension A\n        # 4DC0—4DFF Yijing Hexagram Symbols\n        # 4E00—9FFF CJK Unified Ideographs\n        caption = re.sub(r'[\\u31c0-\\u31ef]+', '', caption)\n        caption = re.sub(r'[\\u31f0-\\u31ff]+', '', caption)\n        caption = re.sub(r'[\\u3200-\\u32ff]+', '', caption)\n        caption = re.sub(r'[\\u3300-\\u33ff]+', '', caption)\n        caption = re.sub(r'[\\u3400-\\u4dbf]+', '', caption)\n        caption = re.sub(r'[\\u4dc0-\\u4dff]+', '', caption)\n        caption = re.sub(r'[\\u4e00-\\u9fff]+', '', caption)\n        #######################################################\n\n        # все виды тире / all types of dash --> \"-\"\n        caption = re.sub(\n            r'[\\u002D\\u058A\\u05BE\\u1400\\u1806\\u2010-\\u2015\\u2E17\\u2E1A\\u2E3A\\u2E3B\\u2E40\\u301C\\u3030\\u30A0\\uFE31\\uFE32\\uFE58\\uFE63\\uFF0D]+',  # noqa\n            '-', caption)\n\n        # кавычки к одному стандарту\n        caption = re.sub(r'[`´«»“”¨]', '\"', caption)\n        caption = re.sub(r'[‘’]', \"'\", caption)\n\n        # &quot;\n        caption = re.sub(r'&quot;?', '', caption)\n        # &amp\n        caption = re.sub(r'&amp', '', caption)\n\n        # ip adresses:\n        caption = re.sub(r'\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}', ' ', caption)\n\n        # article ids:\n        caption = re.sub(r'\\d:\\d\\d\\s+$', '', caption)\n\n        # \\n\n        caption = re.sub(r'\\\\n', ' ', caption)\n\n        # \"#123\"\n        caption = re.sub(r'#\\d{1,3}\\b', '', caption)\n        # \"#12345..\"\n        caption = re.sub(r'#\\d{5,}\\b', '', caption)\n        # \"123456..\"\n        caption = re.sub(r'\\b\\d{6,}\\b', '', caption)\n        # filenames:\n        caption = re.sub(r'[\\S]+\\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption)\n\n        #\n        caption = re.sub(r'[\\\"\\']{2,}', r'\"', caption)  # \"\"\"AUSVERKAUFT\"\"\"\n        caption = re.sub(r'[\\.]{2,}', r' ', caption)  # \"\"\"AUSVERKAUFT\"\"\"\n\n        caption = re.sub(self.bad_punct_regex, r' ', caption)  # ***AUSVERKAUFT***, #AUSVERKAUFT\n        caption = re.sub(r'\\s+\\.\\s+', r' ', caption)  # \" . \"\n\n        # this-is-my-cute-cat / this_is_my_cute_cat\n        regex2 = re.compile(r'(?:\\-|\\_)')\n        if len(re.findall(regex2, caption)) > 3:\n            caption = re.sub(regex2, ' ', caption)\n\n        caption = self.basic_clean(caption)\n\n        caption = re.sub(r'\\b[a-zA-Z]{1,3}\\d{3,15}\\b', '', caption)  # jc6640\n        caption = re.sub(r'\\b[a-zA-Z]+\\d+[a-zA-Z]+\\b', '', caption)  # jc6640vc\n        caption = re.sub(r'\\b\\d+[a-zA-Z]+\\d+\\b', '', caption)  # 6640vc231\n\n        caption = re.sub(r'(worldwide\\s+)?(free\\s+)?shipping', '', caption)\n        caption = re.sub(r'(free\\s)?download(\\sfree)?', '', caption)\n        caption = re.sub(r'\\bclick\\b\\s(?:for|on)\\s\\w+', '', caption)\n        caption = re.sub(r'\\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\\simage[s]?)?', '', caption)\n        caption = re.sub(r'\\bpage\\s+\\d+\\b', '', caption)\n\n        caption = re.sub(r'\\b\\d*[a-zA-Z]+\\d+[a-zA-Z]+\\d+[a-zA-Z\\d]*\\b', r' ', caption)  # j2d1a2a...\n\n        caption = re.sub(r'\\b\\d+\\.?\\d*[xх×]\\d+\\.?\\d*\\b', '', caption)\n\n        caption = re.sub(r'\\b\\s+\\:\\s+', r': ', caption)\n        caption = re.sub(r'(\\D[,\\./])\\b', r'\\1 ', caption)\n        caption = re.sub(r'\\s+', ' ', caption)\n\n        caption.strip()\n\n        caption = re.sub(r'^[\\\"\\']([\\w\\W]+)[\\\"\\']$', r'\\1', caption)\n        caption = re.sub(r'^[\\'\\_,\\-\\:;]', r'', caption)\n        caption = re.sub(r'[\\'\\_,\\-\\:\\-\\+]$', r'', caption)\n        caption = re.sub(r'^\\.\\S+$', '', caption)\n\n        return caption.strip()\n"
  },
  {
    "path": "deepfloyd_if/modules/utils.py",
    "content": "# -*- coding: utf-8 -*-\nimport numpy as np\nimport torchvision.transforms as T\n\n\ndef predict_proba(X, weights, biases):\n    logits = X @ weights.T + biases\n    proba = np.where(logits >= 0, 1 / (1 + np.exp(-logits)), np.exp(logits) / (1 + np.exp(logits)))\n    return proba.T\n\n\ndef load_model_weights(path):\n    model_weights = np.load(path)\n    return model_weights['weights'], model_weights['biases']\n\n\ndef clip_process_generations(generations):\n    min_size = min(generations.shape[-2:])\n    return T.Compose([\n        T.CenterCrop(min_size),\n        T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True),\n        T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),\n    ])(generations)\n"
  },
  {
    "path": "deepfloyd_if/pipelines/__init__.py",
    "content": "# -*- coding: utf-8 -*-\nfrom .dream import dream\nfrom .style_transfer import style_transfer\nfrom .super_resolution import super_resolution\nfrom .inpainting import inpainting\n\n__all__ = ['dream', 'style_transfer', 'super_resolution', 'inpainting']\n"
  },
  {
    "path": "deepfloyd_if/pipelines/dream.py",
    "content": "# -*- coding: utf-8 -*-\nfrom datetime import datetime\n\nimport torch\n\n\ndef dream(\n    t5,\n    if_I,\n    if_II=None,\n    if_III=None,\n    *,\n    prompt,\n    style_prompt=None,\n    negative_prompt=None,\n    seed=None,\n    aspect_ratio='1:1',\n    if_I_kwargs=None,\n    if_II_kwargs=None,\n    if_III_kwargs=None,\n    progress=True,\n    return_tensors=False,\n    disable_watermark=False,\n):\n    \"\"\"\n    Generate pictures using text description!\n\n    :param optional dict if_I_kwargs:\n        \"dynamic_thresholding_p\": 0.95, [0.5, 1.0] it controls color saturation on high cfg values\n        \"dynamic_thresholding_c\": 1.5, [1.0, 15.0] clips the limiter to avoid greyish images on high limiter values\n        \"guidance_scale\": 7.0, [1.0, 20.0] control the level of text understanding\n        \"positive_mixer\": 0.25, [0.0, 1.0] contribution of the second positive prompt, 0.0 - minimum, 1.0 - maximum\n        \"sample_timestep_respacing\": \"150\", see available modes IFBaseModule.respacing_modes or use custom\n\n    :param optional dict if_II_kwargs:\n        \"dynamic_thresholding_p\": 0.95, [0.5, 1.0] it controls color saturation on high cfg values\n        \"dynamic_thresholding_c\": 1.0, [1.0, 15.0] clips the limiter to avoid greyish images on high limiter values\n        \"guidance_scale\": 4.0, [1.0, 20.0] control the amount of texture and details in the final image\n        \"aug_level\": 0.25, [0.0, 1.0] adds additional augmentation to generate more realistic images\n        \"positive_mixer\": 0.5, [0.0, 1.0] contribution of the second positive prompt, 0.0 - minimum, 1.0 - maximum\n        \"sample_timestep_respacing\": \"smart50\", see available modes IFBaseModule.respacing_modes or use custom\n\n    :param deepfloyd_if.modules.IFStageI if_I: obj\n    :param deepfloyd_if.modules.IFStageII if_II: obj\n    :param deepfloyd_if.modules.IFStageIII if_III: obj\n    :param deepfloyd_if.modules.T5Embedder t5: obj\n\n    :param int seed: int, in case None will use random value\n    :param aspect_ratio:\n    :param str prompt: text hint/description\n    :param str style_prompt: text hint/description for style\n    :param str negative_prompt: text hint/description for negative prompt, will use it as unconditional emb\n    :param progress:\n    :return:\n    \"\"\"\n    if seed is None:\n        seed = int((datetime.utcnow().timestamp() * 10 ** 6) % (2 ** 32 - 1))\n    if_I.seed_everything(seed)\n\n    if isinstance(prompt, str):\n        prompt = [prompt]\n\n    t5_embs = t5.get_text_embeddings(prompt)\n\n    if_I_kwargs = if_I_kwargs or {}\n    if_I_kwargs['seed'] = seed\n    if_I_kwargs['t5_embs'] = t5_embs\n    if_I_kwargs['aspect_ratio'] = aspect_ratio\n    if_I_kwargs['progress'] = progress\n\n    if style_prompt is not None:\n        if isinstance(style_prompt, str):\n            style_prompt = [style_prompt]\n        style_t5_embs = t5.get_text_embeddings(style_prompt)\n        if_I_kwargs['style_t5_embs'] = style_t5_embs\n        if_I_kwargs['positive_t5_embs'] = style_t5_embs\n\n    if negative_prompt is not None:\n        if isinstance(negative_prompt, str):\n            negative_prompt = [negative_prompt]\n        negative_t5_embs = t5.get_text_embeddings(negative_prompt)\n        if_I_kwargs['negative_t5_embs'] = negative_t5_embs\n\n    stageI_generations, _ = if_I.embeddings_to_image(**if_I_kwargs)\n    pil_images_I = if_I.to_images(stageI_generations, disable_watermark=disable_watermark)\n\n    result = {'I': pil_images_I}\n\n    if if_II is not None:\n        if_II_kwargs = if_II_kwargs or {}\n        if_II_kwargs['low_res'] = stageI_generations\n        if_II_kwargs['seed'] = seed\n        if_II_kwargs['t5_embs'] = t5_embs\n        if_II_kwargs['progress'] = progress\n        if_II_kwargs['style_t5_embs'] = if_I_kwargs.get('style_t5_embs')\n        if_II_kwargs['positive_t5_embs'] = if_I_kwargs.get('positive_t5_embs')\n\n        stageII_generations, _meta = if_II.embeddings_to_image(**if_II_kwargs)\n        pil_images_II = if_II.to_images(stageII_generations, disable_watermark=disable_watermark)\n\n        result['II'] = pil_images_II\n    else:\n        stageII_generations = None\n\n    if if_II is not None and if_III is not None:\n        if_III_kwargs = if_III_kwargs or {}\n\n        stageIII_generations = []\n        for idx in range(len(stageII_generations)):\n            if if_III.use_diffusers:\n                if_III_kwargs['prompt'] = prompt[idx: idx+1]\n\n            if_III_kwargs['low_res'] = stageII_generations[idx:idx+1]\n            if_III_kwargs['seed'] = seed\n            if_III_kwargs['t5_embs'] = t5_embs[idx:idx+1]\n            if_III_kwargs['progress'] = progress\n            style_t5_embs = if_I_kwargs.get('style_t5_embs')\n            if style_t5_embs is not None:\n                style_t5_embs = style_t5_embs[idx:idx+1]\n            positive_t5_embs = if_I_kwargs.get('positive_t5_embs')\n            if positive_t5_embs is not None:\n                positive_t5_embs = positive_t5_embs[idx:idx+1]\n            if_III_kwargs['style_t5_embs'] = style_t5_embs\n            if_III_kwargs['positive_t5_embs'] = positive_t5_embs\n\n            _stageIII_generations, _meta = if_III.embeddings_to_image(**if_III_kwargs)\n            stageIII_generations.append(_stageIII_generations)\n\n        stageIII_generations = torch.cat(stageIII_generations, 0)\n        pil_images_III = if_III.to_images(stageIII_generations, disable_watermark=disable_watermark)\n\n        result['III'] = pil_images_III\n    else:\n        stageIII_generations = None\n\n    if return_tensors:\n        return result, (stageI_generations, stageII_generations, stageIII_generations)\n    else:\n        return result\n"
  },
  {
    "path": "deepfloyd_if/pipelines/inpainting.py",
    "content": "# -*- coding: utf-8 -*-\nfrom datetime import datetime\n\nimport PIL\nimport torch\n\nfrom .utils import _prepare_pil_image\n\n\ndef inpainting(\n    t5,\n    if_I,\n    if_II=None,\n    if_III=None,\n    *,\n    support_pil_img,\n    prompt,\n    inpainting_mask,\n    negative_prompt=None,\n    seed=None,\n    if_I_kwargs=None,\n    if_II_kwargs=None,\n    if_III_kwargs=None,\n    progress=True,\n    return_tensors=False,\n    disable_watermark=False,\n):\n    from skimage.transform import resize  # noqa\n    from skimage import img_as_bool  # noqa\n    assert isinstance(support_pil_img, PIL.Image.Image)\n\n    if seed is None:\n        seed = int((datetime.utcnow().timestamp() * 10 ** 6) % (2 ** 32 - 1))\n\n    t5_embs = t5.get_text_embeddings(prompt)\n\n    if negative_prompt is not None:\n        if isinstance(negative_prompt, str):\n            negative_prompt = [negative_prompt]\n        negative_t5_embs = t5.get_text_embeddings(negative_prompt)\n    else:\n        negative_t5_embs = None\n\n    low_res = _prepare_pil_image(support_pil_img, 64)\n    mid_res = _prepare_pil_image(support_pil_img, 256)\n    high_res = _prepare_pil_image(support_pil_img, 1024)\n\n    result = {}\n\n    _, _, image_h, image_w = low_res.shape\n    if_I_kwargs = if_I_kwargs or {}\n    if_I_kwargs['seed'] = seed\n    if_I_kwargs['progress'] = progress\n    if_I_kwargs['aspect_ratio'] = f'{image_w}:{image_h}'\n\n    if_I_kwargs['t5_embs'] = t5_embs\n    if_I_kwargs['negative_t5_embs'] = negative_t5_embs\n\n    if_I_kwargs['support_noise'] = low_res\n\n    inpainting_mask_I = img_as_bool(resize(inpainting_mask[0].cpu(), (3, image_h, image_w)))\n    inpainting_mask_I = torch.from_numpy(inpainting_mask_I).unsqueeze(0).to(if_I.device)\n\n    if_I_kwargs['inpainting_mask'] = inpainting_mask_I\n\n    stageI_generations, _ = if_I.embeddings_to_image(**if_I_kwargs)\n    pil_images_I = if_I.to_images(stageI_generations, disable_watermark=disable_watermark)\n\n    result['I'] = pil_images_I\n\n    if if_II is not None:\n        _, _, image_h, image_w = mid_res.shape\n\n        if_II_kwargs = if_II_kwargs or {}\n        if_II_kwargs['low_res'] = stageI_generations\n        if_II_kwargs['seed'] = seed\n        if_II_kwargs['t5_embs'] = t5_embs\n        if_II_kwargs['negative_t5_embs'] = negative_t5_embs\n        if_II_kwargs['progress'] = progress\n\n        if_II_kwargs['support_noise'] = mid_res\n\n        if 'inpainting_mask' not in if_II_kwargs:\n            inpainting_mask_II = img_as_bool(resize(inpainting_mask[0].cpu(), (3, image_h, image_w)))\n            inpainting_mask_II = torch.from_numpy(inpainting_mask_II).unsqueeze(0).to(if_II.device)\n            if_II_kwargs['inpainting_mask'] = inpainting_mask_II\n\n        stageII_generations, _meta = if_II.embeddings_to_image(**if_II_kwargs)\n        pil_images_II = if_II.to_images(stageII_generations, disable_watermark=disable_watermark)\n\n        result['II'] = pil_images_II\n    else:\n        stageII_generations = None\n\n    if if_II is not None and if_III is not None:\n        _, _, image_h, image_w = high_res.shape\n        if_III_kwargs = if_III_kwargs or {}\n\n        stageIII_generations = []\n        for idx in range(len(stageII_generations)):\n            if if_III.use_diffusers:\n                if_III_kwargs['prompt'] = prompt[idx: idx+1]\n\n            if_III_kwargs['low_res'] = stageII_generations[idx:idx+1]\n            if_III_kwargs['seed'] = seed\n            if_III_kwargs['t5_embs'] = t5_embs[idx:idx+1]\n            if negative_t5_embs is not None:\n                if_III_kwargs['negative_t5_embs'] = negative_t5_embs[idx:idx+1]\n            if_III_kwargs['progress'] = progress\n            if_III_kwargs['support_noise'] = high_res\n\n            if 'inpainting_mask' not in if_III_kwargs:\n                inpainting_mask_III = img_as_bool(resize(inpainting_mask[0].cpu(), (3, image_h, image_w)))\n                inpainting_mask_III = torch.from_numpy(inpainting_mask_III).unsqueeze(0).to(if_III.device)\n                if_III_kwargs['inpainting_mask'] = inpainting_mask_III\n\n            _stageIII_generations, _meta = if_III.embeddings_to_image(**if_III_kwargs)\n            stageIII_generations.append(_stageIII_generations)\n\n        stageIII_generations = torch.cat(stageIII_generations, 0)\n        pil_images_III = if_III.to_images(stageIII_generations, disable_watermark=disable_watermark)\n\n        result['III'] = pil_images_III\n    else:\n        stageIII_generations = None\n\n    if return_tensors:\n        return result, (stageI_generations, stageII_generations, stageIII_generations)\n    else:\n        return result\n"
  },
  {
    "path": "deepfloyd_if/pipelines/style_transfer.py",
    "content": "# -*- coding: utf-8 -*-\n\nfrom datetime import datetime\n\nimport PIL\nimport torch\n\nfrom .utils import _prepare_pil_image\n\n\ndef style_transfer(\n    t5,\n    if_I,\n    if_II,\n    if_III=None,\n    *,\n    support_pil_img,\n    style_prompt,\n    prompt=None,\n    negative_prompt=None,\n    seed=None,\n    if_I_kwargs=None,\n    if_II_kwargs=None,\n    if_III_kwargs=None,\n    progress=True,\n    return_tensors=False,\n    disable_watermark=False,\n):\n    assert isinstance(support_pil_img, PIL.Image.Image)\n\n    bs = len(style_prompt)\n\n    if seed is None:\n        seed = int((datetime.utcnow().timestamp() * 10 ** 6) % (2 ** 32 - 1))\n\n    if prompt is not None:\n        t5_embs = t5.get_text_embeddings(prompt)\n    else:\n        t5_embs = t5.get_text_embeddings(style_prompt)\n\n    style_t5_embs = t5.get_text_embeddings(style_prompt)\n\n    if negative_prompt is not None:\n        if isinstance(negative_prompt, str):\n            negative_prompt = [negative_prompt]\n        negative_t5_embs = t5.get_text_embeddings(negative_prompt)\n    else:\n        negative_t5_embs = None\n\n    low_res = _prepare_pil_image(support_pil_img, 64)\n    mid_res = _prepare_pil_image(support_pil_img, 256)\n    # high_res = _prepare_pil_image(support_pil_img, 1024)\n\n    result = {}\n    if if_I is not None:\n        _, _, image_h, image_w = low_res.shape\n        if_I_kwargs = if_I_kwargs or {'sample_timestep_respacing': '20,20,20,20,10,0,0,0,0,0'}\n        if_I_kwargs['seed'] = seed\n        if_I_kwargs['progress'] = progress\n        if_I_kwargs['aspect_ratio'] = f'{image_w}:{image_h}'\n\n        if_I_kwargs['t5_embs'] = t5_embs\n        if_I_kwargs['style_t5_embs'] = style_t5_embs\n        if_I_kwargs['positive_t5_embs'] = style_t5_embs\n        if_I_kwargs['negative_t5_embs'] = negative_t5_embs\n\n        if_I_kwargs['support_noise'] = low_res\n\n        stageI_generations, _ = if_I.embeddings_to_image(**if_I_kwargs)\n        pil_images_I = if_I.to_images(stageI_generations, disable_watermark=disable_watermark)\n\n        result['I'] = pil_images_I\n    else:\n        stageI_generations = None\n\n    if if_II is not None:\n        if stageI_generations is None:\n            stageI_generations = low_res.repeat(bs, 1, 1, 1)\n\n        if_II_kwargs = if_II_kwargs or {}\n        if_II_kwargs['low_res'] = stageI_generations\n        if_II_kwargs['seed'] = seed\n        if_II_kwargs['t5_embs'] = t5_embs\n        if_II_kwargs['style_t5_embs'] = style_t5_embs\n        if_II_kwargs['positive_t5_embs'] = style_t5_embs\n        if_II_kwargs['negative_t5_embs'] = negative_t5_embs\n        if_II_kwargs['progress'] = progress\n\n        if_II_kwargs['support_noise'] = mid_res\n\n        stageII_generations, _meta = if_II.embeddings_to_image(**if_II_kwargs)\n        pil_images_II = if_II.to_images(stageII_generations, disable_watermark=disable_watermark)\n\n        result['II'] = pil_images_II\n    else:\n        stageII_generations = None\n\n    if if_II is not None and if_III is not None:\n        if_III_kwargs = if_III_kwargs or {}\n\n        stageIII_generations = []\n        for idx in range(len(stageII_generations)):\n            if if_III.use_diffusers:\n                if_III_kwargs['prompt'] = prompt[idx: idx+1] if prompt is not None else style_prompt[idx: idx+1]\n\n            if_III_kwargs['low_res'] = stageII_generations[idx:idx+1]\n            if_III_kwargs['seed'] = seed\n            if_III_kwargs['t5_embs'] = t5_embs[idx:idx+1]\n            if_III_kwargs['progress'] = progress\n            style_t5_embs = if_II_kwargs.get('style_t5_embs')\n            if style_t5_embs is not None:\n                style_t5_embs = style_t5_embs[idx:idx+1]\n            positive_t5_embs = if_II_kwargs.get('positive_t5_embs')\n            if positive_t5_embs is not None:\n                positive_t5_embs = positive_t5_embs[idx:idx+1]\n            if_III_kwargs['style_t5_embs'] = style_t5_embs\n            if_III_kwargs['positive_t5_embs'] = positive_t5_embs\n\n            _stageIII_generations, _meta = if_III.embeddings_to_image(**if_III_kwargs)\n            stageIII_generations.append(_stageIII_generations)\n\n        stageIII_generations = torch.cat(stageIII_generations, 0)\n        pil_images_III = if_III.to_images(stageIII_generations, disable_watermark=disable_watermark)\n\n        result['III'] = pil_images_III\n    else:\n        stageIII_generations = None\n\n    if return_tensors:\n        return result, (stageI_generations, stageII_generations, stageIII_generations)\n    else:\n        return result\n"
  },
  {
    "path": "deepfloyd_if/pipelines/super_resolution.py",
    "content": "# -*- coding: utf-8 -*-\n\nfrom datetime import datetime\n\nimport PIL\nfrom .utils import _prepare_pil_image\n\n\ndef super_resolution(\n    t5,\n    if_III=None,\n    *,\n    support_pil_img,\n    prompt=None,\n    negative_prompt=None,\n    seed=None,\n    if_III_kwargs=None,\n    progress=True,\n    img_size=256,\n    img_scale=4.0,\n    return_tensors=False,\n    disable_watermark=False,\n):\n    assert isinstance(support_pil_img, PIL.Image.Image)\n    assert img_size % 8 == 0\n\n    if seed is None:\n        seed = int((datetime.utcnow().timestamp() * 10 ** 6) % (2 ** 32 - 1))\n\n    if prompt is not None:\n        t5_embs = t5.get_text_embeddings(prompt)\n    else:\n        t5_embs = t5.get_text_embeddings('')\n\n    if negative_prompt is not None:\n        if isinstance(negative_prompt, str):\n            negative_prompt = [negative_prompt]\n        negative_t5_embs = t5.get_text_embeddings(negative_prompt)\n    else:\n        negative_t5_embs = None\n\n    low_res = _prepare_pil_image(support_pil_img, img_size)\n\n    result = {}\n\n    bs = 1\n    if_III_kwargs = if_III_kwargs or {}\n\n    if if_III.use_diffusers:\n        if_III_kwargs['prompt'] = prompt\n\n    if_III_kwargs['low_res'] = low_res.repeat(bs, 1, 1, 1)\n    if_III_kwargs['seed'] = seed\n    if_III_kwargs['t5_embs'] = t5_embs\n    if_III_kwargs['negative_t5_embs'] = negative_t5_embs\n    if_III_kwargs['progress'] = progress\n    if_III_kwargs['img_scale'] = img_scale\n\n    stageIII_generations, _meta = if_III.embeddings_to_image(**if_III_kwargs)\n    pil_images_III = if_III.to_images(stageIII_generations, disable_watermark=disable_watermark)\n    result['III'] = pil_images_III\n\n    if return_tensors:\n        return result, (stageIII_generations,)\n    else:\n        return result\n"
  },
  {
    "path": "deepfloyd_if/pipelines/utils.py",
    "content": "# -*- coding: utf-8 -*-\n\nimport torch\nimport numpy as np\nfrom PIL import Image\n\n\ndef _prepare_pil_image(raw_pil_img, img_size):\n    raw_pil_img = raw_pil_img.convert('RGB')\n    w, h = raw_pil_img.size\n    coef = w / h\n    image_h, image_w = img_size, img_size\n    if coef >= 1:\n        image_w = int(round(img_size / 8 * coef) * 8)\n    else:\n        image_h = int(round(img_size / 8 / coef) * 8)\n\n    pil_img = raw_pil_img.resize(\n        (image_w, image_h), resample=getattr(Image, 'Resampling', Image).BICUBIC, reducing_gap=None\n    )\n    img = np.array(pil_img)\n    img = img.astype(np.float32) / 127.5 - 1\n    img = np.transpose(img, [2, 0, 1])\n    img = torch.from_numpy(img).unsqueeze(0)\n    return img\n"
  },
  {
    "path": "deepfloyd_if/utils.py",
    "content": "# -*- coding: utf-8 -*-\nfrom os.path import abspath, dirname, join\n\nimport torch\nimport numpy as np\nfrom PIL import Image, ImageFilter\n\nRESOURCES_ROOT = join(abspath(dirname(__file__)), 'resources')\n\n\ndef drop_shadow(image, offset=(5, 5), background=0xffffff, shadow=0x444444, border=8, iterations=3):\n    \"\"\"\n    Drop shadows with PIL.\n    Author: Kevin Schluff\n    License: Python license\n    https://code.activestate.com/recipes/474116/\n\n    Add a gaussian blur drop shadow to an image.\n\n    image       - The image to overlay on top of the shadow.\n    offset      - Offset of the shadow from the image as an (x,y) tuple.  Can be\n                  positive or negative.\n    background  - Background colour behind the image.\n    shadow      - Shadow colour (darkness).\n    border      - Width of the border around the image.  This must be wide\n                  enough to account for the blurring of the shadow.\n    iterations  - Number of times to apply the filter.  More iterations\n                  produce a more blurred shadow, but increase processing time.\n    \"\"\"\n\n    # Create the backdrop image -- a box in the background colour with a\n    # shadow on it.\n    total_width = image.size[0] + abs(offset[0]) + 2 * border\n    total_height = image.size[1] + abs(offset[1]) + 2 * border\n    back = Image.new(image.mode, (total_width, total_height), background)\n\n    # Place the shadow, taking into account the offset from the image\n    shadow_left = border + max(offset[0], 0)\n    shadow_top = border + max(offset[1], 0)\n    back.paste(shadow, [shadow_left, shadow_top, shadow_left + image.size[0], shadow_top + image.size[1]])\n\n    # Apply the filter to blur the edges of the shadow.  Since a small kernel\n    # is used, the filter must be applied repeatedly to get a decent blur.\n    n = 0\n    while n < iterations:\n        back = back.filter(ImageFilter.BLUR)\n        n += 1\n\n    # Paste the input image onto the shadow backdrop\n    image_left = border - min(offset[0], 0)\n    image_top = border - min(offset[1], 0)\n    back.paste(image, (image_left, image_top))\n    return back\n\n\ndef pil_list_to_torch_tensors(pil_images):\n    result = []\n    for pil_image in pil_images:\n        image = np.array(pil_image, dtype=np.uint8)\n        image = torch.from_numpy(image)\n        image = image.permute(2, 0, 1).unsqueeze(0)\n        result.append(image)\n    return torch.cat(result, dim=0)\n"
  },
  {
    "path": "requirements-dev.txt",
    "content": "-r requirements-test.txt\npre-commit\n"
  },
  {
    "path": "requirements-test.txt",
    "content": "-r requirements.txt\npytest\npytest-cov\n"
  },
  {
    "path": "requirements.txt",
    "content": "tqdm\nnumpy\ntorch<2.0.0\ntorchvision\nomegaconf\nmatplotlib\nPillow>=9.2.0\nhuggingface_hub>=0.13.2\ntransformers~=4.25.1\naccelerate~=0.15.0\ndiffusers~=0.16.0\ntokenizers~=0.13.2\nsentencepiece~=0.1.97\nftfy~=6.1.1\nbeautifulsoup4~=4.11.1\n"
  },
  {
    "path": "setup.cfg",
    "content": "[pep8]\nmax-line-length = 120\nexclude = .tox,*migrations*,.json\n\n[flake8]\nmax-line-length = 120\nexclude = .tox,*migrations*,.json\n\n[autopep8-wrapper]\nexclude = .tox,*migrations*,.json\n\n[check-docstring-first]\nexclude = .tox,*migrations*,.json\n"
  },
  {
    "path": "setup.py",
    "content": "# -*- coding: utf-8 -*-\nimport os\nimport re\nfrom setuptools import setup\n\n\ndef read(filename):\n    with open(os.path.join(os.path.dirname(__file__), filename)) as f:\n        file_content = f.read()\n    return file_content\n\n\ndef get_requirements():\n    requirements = []\n    for requirement in read('requirements.txt').splitlines():\n        if requirement.startswith('git+') or requirement.startswith('svn+') or requirement.startswith('hg+'):\n            parsed_requires = re.findall(r'#egg=([\\w\\d\\.]+)-([\\d\\.]+)$', requirement)\n            if parsed_requires:\n                package, version = parsed_requires[0]\n                requirements.append(f'{package}=={version}')\n            else:\n                print('WARNING! For correct matching dependency links need to specify package name and version'\n                      'such as <dependency url>#egg=<package_name>-<version>')\n        else:\n            requirements.append(requirement)\n    return requirements\n\n\ndef get_links():\n    return [\n        requirement for requirement in read('requirements.txt').splitlines()\n        if requirement.startswith('git+') or requirement.startswith('svn+') or requirement.startswith('hg+')\n    ]\n\n\ndef get_version():\n    \"\"\" Get version from the package without actually importing it. \"\"\"\n    init = read('deepfloyd_if/__init__.py')\n    for line in init.split('\\n'):\n        if line.startswith('__version__'):\n            return eval(line.split('=')[1])\n\n\nsetup(\n    name='deepfloyd_if',\n    version=get_version(),\n    author='DeepFloyd, StabilityAI',\n    author_email='shonenkov@gmail.com',\n    description='DeepFloyd-IF (Imagen Free)',\n    packages=['deepfloyd_if', 'deepfloyd_if/model', 'deepfloyd_if/modules', 'deepfloyd_if/pipelines',\n              'deepfloyd_if/resources'],\n    package_data={'deepfloyd_if/resources': ['*.png', '*.npy', '*.npz']},\n    install_requires=get_requirements(),\n    dependency_links=get_links(),\n    long_description=read('README.md'),\n    long_description_content_type='text/markdown',\n)\n"
  }
]