[
  {
    "path": ".gitignore",
    "content": "storage/outputs/*.png\nstorage/init/*.png\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\nlog.txt\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\nwaifu-diffusion/\n"
  },
  {
    "path": "LICENSE",
    "content": "                    GNU GENERAL PUBLIC LICENSE\n                       Version 2, June 1991\n\n Copyright (C) 1989, 1991 Free Software Foundation, Inc.,\n 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA\n Everyone is permitted to copy and distribute verbatim copies\n of this license document, but changing it is not allowed.\n\n                            Preamble\n\n  The licenses for most software are designed to take away your\nfreedom to share and change it.  By contrast, the GNU General Public\nLicense is intended to guarantee your freedom to share and change free\nsoftware--to make sure the software is free for all its users.  This\nGeneral Public License applies to most of the Free Software\nFoundation's software and to any other program whose authors commit to\nusing it.  (Some other Free Software Foundation software is covered by\nthe GNU Lesser General Public License instead.)  You can apply it to\nyour programs, too.\n\n  When we speak of free software, we are referring to freedom, not\nprice.  Our General Public Licenses are designed to make sure that you\nhave the freedom to distribute copies of free software (and charge for\nthis service if you wish), that you receive source code or can get it\nif you want it, that you can change the software or use pieces of it\nin new free programs; and that you know you can do these things.\n\n  To protect your rights, we need to make restrictions that forbid\nanyone to deny you these rights or to ask you to surrender the rights.\nThese restrictions translate to certain responsibilities for you if you\ndistribute copies of the software, or if you modify it.\n\n  For example, if you distribute copies of such a program, whether\ngratis or for a fee, you must give the recipients all the rights that\nyou have.  You must make sure that they, too, receive or can get the\nsource code.  And you must show them these terms so they know their\nrights.\n\n  We protect your rights with two steps: (1) copyright the software, and\n(2) offer you this license which gives you legal permission to copy,\ndistribute and/or modify the software.\n\n  Also, for each author's protection and ours, we want to make certain\nthat everyone understands that there is no warranty for this free\nsoftware.  If the software is modified by someone else and passed on, we\nwant its recipients to know that what they have is not the original, so\nthat any problems introduced by others will not reflect on the original\nauthors' reputations.\n\n  Finally, any free program is threatened constantly by software\npatents.  We wish to avoid the danger that redistributors of a free\nprogram will individually obtain patent licenses, in effect making the\nprogram proprietary.  To prevent this, we have made it clear that any\npatent must be licensed for everyone's free use or not licensed at all.\n\n  The precise terms and conditions for copying, distribution and\nmodification follow.\n\n                    GNU GENERAL PUBLIC LICENSE\n   TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION\n\n  0. This License applies to any program or other work which contains\na notice placed by the copyright holder saying it may be distributed\nunder the terms of this General Public License.  The \"Program\", below,\nrefers to any such program or work, and a \"work based on the Program\"\nmeans either the Program or any derivative work under copyright law:\nthat is to say, a work containing the Program or a portion of it,\neither verbatim or with modifications and/or translated into another\nlanguage.  (Hereinafter, translation is included without limitation in\nthe term \"modification\".)  Each licensee is addressed as \"you\".\n\nActivities other than copying, distribution and modification are not\ncovered by this License; they are outside its scope.  The act of\nrunning the Program is not restricted, and the output from the Program\nis covered only if its contents constitute a work based on the\nProgram (independent of having been made by running the Program).\nWhether that is true depends on what the Program does.\n\n  1. You may copy and distribute verbatim copies of the Program's\nsource code as you receive it, in any medium, provided that you\nconspicuously and appropriately publish on each copy an appropriate\ncopyright notice and disclaimer of warranty; keep intact all the\nnotices that refer to this License and to the absence of any warranty;\nand give any other recipients of the Program a copy of this License\nalong with the Program.\n\nYou may charge a fee for the physical act of transferring a copy, and\nyou may at your option offer warranty protection in exchange for a fee.\n\n  2. You may modify your copy or copies of the Program or any portion\nof it, thus forming a work based on the Program, and copy and\ndistribute such modifications or work under the terms of Section 1\nabove, provided that you also meet all of these conditions:\n\n    a) You must cause the modified files to carry prominent notices\n    stating that you changed the files and the date of any change.\n\n    b) You must cause any work that you distribute or publish, that in\n    whole or in part contains or is derived from the Program or any\n    part thereof, to be licensed as a whole at no charge to all third\n    parties under the terms of this License.\n\n    c) If the modified program normally reads commands interactively\n    when run, you must cause it, when started running for such\n    interactive use in the most ordinary way, to print or display an\n    announcement including an appropriate copyright notice and a\n    notice that there is no warranty (or else, saying that you provide\n    a warranty) and that users may redistribute the program under\n    these conditions, and telling the user how to view a copy of this\n    License.  (Exception: if the Program itself is interactive but\n    does not normally print such an announcement, your work based on\n    the Program is not required to print an announcement.)\n\nThese requirements apply to the modified work as a whole.  If\nidentifiable sections of that work are not derived from the Program,\nand can be reasonably considered independent and separate works in\nthemselves, then this License, and its terms, do not apply to those\nsections when you distribute them as separate works.  But when you\ndistribute the same sections as part of a whole which is a work based\non the Program, the distribution of the whole must be on the terms of\nthis License, whose permissions for other licensees extend to the\nentire whole, and thus to each and every part regardless of who wrote it.\n\nThus, it is not the intent of this section to claim rights or contest\nyour rights to work written entirely by you; rather, the intent is to\nexercise the right to control the distribution of derivative or\ncollective works based on the Program.\n\nIn addition, mere aggregation of another work not based on the Program\nwith the Program (or with a work based on the Program) on a volume of\na storage or distribution medium does not bring the other work under\nthe scope of this License.\n\n  3. You may copy and distribute the Program (or a work based on it,\nunder Section 2) in object code or executable form under the terms of\nSections 1 and 2 above provided that you also do one of the following:\n\n    a) Accompany it with the complete corresponding machine-readable\n    source code, which must be distributed under the terms of Sections\n    1 and 2 above on a medium customarily used for software interchange; or,\n\n    b) Accompany it with a written offer, valid for at least three\n    years, to give any third party, for a charge no more than your\n    cost of physically performing source distribution, a complete\n    machine-readable copy of the corresponding source code, to be\n    distributed under the terms of Sections 1 and 2 above on a medium\n    customarily used for software interchange; or,\n\n    c) Accompany it with the information you received as to the offer\n    to distribute corresponding source code.  (This alternative is\n    allowed only for noncommercial distribution and only if you\n    received the program in object code or executable form with such\n    an offer, in accord with Subsection b above.)\n\nThe source code for a work means the preferred form of the work for\nmaking modifications to it.  For an executable work, complete source\ncode means all the source code for all modules it contains, plus any\nassociated interface definition files, plus the scripts used to\ncontrol compilation and installation of the executable.  However, as a\nspecial exception, the source code distributed need not include\nanything that is normally distributed (in either source or binary\nform) with the major components (compiler, kernel, and so on) of the\noperating system on which the executable runs, unless that component\nitself accompanies the executable.\n\nIf distribution of executable or object code is made by offering\naccess to copy from a designated place, then offering equivalent\naccess to copy the source code from the same place counts as\ndistribution of the source code, even though third parties are not\ncompelled to copy the source along with the object code.\n\n  4. You may not copy, modify, sublicense, or distribute the Program\nexcept as expressly provided under this License.  Any attempt\notherwise to copy, modify, sublicense or distribute the Program is\nvoid, and will automatically terminate your rights under this License.\nHowever, parties who have received copies, or rights, from you under\nthis License will not have their licenses terminated so long as such\nparties remain in full compliance.\n\n  5. You are not required to accept this License, since you have not\nsigned it.  However, nothing else grants you permission to modify or\ndistribute the Program or its derivative works.  These actions are\nprohibited by law if you do not accept this License.  Therefore, by\nmodifying or distributing the Program (or any work based on the\nProgram), you indicate your acceptance of this License to do so, and\nall its terms and conditions for copying, distributing or modifying\nthe Program or works based on it.\n\n  6. Each time you redistribute the Program (or any work based on the\nProgram), the recipient automatically receives a license from the\noriginal licensor to copy, distribute or modify the Program subject to\nthese terms and conditions.  You may not impose any further\nrestrictions on the recipients' exercise of the rights granted herein.\nYou are not responsible for enforcing compliance by third parties to\nthis License.\n\n  7. If, as a consequence of a court judgment or allegation of patent\ninfringement or for any other reason (not limited to patent issues),\nconditions are imposed on you (whether by court order, agreement or\notherwise) that contradict the conditions of this License, they do not\nexcuse you from the conditions of this License.  If you cannot\ndistribute so as to satisfy simultaneously your obligations under this\nLicense and any other pertinent obligations, then as a consequence you\nmay not distribute the Program at all.  For example, if a patent\nlicense would not permit royalty-free redistribution of the Program by\nall those who receive copies directly or indirectly through you, then\nthe only way you could satisfy both it and this License would be to\nrefrain entirely from distribution of the Program.\n\nIf any portion of this section is held invalid or unenforceable under\nany particular circumstance, the balance of the section is intended to\napply and the section as a whole is intended to apply in other\ncircumstances.\n\nIt is not the purpose of this section to induce you to infringe any\npatents or other property right claims or to contest validity of any\nsuch claims; this section has the sole purpose of protecting the\nintegrity of the free software distribution system, which is\nimplemented by public license practices.  Many people have made\ngenerous contributions to the wide range of software distributed\nthrough that system in reliance on consistent application of that\nsystem; it is up to the author/donor to decide if he or she is willing\nto distribute software through any other system and a licensee cannot\nimpose that choice.\n\nThis section is intended to make thoroughly clear what is believed to\nbe a consequence of the rest of this License.\n\n  8. If the distribution and/or use of the Program is restricted in\ncertain countries either by patents or by copyrighted interfaces, the\noriginal copyright holder who places the Program under this License\nmay add an explicit geographical distribution limitation excluding\nthose countries, so that distribution is permitted only in or among\ncountries not thus excluded.  In such case, this License incorporates\nthe limitation as if written in the body of this License.\n\n  9. The Free Software Foundation may publish revised and/or new versions\nof the General Public License from time to time.  Such new versions will\nbe similar in spirit to the present version, but may differ in detail to\naddress new problems or concerns.\n\nEach version is given a distinguishing version number.  If the Program\nspecifies a version number of this License which applies to it and \"any\nlater version\", you have the option of following the terms and conditions\neither of that version or of any later version published by the Free\nSoftware Foundation.  If the Program does not specify a version number of\nthis License, you may choose any version ever published by the Free Software\nFoundation.\n\n  10. If you wish to incorporate parts of the Program into other free\nprograms whose distribution conditions are different, write to the author\nto ask for permission.  For software which is copyrighted by the Free\nSoftware Foundation, write to the Free Software Foundation; we sometimes\nmake exceptions for this.  Our decision will be guided by the two goals\nof preserving the free status of all derivatives of our free software and\nof promoting the sharing and reuse of software generally.\n\n                            NO WARRANTY\n\n  11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY\nFOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW.  EXCEPT WHEN\nOTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES\nPROVIDE THE PROGRAM \"AS IS\" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED\nOR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF\nMERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE.  THE ENTIRE RISK AS\nTO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU.  SHOULD THE\nPROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING,\nREPAIR OR CORRECTION.\n\n  12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING\nWILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR\nREDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES,\nINCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING\nOUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED\nTO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY\nYOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER\nPROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE\nPOSSIBILITY OF SUCH DAMAGES.\n\n                     END OF TERMS AND CONDITIONS\n\n            How to Apply These Terms to Your New Programs\n\n  If you develop a new program, and you want it to be of the greatest\npossible use to the public, the best way to achieve this is to make it\nfree software which everyone can redistribute and change under these terms.\n\n  To do so, attach the following notices to the program.  It is safest\nto attach them to the start of each source file to most effectively\nconvey the exclusion of warranty; and each file should have at least\nthe \"copyright\" line and a pointer to where the full notice is found.\n\n    <one line to give the program's name and a brief idea of what it does.>\n    Copyright (C) <year>  <name of author>\n\n    This program is free software; you can redistribute it and/or modify\n    it under the terms of the GNU General Public License as published by\n    the Free Software Foundation; either version 2 of the License, or\n    (at your option) any later version.\n\n    This program is distributed in the hope that it will be useful,\n    but WITHOUT ANY WARRANTY; without even the implied warranty of\n    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n    GNU General Public License for more details.\n\n    You should have received a copy of the GNU General Public License along\n    with this program; if not, write to the Free Software Foundation, Inc.,\n    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.\n\nAlso add information on how to contact you by electronic and paper mail.\n\nIf the program is interactive, make it output a short notice like this\nwhen it starts in an interactive mode:\n\n    Gnomovision version 69, Copyright (C) year name of author\n    Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'.\n    This is free software, and you are welcome to redistribute it\n    under certain conditions; type `show c' for details.\n\nThe hypothetical commands `show w' and `show c' should show the appropriate\nparts of the General Public License.  Of course, the commands you use may\nbe called something other than `show w' and `show c'; they could even be\nmouse-clicks or menu items--whatever suits your program.\n\nYou should also get your employer (if you work as a programmer) or your\nschool, if any, to sign a \"copyright disclaimer\" for the program, if\nnecessary.  Here is a sample; alter the names:\n\n  Yoyodyne, Inc., hereby disclaims all copyright interest in the program\n  `Gnomovision' (which makes passes at compilers) written by James Hacker.\n\n  <signature of Ty Coon>, 1 April 1989\n  Ty Coon, President of Vice\n\nThis General Public License does not permit incorporating your program into\nproprietary programs.  If your program is a subroutine library, you may\nconsider it more useful to permit linking proprietary applications with the\nlibrary.  If this is what you want to do, use the GNU Lesser General\nPublic License instead of this License.\n"
  },
  {
    "path": "README.md",
    "content": "# Shanghai - AI Powered Art in a Discord Bot!\n\n<img src=https://cdn.discordapp.com/attachments/971549874514444358/1012400070559277086/1502073419.png?3867929 width=50% height=50%>\n\n### Any questions or need help? Come hop on by to our Discord server!\n\n[![Discord Server](https://discordapp.com/api/guilds/930499730843250783/widget.png?style=banner2)](https://discord.gg/Sx6Spmsgx7)\n\n\n## Setup\nMake sure you have the [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads) installed\n\nClone the repository and enter it\n````\ngit clone https://github.com/harubaru/discord-stable-diffusion.git\ncd discord-stable-diffusion\n````\n\n#### WINDOWS SETUP\nRun `setup.bat`. If you run into any errors, try running the file as administrator\n\nIf you are on a Windows 10 system, run `win10patch.bat`\n\nModify the `run.bat` file, where\n* `--model_path` is the path to the model (make sure to replace any backslashes with double backslashes),\n* `--token=` is the token to the Discord bot\n* `--hf_token=` is your huggingface token (can be found [here](https://huggingface.co/settings/tokens))\n\nRun the `run.bat` file\n#### LINUX SETUP\nRun `./setup.sh`. If you run into any errors, try using `sudo ./setup.sh`\n\nModify the `run.sh` file, where\n* `--model_path` is the path to the model,\n* `--token=` is the token to the Discord bot\n* `--hf_token=` is your huggingface token (can be found [here](https://huggingface.co/settings/tokens))\n\nRun `./run.sh`\n\n### Quickstart\n#### Text to Image\n\nTo generate an image from text, use the ``/dream`` command and include your prompt as the query. There's tons of parameters to play with so go wild!\n\n![image](https://user-images.githubusercontent.com/26317155/186722689-3cbca12a-531c-47f7-b87f-99918e9ed232.png)\n\n![image](https://user-images.githubusercontent.com/26317155/186721768-3684f629-90c3-4ef2-82b8-1310200df437.png)\n\n\n#### Image to Image\n\nTo generate an image from another image, use the ``/dream`` command and include the `init_image` and `strength` parameters. The image needs to be attached to the message.\n\n![image](https://user-images.githubusercontent.com/26317155/186722463-ec3a6d24-36c1-48f8-b09a-57651706848c.png)\n\n![image](https://user-images.githubusercontent.com/26317155/186722528-7e652a21-fd02-4071-9fc1-87a31dfb6e63.png)\n\n\n#### (Experimental) Inpainting\n\nTo fill in a mask in an image, supply a prompt, the `init_image`, `mask_image` and `strength` parameters. The mask needs to consist of black pixels in a transparent image.\n\n![image](https://user-images.githubusercontent.com/26317155/186722970-71a662dc-16a8-4bb4-8696-3bafb3e08e65.png)\n\n"
  },
  {
    "path": "__main__.py",
    "content": "import os\nimport sys\nimport argparse\nimport asyncio\nfrom src.core.logging import get_logger\nfrom src.bot.shanghai import Shanghai\n\nlogger = get_logger(__name__)\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description='Shanghai - A Discord bot for AI powered utilities.',\n        usage='shanghai [arguments]'\n    )\n\n    parser.add_argument('--prefix', type=str, help='The prefix to use for commands.', default='s!')\n    parser.add_argument('--token', type=str, help='The token to use for authentication.')\n    parser.add_argument('--hf_token', type=str, help='The token to use for HuggingFace authentication.', default=None)\n    parser.add_argument('--model_path', type=str, help='Path to the model.', default=None)\n\n    return parser.parse_args()\n\nasync def shutdown(bot):\n    await bot.close()\n\ndef main():\n    shanghai = None\n    args = parse_args()\n    \n    try:\n        shanghai = Shanghai(args)\n        logger.info('Executing bot.')\n        shanghai.run(args.token)\n    except KeyboardInterrupt:\n        logger.info('Keyboard interrupt received. Exiting.')\n        asyncio.run(shutdown(shanghai))\n    except SystemExit:\n        logger.info('System exit received. Exiting.')\n        asyncio.run(shutdown(shanghai))\n    except Exception as e:\n        logger.error(e)\n        asyncio.run(shutdown(shanghai))\n    finally:\n        sys.exit(0)\n\nif __name__ == '__main__':\n    main()"
  },
  {
    "path": "models/.keep",
    "content": "壊れたカーテンの隙間から\n壁を埋めるのは\n暴言？妄言？知りません。"
  },
  {
    "path": "models/v1-inference.yaml",
    "content": "model:\n  base_learning_rate: 1.0e-04\n  target: src.stablediffusion.ldm.models.diffusion.ddpm.LatentDiffusion\n  params:\n    linear_start: 0.00085\n    linear_end: 0.0120\n    num_timesteps_cond: 1\n    log_every_t: 200\n    timesteps: 1000\n    first_stage_key: \"jpg\"\n    cond_stage_key: \"txt\"\n    image_size: 64\n    channels: 4\n    cond_stage_trainable: false   # Note: different from the one we trained before\n    conditioning_key: crossattn\n    monitor: val/loss_simple_ema\n    scale_factor: 0.18215\n    use_ema: False\n\n    scheduler_config: # 10000 warmup steps\n      target: src.stablediffusion.ldm.lr_scheduler.LambdaLinearScheduler\n      params:\n        warm_up_steps: [ 10000 ]\n        cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases\n        f_start: [ 1.e-6 ]\n        f_max: [ 1. ]\n        f_min: [ 1. ]\n\n    personalization_config:\n      target: src.stablediffusion.ldm.modules.embedding_manager.EmbeddingManager\n      params:\n        placeholder_strings: [\"*\"]\n        initializer_words: [\"sculpture\"]\n        per_image_tokens: false\n        num_vectors_per_token: 1\n        progressive_words: False\n\n    unet_config:\n      target: src.stablediffusion.ldm.modules.diffusionmodules.openaimodel.UNetModel\n      params:\n        image_size: 32 # unused\n        in_channels: 4\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [ 4, 2, 1 ]\n        num_res_blocks: 2\n        channel_mult: [ 1, 2, 4, 4 ]\n        num_heads: 8\n        use_spatial_transformer: True\n        transformer_depth: 1\n        context_dim: 768\n        use_checkpoint: True\n        legacy: False\n\n    first_stage_config:\n      target: src.stablediffusion.ldm.models.autoencoder.AutoencoderKL\n      params:\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          double_z: true\n          z_channels: 4\n          resolution: 256\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult:\n          - 1\n          - 2\n          - 4\n          - 4\n          num_res_blocks: 2\n          attn_resolutions: []\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity\n\n    cond_stage_config:\n      target: src.stablediffusion.ldm.modules.encoders.modules.FrozenCLIPEmbedder\n"
  },
  {
    "path": "requirements.txt",
    "content": "--extra-index-url https://download.pytorch.org/whl/cu117\ntorch\ndiffusers\nnumpy\nPillow\npydantic\ngit+https://github.com/Pycord-Development/pycord\nomegaconf==2.1.1\npytorch-lightning==1.4.2\ntaming-transformers-rom1504==0.0.6\ntest-tube>=0.7.5\ntorch-fidelity==0.3.0\ntorchmetrics==0.6.0\ntransformers==4.19.2\ngit+https://github.com/openai/CLIP.git@main#egg=clip\ngit+https://github.com/lstein/k-diffusion.git@master#egg=k-diffusion\n"
  },
  {
    "path": "run.bat",
    "content": "venv\\Scripts\\python.exe . --model_path \"\" --token=\"\""
  },
  {
    "path": "run.sh",
    "content": "venv/bin/python . --model_path \"\" --token=\"\" --hf_token=\"\"\n"
  },
  {
    "path": "setup.bat",
    "content": "python -m venv venv\nvenv\\Scripts\\pip.exe install -r requirements.txt"
  },
  {
    "path": "setup.sh",
    "content": "python -m venv venv\nvenv/bin/pip install -r requirements.txt\n"
  },
  {
    "path": "src/bot/shanghai.py",
    "content": "import asyncio\nimport os\nfrom abc import ABC\n\nimport discord\nfrom discord.ext import commands\nfrom src.core.logging import get_logger\n\n\nclass Shanghai(commands.Bot, ABC):\n    def __init__(self, args):\n        intents = discord.Intents.default()\n        intents.members = True\n        super().__init__(command_prefix=args.prefix, intents=intents)\n        self.args = args\n        self.logger = get_logger(__name__)\n        self.load_extension('src.bot.stablecog')\n\n    async def on_ready(self):\n        self.logger.info(f'Logged in as {self.user.name} ({self.user.id})')\n        await self.change_presence(\n            activity=discord.Activity(type=discord.ActivityType.watching, name='you over the seven seas.'))\n\n    async def on_message(self, message):\n        if message.author == self.user:\n            try:\n                # Check if the message from Shanghai was actually a generation\n                if message.embeds[0].fields[0].name == 'command':\n                    await message.add_reaction('❌')\n            except:\n                pass\n\n    async def on_raw_reaction_add(self, ctx):\n        if ctx.emoji.name == '❌':\n            message = await self.get_channel(ctx.channel_id).fetch_message(ctx.message_id)\n            if message.embeds:\n                # look at the message footer to see if the generation was by the user who reacted\n                if message.embeds[0].footer.text == f'{ctx.member.name}#{ctx.member.discriminator}':\n                    await message.delete()\n"
  },
  {
    "path": "src/bot/stablecog.py",
    "content": "import traceback\nfrom asyncio import AbstractEventLoop\nfrom threading import Thread\n\nimport requests\nimport asyncio\nimport discord\nfrom discord.ext import commands\nfrom typing import Optional\nfrom io import BytesIO\nfrom PIL import Image\nfrom discord import option\nimport random\nimport time\n\nfrom src.stablediffusion.text2image_compvis import Text2Image\n\nembed_color = discord.Colour.from_rgb(215, 195, 134)\n\n\nclass QueueObject:\n    def __init__(self, ctx, prompt, height, width, guidance_scale, steps, seed, strength,\n                 init_image, mask_image, sampler_name, command_str):\n        self.ctx = ctx\n        self.prompt = prompt\n        self.height = height\n        self.width = width\n        self.guidance_scale = guidance_scale\n        self.steps = steps\n        self.seed = seed\n        self.strength = strength\n        self.init_image = init_image\n        self.mask_image = mask_image\n        self.sampler_name = sampler_name\n        self.command_str = command_str\n\n\nclass StableCog(commands.Cog, name='Stable Diffusion', description='Create images from natural language.'):\n    def __init__(self, bot):\n        self.dream_thread = Thread()\n        self.text2image_model = Text2Image(model_path=bot.args.model_path)\n        self.event_loop = asyncio.get_event_loop()\n        self.queue = []\n        self.bot = bot\n\n\n    @commands.slash_command(name='dream', description='Create an image.')\n    @option(\n        'prompt',\n        str,\n        description='A prompt to condition the model with.',\n        required=True,\n    )\n    @option(\n        'height',\n        int,\n        description='Height of the generated image.',\n        required=False,\n        choices=[x for x in range(192, 832, 64)]\n    )\n    @option(\n        'width',\n        int,\n        description='Width of the generated image.',\n        required=False,\n        choices=[x for x in range(192, 832, 64)]\n    )\n    @option(\n        'guidance_scale',\n        float,\n        description='Classifier-Free Guidance scale',\n        required=False,\n    )\n    @option(\n        'steps',\n        int,\n        description='The amount of steps to sample the model',\n        required=False,\n        choices=[x for x in range(5, 55, 5)]\n    )\n    @option(\n        'sampler',\n        str,\n        description='The sampler to use for generation',\n        required=False,\n        choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'],\n        default='ddim'\n    )\n    @option(\n        'seed',\n        int,\n        description='The seed to use for reproduceability',\n        required=False,\n    )\n    @option(\n        'strength',\n        float,\n        description='The strength (0.0 to 1.0) used to apply the prompt to the init_image/mask_image'\n    )\n    @option(\n        'init_image',\n        discord.Attachment,\n        description='The image to initialize the latents with for denoising',\n        required=False,\n    )\n    @option(\n        'mask_image',\n        discord.Attachment,\n        description='The mask image to use for inpainting',\n        required=False,\n    )\n    async def dream_handler(self, ctx: discord.ApplicationContext, *, prompt: str, height: Optional[int] = 512,\n                            width: Optional[int] = 512, guidance_scale: Optional[float] = 7.0,\n                            steps: Optional[int] = 30,\n                            sampler: Optional[str] = 'k_euler_a',\n                            seed: Optional[int] = -1, strength: Optional[float] = None,\n                            init_image: Optional[discord.Attachment] = None,\n                            mask_image: Optional[discord.Attachment] = None):\n        print(f'Request -- {ctx.author.name}#{ctx.author.discriminator} -- Prompt: {prompt}')\n\n        if seed == -1:\n            seed = random.randint(0, 0xFFFFFFFF)\n\n        command_str = '/dream'\n        command_str = command_str + f' prompt:{prompt} height:{str(height)} width:{width} guidance_scale:{guidance_scale} steps:{steps} sampler:{sampler} seed:{seed}'\n        if init_image or mask_image:\n            command_str = command_str + f' strength:{strength}'\n\n        if self.dream_thread.is_alive():\n            user_already_in_queue = False\n            for queue_object in self.queue:\n                if queue_object.ctx.author.id == ctx.author.id:\n                    user_already_in_queue = True\n                    break\n            if user_already_in_queue:\n                await ctx.send_response(\n                    content=f'Please wait for your current image to finish generating before generating a new image',\n                    ephemeral=True)\n            else:\n                self.queue.append(QueueObject(ctx, prompt, height, width, guidance_scale, steps, seed,\n                                              strength,\n                                              init_image, mask_image, sampler, command_str))\n                await ctx.send_response(\n                    content=f'Dreaming for <@{ctx.author.id}> - Queue Position: ``{len(self.queue)}`` - ``{command_str}``')\n        else:\n            await self.process_dream(QueueObject(ctx, prompt, height, width, guidance_scale, steps, seed,\n                                                 strength,\n                                                 init_image, mask_image, sampler, command_str))\n            await ctx.send_response(\n                content=f'Dreaming for <@{ctx.author.id}> - Queue Position: ``{len(self.queue)}`` - ``{command_str}``')\n\n    async def process_dream(self, queue_object: QueueObject):\n        self.dream_thread = Thread(target=self.dream,\n                                   args=(self.event_loop, queue_object))\n        self.dream_thread.start()\n\n    def dream(self, event_loop: AbstractEventLoop, queue_object: QueueObject):\n        try:\n            start_time = time.time()\n            if (queue_object.init_image is None) and (queue_object.mask_image is None):\n                samples, seed = self.text2image_model.dream(queue_object.prompt, queue_object.steps, False, False, 0.0,\n                                                            1, 1, queue_object.guidance_scale, queue_object.seed,\n                                                            queue_object.height, queue_object.width, False,\n                                                            queue_object.sampler_name)\n            elif queue_object.init_image is not None:\n                image = Image.open(requests.get(queue_object.init_image.url, stream=True).raw).convert('RGB')\n                samples, seed = self.text2image_model.translation(queue_object.prompt, image, queue_object.steps, 0.0,\n                                                                  0,\n                                                                  0, queue_object.guidance_scale,\n                                                                  queue_object.strength, queue_object.seed,\n                                                                  queue_object.height, queue_object.width,\n                                                                  queue_object.sampler_name)\n            else:\n                image = Image.open(requests.get(queue_object.init_image.url, stream=True).raw).convert('RGB')\n                mask = Image.open(requests.get(queue_object.mask_image.url, stream=True).raw).convert('RGB')\n                samples, seed = self.text2image_model.inpaint(queue_object.prompt, image, mask, queue_object.steps, 0.0,\n                                                              1, 1, queue_object.guidance_scale,\n                                                              denoising_strength=queue_object.strength,\n                                                              seed=queue_object.seed, height=queue_object.height,\n                                                              width=queue_object.width,\n                                                              sampler_name=queue_object.sampler_name)\n            end_time = time.time()\n\n            with BytesIO() as buffer:\n                samples[0].save(buffer, 'PNG')\n                buffer.seek(0)\n                embed = discord.Embed()\n                embed.colour = embed_color\n                embed.add_field(name='command', value=f'``{queue_object.command_str}``', inline=False)\n                embed.add_field(name='compute used', value='``{0:.3f}`` seconds'.format(end_time - start_time),\n                                inline=False)\n                embed.add_field(name='delete', value='React with ❌ to delete your own generation')\n                # fix errors if user doesn't have pfp\n                if queue_object.ctx.author.avatar is None:\n                    embed.set_footer(\n                        text=f'{queue_object.ctx.author.name}#{queue_object.ctx.author.discriminator}')\n                else:\n                    embed.set_footer(\n                        text=f'{queue_object.ctx.author.name}#{queue_object.ctx.author.discriminator}',\n                        icon_url=queue_object.ctx.author.avatar.url)\n\n                event_loop.create_task(\n                    queue_object.ctx.channel.send(content=f'<@{queue_object.ctx.author.id}>', embed=embed,\n                                                  file=discord.File(fp=buffer, filename=f'{seed}.png')))\n        except Exception as e:\n            embed = discord.Embed(title='txt2img failed', description=f'{e}\\n{traceback.print_exc()}',\n                                  color=embed_color)\n            event_loop.create_task(queue_object.ctx.channel.send(embed=embed))\n        if self.queue:\n            event_loop.create_task(self.process_dream(self.queue.pop(0)))\n\n\ndef setup(bot):\n    bot.add_cog(StableCog(bot))\n"
  },
  {
    "path": "src/core/logging.py",
    "content": "import logging\n\nlogging.basicConfig(level=logging.INFO,\n                    format='[%(asctime)s] %(levelname)s: %(message)s',\n                    datefmt='%Y-%m-%d %H:%M:%S')\n\ndef get_logger(name):\n    return logging.getLogger(name)"
  },
  {
    "path": "src/scripts/win10patch.py",
    "content": "try:\n    file_path = 'venv\\\\lib\\\\site-packages\\\\torch\\\\distributed\\\\elastic\\\\timer\\\\file_based_local_timer.py'\n    with open(file_path, 'r+') as file:\n        old = file.read()\n        if 'SIGKILL' not in old:\n            print(file_path + ' already patched!')\n            exit(0)\n        file.seek(0)\n        file.write(old.replace('SIGKILL', 'SIGINT'))\n    print('Patched ' + file_path)\nexcept Exception as e:\n    print('Patch failed! Please report this either on github or to salt#7234\\nReason: ' + str(e))\n"
  },
  {
    "path": "src/stablediffusion/dream.py",
    "content": "import inspect\nimport warnings\nfrom typing import List, Optional, Union\n\nimport torch\n\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import AutoencoderKL, UNet2DConditionModel, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler\n\n\nfrom PIL import Image\n\n\nclass StableDiffusionPipeline(DiffusionPipeline):\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]\n    ):\n        super().__init__()\n        scheduler = scheduler.set_format(\"pt\")\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n        )\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]],\n        height: Optional[int] = 512,\n        width: Optional[int] = 512,\n        num_inference_steps: Optional[int] = 50,\n        guidance_scale: Optional[float] = 7.5,\n        eta: Optional[float] = 0.0,\n        generator: Optional[torch.Generator] = None,\n        output_type: Optional[str] = \"pil\",\n        progress: Optional[bool] = False,\n        **kwargs,\n    ):\n        if \"torch_device\" in kwargs:\n            device = kwargs.pop(\"torch_device\")\n            warnings.warn(\n                \"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0.\"\n                \" Consider using `pipe.to(torch_device)` instead.\"\n            )\n\n            # Set device as before (to be removed in 0.3.0)\n            if device is None:\n                device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n            self.to(device)\n\n        if isinstance(prompt, str):\n            batch_size = 1\n        elif isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        # get prompt text embeddings\n        text_input = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance:\n            max_length = text_input.input_ids.shape[-1]\n            uncond_input = self.tokenizer(\n                [\"\"] * batch_size, padding=\"max_length\", max_length=max_length, return_tensors=\"pt\"\n            )\n            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n\n        # get the intial random noise\n        latents = torch.randn(\n            (batch_size, self.unet.in_channels, height // 8, width // 8),\n            generator=generator,\n            device=self.device,\n        )\n\n        # set timesteps\n        accepts_offset = \"offset\" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())\n        extra_set_kwargs = {}\n        if accepts_offset:\n            extra_set_kwargs[\"offset\"] = 1\n\n        self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)\n\n        # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas\n        if isinstance(self.scheduler, LMSDiscreteScheduler):\n            latents = latents * self.scheduler.sigmas[0]\n\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502\n        # and should be between [0, 1]\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n        \n        images = []\n\n        for i, t in tqdm(enumerate(self.scheduler.timesteps)):\n            # expand the latents if we are doing classifier free guidance\n            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n            if isinstance(self.scheduler, LMSDiscreteScheduler):\n                sigma = self.scheduler.sigmas[i]\n                latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)\n\n            # predict the noise residual\n            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)[\"sample\"]\n\n            # perform guidance\n            if do_classifier_free_guidance:\n                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n            # compute the previous noisy sample x_t -> x_t-1\n            if isinstance(self.scheduler, LMSDiscreteScheduler):\n                latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)[\"prev_sample\"]\n            else:\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)[\"prev_sample\"]\n            \n            if progress:\n                latent_image = self.vae.decode(1 / 0.18215 * latents)\n                latent_image = (latent_image / 2 + 0.5).clamp(0, 1)\n                latent_image = latent_image.cpu().permute(0, 2, 3, 1).numpy()\n\n                if latent_image.ndim == 3:\n                    latent_image = latent_image[None, ...]\n                latent_image = (latent_image * 255).round().astype('uint8')\n                latent_image = [Image.fromarray(image) for image in latent_image]\n                images.append(latent_image[0])\n\n\n        if progress:\n            images[0].save(f'output.gif', save_all=True, append_images=images[1:], optimize=False, loop=0, duration=125)\n\n        # scale and decode the image latents with vae\n        latents = 1 / 0.18215 * latents\n        image = self.vae.decode(latents)\n\n        image = (image / 2 + 0.5).clamp(0, 1)\n        image = image.cpu().permute(0, 2, 3, 1).numpy()\n\n        if output_type == \"pil\":\n            image = self.numpy_to_pil(image)\n\n        return {\"sample\": image}\n"
  },
  {
    "path": "src/stablediffusion/inpaint.py",
    "content": "import inspect\nfrom typing import List, Optional, Union\n\nimport numpy as np\nimport torch\n\nimport PIL\nfrom diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, PNDMScheduler, UNet2DConditionModel\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\n\ndef preprocess(image):\n    w, h = image.size\n    w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32\n    image = image.resize((w, h), resample=PIL.Image.LANCZOS)\n    image = np.array(image).astype(np.float32) / 255.0\n    image = image[None].transpose(0, 3, 1, 2)\n    image = torch.from_numpy(image)\n    return 2.0 * image - 1.0\n\ndef preprocess_mask(mask):\n    mask=mask.convert(\"L\")\n    w, h = mask.size\n    mask = mask.resize((int(w / 8), int(h / 8)), resample=PIL.Image.LANCZOS)\n    mask = np.array(mask).astype(np.float32) / 255.0\n    mask = np.tile(mask,(4,1,1))\n    mask = mask[None].transpose(0, 1, 2, 3)#what does this step do?\n    mask = torch.from_numpy(mask).bool()\n    return (mask).long()\n\nclass StableDiffusionInpaintingPipeline(DiffusionPipeline):\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: Union[DDIMScheduler, PNDMScheduler],\n    ):\n        super().__init__()\n        scheduler = scheduler.set_format(\"pt\")\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n        )\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]],\n        init_image: torch.FloatTensor,\n        mask_image: torch.FloatTensor,\n        strength: float = 0.8,\n        num_inference_steps: Optional[int] = 50,\n        guidance_scale: Optional[float] = 7.5,\n        eta: Optional[float] = 0.0,\n        generator: Optional[torch.Generator] = None,\n        output_type: Optional[str] = \"pil\",\n    ):\n\n        if isinstance(prompt, str):\n            batch_size = 1\n        elif isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if strength < 0 or strength > 1:\n            raise ValueError(f\"The value of strength should in [0.0, 1.0] but is {strength}\")\n\n        # set timesteps\n        accepts_offset = \"offset\" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())\n        extra_set_kwargs = {}\n        offset = 0\n        if accepts_offset:\n            offset = 1\n            extra_set_kwargs[\"offset\"] = 1\n\n        self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)\n\n        # encode the init image into latents and scale the latents\n        init_latents = self.vae.encode(init_image.to(self.device)).sample()\n        init_latents = 0.18215 * init_latents\n        init_latents_orig = init_latents\n\n        # prepare init_latents noise to latents\n        init_latents = torch.cat([init_latents] * batch_size)\n\n        # preprocess mask\n        mask = preprocess_mask(mask_image).to(self.device)\n        mask = torch.cat([mask] * batch_size)\n\n        # get the original timestep using init_timestep\n        init_timestep = int(num_inference_steps * strength) + offset\n        init_timestep = min(init_timestep, num_inference_steps)\n        timesteps = self.scheduler.timesteps[-init_timestep]\n        timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)\n\n        # add noise to latents using the timesteps\n        noise = torch.randn(init_latents.shape, generator=generator, device=self.device)\n        init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)\n\n        # get prompt text embeddings\n        text_input = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance:\n            max_length = text_input.input_ids.shape[-1]\n            uncond_input = self.tokenizer(\n                [\"\"] * batch_size, padding=\"max_length\", max_length=max_length, return_tensors=\"pt\"\n            )\n            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502\n        # and should be between [0, 1]\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        latents = init_latents\n        t_start = max(num_inference_steps - init_timestep + offset, 0)\n        for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):\n            # expand the latents if we are doing classifier free guidance\n            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n\n            # predict the noise residual\n            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)[\"sample\"]\n\n            # perform guidance\n            if do_classifier_free_guidance:\n                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n            # compute the previous noisy sample x_t -> x_t-1\n            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)[\"prev_sample\"]\n\n            #masking\n            init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)\n            latents = ( init_latents_proper * mask ) + ( latents * (1-mask) )\n\n        # scale and decode the image latents with vae\n        latents = 1 / 0.18215 * latents\n        image = self.vae.decode(latents)\n\n        image = (image / 2 + 0.5).clamp(0, 1)\n        image = image.cpu().permute(0, 2, 3, 1).numpy()\n\n        if output_type == \"pil\":\n            image = self.numpy_to_pil(image)\n\n        return {\"sample\": image, \"nsfw_content_detected\": False}"
  },
  {
    "path": "src/stablediffusion/ldm/__init__.py",
    "content": "from .generate import Generate"
  },
  {
    "path": "src/stablediffusion/ldm/data/__init__.py",
    "content": ""
  },
  {
    "path": "src/stablediffusion/ldm/data/base.py",
    "content": "from abc import abstractmethod\nfrom torch.utils.data import (\n    Dataset,\n    ConcatDataset,\n    ChainDataset,\n    IterableDataset,\n)\n\n\nclass Txt2ImgIterableBaseDataset(IterableDataset):\n    \"\"\"\n    Define an interface to make the IterableDatasets for text2img data chainable\n    \"\"\"\n\n    def __init__(self, num_records=0, valid_ids=None, size=256):\n        super().__init__()\n        self.num_records = num_records\n        self.valid_ids = valid_ids\n        self.sample_ids = valid_ids\n        self.size = size\n\n        print(\n            f'{self.__class__.__name__} dataset contains {self.__len__()} examples.'\n        )\n\n    def __len__(self):\n        return self.num_records\n\n    @abstractmethod\n    def __iter__(self):\n        pass\n"
  },
  {
    "path": "src/stablediffusion/ldm/data/imagenet.py",
    "content": "import os, yaml, pickle, shutil, tarfile, glob\nimport cv2\nimport albumentations\nimport PIL\nimport numpy as np\nimport torchvision.transforms.functional as TF\nfrom omegaconf import OmegaConf\nfrom functools import partial\nfrom PIL import Image\nfrom tqdm import tqdm\nfrom torch.utils.data import Dataset, Subset\n\nimport taming.data.utils as tdu\nfrom taming.data.imagenet import (\n    str_to_indices,\n    give_synsets_from_indices,\n    download,\n    retrieve,\n)\nfrom taming.data.imagenet import ImagePaths\n\nfrom ldm.modules.image_degradation import (\n    degradation_fn_bsr,\n    degradation_fn_bsr_light,\n)\n\n\ndef synset2idx(path_to_yaml='data/index_synset.yaml'):\n    with open(path_to_yaml) as f:\n        di2s = yaml.load(f)\n    return dict((v, k) for k, v in di2s.items())\n\n\nclass ImageNetBase(Dataset):\n    def __init__(self, config=None):\n        self.config = config or OmegaConf.create()\n        if not type(self.config) == dict:\n            self.config = OmegaConf.to_container(self.config)\n        self.keep_orig_class_label = self.config.get(\n            'keep_orig_class_label', False\n        )\n        self.process_images = True  # if False we skip loading & processing images and self.data contains filepaths\n        self._prepare()\n        self._prepare_synset_to_human()\n        self._prepare_idx_to_synset()\n        self._prepare_human_to_integer_label()\n        self._load()\n\n    def __len__(self):\n        return len(self.data)\n\n    def __getitem__(self, i):\n        return self.data[i]\n\n    def _prepare(self):\n        raise NotImplementedError()\n\n    def _filter_relpaths(self, relpaths):\n        ignore = set(\n            [\n                'n06596364_9591.JPEG',\n            ]\n        )\n        relpaths = [\n            rpath for rpath in relpaths if not rpath.split('/')[-1] in ignore\n        ]\n        if 'sub_indices' in self.config:\n            indices = str_to_indices(self.config['sub_indices'])\n            synsets = give_synsets_from_indices(\n                indices, path_to_yaml=self.idx2syn\n            )  # returns a list of strings\n            self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)\n            files = []\n            for rpath in relpaths:\n                syn = rpath.split('/')[0]\n                if syn in synsets:\n                    files.append(rpath)\n            return files\n        else:\n            return relpaths\n\n    def _prepare_synset_to_human(self):\n        SIZE = 2655750\n        URL = 'https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1'\n        self.human_dict = os.path.join(self.root, 'synset_human.txt')\n        if (\n            not os.path.exists(self.human_dict)\n            or not os.path.getsize(self.human_dict) == SIZE\n        ):\n            download(URL, self.human_dict)\n\n    def _prepare_idx_to_synset(self):\n        URL = 'https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1'\n        self.idx2syn = os.path.join(self.root, 'index_synset.yaml')\n        if not os.path.exists(self.idx2syn):\n            download(URL, self.idx2syn)\n\n    def _prepare_human_to_integer_label(self):\n        URL = 'https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1'\n        self.human2integer = os.path.join(\n            self.root, 'imagenet1000_clsidx_to_labels.txt'\n        )\n        if not os.path.exists(self.human2integer):\n            download(URL, self.human2integer)\n        with open(self.human2integer, 'r') as f:\n            lines = f.read().splitlines()\n            assert len(lines) == 1000\n            self.human2integer_dict = dict()\n            for line in lines:\n                value, key = line.split(':')\n                self.human2integer_dict[key] = int(value)\n\n    def _load(self):\n        with open(self.txt_filelist, 'r') as f:\n            self.relpaths = f.read().splitlines()\n            l1 = len(self.relpaths)\n            self.relpaths = self._filter_relpaths(self.relpaths)\n            print(\n                'Removed {} files from filelist during filtering.'.format(\n                    l1 - len(self.relpaths)\n                )\n            )\n\n        self.synsets = [p.split('/')[0] for p in self.relpaths]\n        self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]\n\n        unique_synsets = np.unique(self.synsets)\n        class_dict = dict(\n            (synset, i) for i, synset in enumerate(unique_synsets)\n        )\n        if not self.keep_orig_class_label:\n            self.class_labels = [class_dict[s] for s in self.synsets]\n        else:\n            self.class_labels = [self.synset2idx[s] for s in self.synsets]\n\n        with open(self.human_dict, 'r') as f:\n            human_dict = f.read().splitlines()\n            human_dict = dict(line.split(maxsplit=1) for line in human_dict)\n\n        self.human_labels = [human_dict[s] for s in self.synsets]\n\n        labels = {\n            'relpath': np.array(self.relpaths),\n            'synsets': np.array(self.synsets),\n            'class_label': np.array(self.class_labels),\n            'human_label': np.array(self.human_labels),\n        }\n\n        if self.process_images:\n            self.size = retrieve(self.config, 'size', default=256)\n            self.data = ImagePaths(\n                self.abspaths,\n                labels=labels,\n                size=self.size,\n                random_crop=self.random_crop,\n            )\n        else:\n            self.data = self.abspaths\n\n\nclass ImageNetTrain(ImageNetBase):\n    NAME = 'ILSVRC2012_train'\n    URL = 'http://www.image-net.org/challenges/LSVRC/2012/'\n    AT_HASH = 'a306397ccf9c2ead27155983c254227c0fd938e2'\n    FILES = [\n        'ILSVRC2012_img_train.tar',\n    ]\n    SIZES = [\n        147897477120,\n    ]\n\n    def __init__(self, process_images=True, data_root=None, **kwargs):\n        self.process_images = process_images\n        self.data_root = data_root\n        super().__init__(**kwargs)\n\n    def _prepare(self):\n        if self.data_root:\n            self.root = os.path.join(self.data_root, self.NAME)\n        else:\n            cachedir = os.environ.get(\n                'XDG_CACHE_HOME', os.path.expanduser('~/.cache')\n            )\n            self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME)\n\n        self.datadir = os.path.join(self.root, 'data')\n        self.txt_filelist = os.path.join(self.root, 'filelist.txt')\n        self.expected_length = 1281167\n        self.random_crop = retrieve(\n            self.config, 'ImageNetTrain/random_crop', default=True\n        )\n        if not tdu.is_prepared(self.root):\n            # prep\n            print('Preparing dataset {} in {}'.format(self.NAME, self.root))\n\n            datadir = self.datadir\n            if not os.path.exists(datadir):\n                path = os.path.join(self.root, self.FILES[0])\n                if (\n                    not os.path.exists(path)\n                    or not os.path.getsize(path) == self.SIZES[0]\n                ):\n                    import academictorrents as at\n\n                    atpath = at.get(self.AT_HASH, datastore=self.root)\n                    assert atpath == path\n\n                print('Extracting {} to {}'.format(path, datadir))\n                os.makedirs(datadir, exist_ok=True)\n                with tarfile.open(path, 'r:') as tar:\n                    tar.extractall(path=datadir)\n\n                print('Extracting sub-tars.')\n                subpaths = sorted(glob.glob(os.path.join(datadir, '*.tar')))\n                for subpath in tqdm(subpaths):\n                    subdir = subpath[: -len('.tar')]\n                    os.makedirs(subdir, exist_ok=True)\n                    with tarfile.open(subpath, 'r:') as tar:\n                        tar.extractall(path=subdir)\n\n            filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG'))\n            filelist = [os.path.relpath(p, start=datadir) for p in filelist]\n            filelist = sorted(filelist)\n            filelist = '\\n'.join(filelist) + '\\n'\n            with open(self.txt_filelist, 'w') as f:\n                f.write(filelist)\n\n            tdu.mark_prepared(self.root)\n\n\nclass ImageNetValidation(ImageNetBase):\n    NAME = 'ILSVRC2012_validation'\n    URL = 'http://www.image-net.org/challenges/LSVRC/2012/'\n    AT_HASH = '5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5'\n    VS_URL = 'https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1'\n    FILES = [\n        'ILSVRC2012_img_val.tar',\n        'validation_synset.txt',\n    ]\n    SIZES = [\n        6744924160,\n        1950000,\n    ]\n\n    def __init__(self, process_images=True, data_root=None, **kwargs):\n        self.data_root = data_root\n        self.process_images = process_images\n        super().__init__(**kwargs)\n\n    def _prepare(self):\n        if self.data_root:\n            self.root = os.path.join(self.data_root, self.NAME)\n        else:\n            cachedir = os.environ.get(\n                'XDG_CACHE_HOME', os.path.expanduser('~/.cache')\n            )\n            self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME)\n        self.datadir = os.path.join(self.root, 'data')\n        self.txt_filelist = os.path.join(self.root, 'filelist.txt')\n        self.expected_length = 50000\n        self.random_crop = retrieve(\n            self.config, 'ImageNetValidation/random_crop', default=False\n        )\n        if not tdu.is_prepared(self.root):\n            # prep\n            print('Preparing dataset {} in {}'.format(self.NAME, self.root))\n\n            datadir = self.datadir\n            if not os.path.exists(datadir):\n                path = os.path.join(self.root, self.FILES[0])\n                if (\n                    not os.path.exists(path)\n                    or not os.path.getsize(path) == self.SIZES[0]\n                ):\n                    import academictorrents as at\n\n                    atpath = at.get(self.AT_HASH, datastore=self.root)\n                    assert atpath == path\n\n                print('Extracting {} to {}'.format(path, datadir))\n                os.makedirs(datadir, exist_ok=True)\n                with tarfile.open(path, 'r:') as tar:\n                    tar.extractall(path=datadir)\n\n                vspath = os.path.join(self.root, self.FILES[1])\n                if (\n                    not os.path.exists(vspath)\n                    or not os.path.getsize(vspath) == self.SIZES[1]\n                ):\n                    download(self.VS_URL, vspath)\n\n                with open(vspath, 'r') as f:\n                    synset_dict = f.read().splitlines()\n                    synset_dict = dict(line.split() for line in synset_dict)\n\n                print('Reorganizing into synset folders')\n                synsets = np.unique(list(synset_dict.values()))\n                for s in synsets:\n                    os.makedirs(os.path.join(datadir, s), exist_ok=True)\n                for k, v in synset_dict.items():\n                    src = os.path.join(datadir, k)\n                    dst = os.path.join(datadir, v)\n                    shutil.move(src, dst)\n\n            filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG'))\n            filelist = [os.path.relpath(p, start=datadir) for p in filelist]\n            filelist = sorted(filelist)\n            filelist = '\\n'.join(filelist) + '\\n'\n            with open(self.txt_filelist, 'w') as f:\n                f.write(filelist)\n\n            tdu.mark_prepared(self.root)\n\n\nclass ImageNetSR(Dataset):\n    def __init__(\n        self,\n        size=None,\n        degradation=None,\n        downscale_f=4,\n        min_crop_f=0.5,\n        max_crop_f=1.0,\n        random_crop=True,\n    ):\n        \"\"\"\n        Imagenet Superresolution Dataloader\n        Performs following ops in order:\n        1.  crops a crop of size s from image either as random or center crop\n        2.  resizes crop to size with cv2.area_interpolation\n        3.  degrades resized crop with degradation_fn\n\n        :param size: resizing to size after cropping\n        :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light\n        :param downscale_f: Low Resolution Downsample factor\n        :param min_crop_f: determines crop size s,\n          where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)\n        :param max_crop_f: \"\"\n        :param data_root:\n        :param random_crop:\n        \"\"\"\n        self.base = self.get_base()\n        assert size\n        assert (size / downscale_f).is_integer()\n        self.size = size\n        self.LR_size = int(size / downscale_f)\n        self.min_crop_f = min_crop_f\n        self.max_crop_f = max_crop_f\n        assert max_crop_f <= 1.0\n        self.center_crop = not random_crop\n\n        self.image_rescaler = albumentations.SmallestMaxSize(\n            max_size=size, interpolation=cv2.INTER_AREA\n        )\n\n        self.pil_interpolation = (\n            False  # gets reset later if incase interp_op is from pillow\n        )\n\n        if degradation == 'bsrgan':\n            self.degradation_process = partial(\n                degradation_fn_bsr, sf=downscale_f\n            )\n\n        elif degradation == 'bsrgan_light':\n            self.degradation_process = partial(\n                degradation_fn_bsr_light, sf=downscale_f\n            )\n\n        else:\n            interpolation_fn = {\n                'cv_nearest': cv2.INTER_NEAREST,\n                'cv_bilinear': cv2.INTER_LINEAR,\n                'cv_bicubic': cv2.INTER_CUBIC,\n                'cv_area': cv2.INTER_AREA,\n                'cv_lanczos': cv2.INTER_LANCZOS4,\n                'pil_nearest': PIL.Image.NEAREST,\n                'pil_bilinear': PIL.Image.BILINEAR,\n                'pil_bicubic': PIL.Image.BICUBIC,\n                'pil_box': PIL.Image.BOX,\n                'pil_hamming': PIL.Image.HAMMING,\n                'pil_lanczos': PIL.Image.LANCZOS,\n            }[degradation]\n\n            self.pil_interpolation = degradation.startswith('pil_')\n\n            if self.pil_interpolation:\n                self.degradation_process = partial(\n                    TF.resize,\n                    size=self.LR_size,\n                    interpolation=interpolation_fn,\n                )\n\n            else:\n                self.degradation_process = albumentations.SmallestMaxSize(\n                    max_size=self.LR_size, interpolation=interpolation_fn\n                )\n\n    def __len__(self):\n        return len(self.base)\n\n    def __getitem__(self, i):\n        example = self.base[i]\n        image = Image.open(example['file_path_'])\n\n        if not image.mode == 'RGB':\n            image = image.convert('RGB')\n\n        image = np.array(image).astype(np.uint8)\n\n        min_side_len = min(image.shape[:2])\n        crop_side_len = min_side_len * np.random.uniform(\n            self.min_crop_f, self.max_crop_f, size=None\n        )\n        crop_side_len = int(crop_side_len)\n\n        if self.center_crop:\n            self.cropper = albumentations.CenterCrop(\n                height=crop_side_len, width=crop_side_len\n            )\n\n        else:\n            self.cropper = albumentations.RandomCrop(\n                height=crop_side_len, width=crop_side_len\n            )\n\n        image = self.cropper(image=image)['image']\n        image = self.image_rescaler(image=image)['image']\n\n        if self.pil_interpolation:\n            image_pil = PIL.Image.fromarray(image)\n            LR_image = self.degradation_process(image_pil)\n            LR_image = np.array(LR_image).astype(np.uint8)\n\n        else:\n            LR_image = self.degradation_process(image=image)['image']\n\n        example['image'] = (image / 127.5 - 1.0).astype(np.float32)\n        example['LR_image'] = (LR_image / 127.5 - 1.0).astype(np.float32)\n\n        return example\n\n\nclass ImageNetSRTrain(ImageNetSR):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n\n    def get_base(self):\n        with open('data/imagenet_train_hr_indices.p', 'rb') as f:\n            indices = pickle.load(f)\n        dset = ImageNetTrain(\n            process_images=False,\n        )\n        return Subset(dset, indices)\n\n\nclass ImageNetSRValidation(ImageNetSR):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n\n    def get_base(self):\n        with open('data/imagenet_val_hr_indices.p', 'rb') as f:\n            indices = pickle.load(f)\n        dset = ImageNetValidation(\n            process_images=False,\n        )\n        return Subset(dset, indices)\n"
  },
  {
    "path": "src/stablediffusion/ldm/data/lsun.py",
    "content": "import os\nimport numpy as np\nimport PIL\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\n\n\nclass LSUNBase(Dataset):\n    def __init__(\n        self,\n        txt_file,\n        data_root,\n        size=None,\n        interpolation='bicubic',\n        flip_p=0.5,\n    ):\n        self.data_paths = txt_file\n        self.data_root = data_root\n        with open(self.data_paths, 'r') as f:\n            self.image_paths = f.read().splitlines()\n        self._length = len(self.image_paths)\n        self.labels = {\n            'relative_file_path_': [l for l in self.image_paths],\n            'file_path_': [\n                os.path.join(self.data_root, l) for l in self.image_paths\n            ],\n        }\n\n        self.size = size\n        self.interpolation = {\n            'linear': PIL.Image.LINEAR,\n            'bilinear': PIL.Image.BILINEAR,\n            'bicubic': PIL.Image.BICUBIC,\n            'lanczos': PIL.Image.LANCZOS,\n        }[interpolation]\n        self.flip = transforms.RandomHorizontalFlip(p=flip_p)\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, i):\n        example = dict((k, self.labels[k][i]) for k in self.labels)\n        image = Image.open(example['file_path_'])\n        if not image.mode == 'RGB':\n            image = image.convert('RGB')\n\n        # default to score-sde preprocessing\n        img = np.array(image).astype(np.uint8)\n        crop = min(img.shape[0], img.shape[1])\n        h, w, = (\n            img.shape[0],\n            img.shape[1],\n        )\n        img = img[\n            (h - crop) // 2 : (h + crop) // 2,\n            (w - crop) // 2 : (w + crop) // 2,\n        ]\n\n        image = Image.fromarray(img)\n        if self.size is not None:\n            image = image.resize(\n                (self.size, self.size), resample=self.interpolation\n            )\n\n        image = self.flip(image)\n        image = np.array(image).astype(np.uint8)\n        example['image'] = (image / 127.5 - 1.0).astype(np.float32)\n        return example\n\n\nclass LSUNChurchesTrain(LSUNBase):\n    def __init__(self, **kwargs):\n        super().__init__(\n            txt_file='data/lsun/church_outdoor_train.txt',\n            data_root='data/lsun/churches',\n            **kwargs\n        )\n\n\nclass LSUNChurchesValidation(LSUNBase):\n    def __init__(self, flip_p=0.0, **kwargs):\n        super().__init__(\n            txt_file='data/lsun/church_outdoor_val.txt',\n            data_root='data/lsun/churches',\n            flip_p=flip_p,\n            **kwargs\n        )\n\n\nclass LSUNBedroomsTrain(LSUNBase):\n    def __init__(self, **kwargs):\n        super().__init__(\n            txt_file='data/lsun/bedrooms_train.txt',\n            data_root='data/lsun/bedrooms',\n            **kwargs\n        )\n\n\nclass LSUNBedroomsValidation(LSUNBase):\n    def __init__(self, flip_p=0.0, **kwargs):\n        super().__init__(\n            txt_file='data/lsun/bedrooms_val.txt',\n            data_root='data/lsun/bedrooms',\n            flip_p=flip_p,\n            **kwargs\n        )\n\n\nclass LSUNCatsTrain(LSUNBase):\n    def __init__(self, **kwargs):\n        super().__init__(\n            txt_file='data/lsun/cat_train.txt',\n            data_root='data/lsun/cats',\n            **kwargs\n        )\n\n\nclass LSUNCatsValidation(LSUNBase):\n    def __init__(self, flip_p=0.0, **kwargs):\n        super().__init__(\n            txt_file='data/lsun/cat_val.txt',\n            data_root='data/lsun/cats',\n            flip_p=flip_p,\n            **kwargs\n        )\n"
  },
  {
    "path": "src/stablediffusion/ldm/data/personalized.py",
    "content": "import os\nimport numpy as np\nimport PIL\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\n\nimport random\n\nimagenet_templates_smallest = [\n    'a photo of a {}',\n]\n\nimagenet_templates_small = [\n    'a photo of a {}',\n    'a rendering of a {}',\n    'a cropped photo of the {}',\n    'the photo of a {}',\n    'a photo of a clean {}',\n    'a photo of a dirty {}',\n    'a dark photo of the {}',\n    'a photo of my {}',\n    'a photo of the cool {}',\n    'a close-up photo of a {}',\n    'a bright photo of the {}',\n    'a cropped photo of a {}',\n    'a photo of the {}',\n    'a good photo of the {}',\n    'a photo of one {}',\n    'a close-up photo of the {}',\n    'a rendition of the {}',\n    'a photo of the clean {}',\n    'a rendition of a {}',\n    'a photo of a nice {}',\n    'a good photo of a {}',\n    'a photo of the nice {}',\n    'a photo of the small {}',\n    'a photo of the weird {}',\n    'a photo of the large {}',\n    'a photo of a cool {}',\n    'a photo of a small {}',\n]\n\nimagenet_dual_templates_small = [\n    'a photo of a {} with {}',\n    'a rendering of a {} with {}',\n    'a cropped photo of the {} with {}',\n    'the photo of a {} with {}',\n    'a photo of a clean {} with {}',\n    'a photo of a dirty {} with {}',\n    'a dark photo of the {} with {}',\n    'a photo of my {} with {}',\n    'a photo of the cool {} with {}',\n    'a close-up photo of a {} with {}',\n    'a bright photo of the {} with {}',\n    'a cropped photo of a {} with {}',\n    'a photo of the {} with {}',\n    'a good photo of the {} with {}',\n    'a photo of one {} with {}',\n    'a close-up photo of the {} with {}',\n    'a rendition of the {} with {}',\n    'a photo of the clean {} with {}',\n    'a rendition of a {} with {}',\n    'a photo of a nice {} with {}',\n    'a good photo of a {} with {}',\n    'a photo of the nice {} with {}',\n    'a photo of the small {} with {}',\n    'a photo of the weird {} with {}',\n    'a photo of the large {} with {}',\n    'a photo of a cool {} with {}',\n    'a photo of a small {} with {}',\n]\n\nper_img_token_list = [\n    'א',\n    'ב',\n    'ג',\n    'ד',\n    'ה',\n    'ו',\n    'ז',\n    'ח',\n    'ט',\n    'י',\n    'כ',\n    'ל',\n    'מ',\n    'נ',\n    'ס',\n    'ע',\n    'פ',\n    'צ',\n    'ק',\n    'ר',\n    'ש',\n    'ת',\n]\n\n\nclass PersonalizedBase(Dataset):\n    def __init__(\n        self,\n        data_root,\n        size=None,\n        repeats=100,\n        interpolation='bicubic',\n        flip_p=0.5,\n        set='train',\n        placeholder_token='*',\n        per_image_tokens=False,\n        center_crop=False,\n        mixing_prob=0.25,\n        coarse_class_text=None,\n    ):\n\n        self.data_root = data_root\n\n        self.image_paths = [\n            os.path.join(self.data_root, file_path)\n            for file_path in os.listdir(self.data_root)\n        ]\n\n        # self._length = len(self.image_paths)\n        self.num_images = len(self.image_paths)\n        self._length = self.num_images\n\n        self.placeholder_token = placeholder_token\n\n        self.per_image_tokens = per_image_tokens\n        self.center_crop = center_crop\n        self.mixing_prob = mixing_prob\n\n        self.coarse_class_text = coarse_class_text\n\n        if per_image_tokens:\n            assert self.num_images < len(\n                per_img_token_list\n            ), f\"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'.\"\n\n        if set == 'train':\n            self._length = self.num_images * repeats\n\n        self.size = size\n        self.interpolation = {\n            'linear': PIL.Image.LINEAR,\n            'bilinear': PIL.Image.BILINEAR,\n            'bicubic': PIL.Image.BICUBIC,\n            'lanczos': PIL.Image.LANCZOS,\n        }[interpolation]\n        self.flip = transforms.RandomHorizontalFlip(p=flip_p)\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, i):\n        example = {}\n        image = Image.open(self.image_paths[i % self.num_images])\n\n        if not image.mode == 'RGB':\n            image = image.convert('RGB')\n\n        placeholder_string = self.placeholder_token\n        if self.coarse_class_text:\n            placeholder_string = (\n                f'{self.coarse_class_text} {placeholder_string}'\n            )\n\n        if self.per_image_tokens and np.random.uniform() < self.mixing_prob:\n            text = random.choice(imagenet_dual_templates_small).format(\n                placeholder_string, per_img_token_list[i % self.num_images]\n            )\n        else:\n            text = random.choice(imagenet_templates_small).format(\n                placeholder_string\n            )\n\n        example['caption'] = text\n\n        # default to score-sde preprocessing\n        img = np.array(image).astype(np.uint8)\n\n        if self.center_crop:\n            crop = min(img.shape[0], img.shape[1])\n            h, w, = (\n                img.shape[0],\n                img.shape[1],\n            )\n            img = img[\n                (h - crop) // 2 : (h + crop) // 2,\n                (w - crop) // 2 : (w + crop) // 2,\n            ]\n\n        image = Image.fromarray(img)\n        if self.size is not None:\n            image = image.resize(\n                (self.size, self.size), resample=self.interpolation\n            )\n\n        image = self.flip(image)\n        image = np.array(image).astype(np.uint8)\n        example['image'] = (image / 127.5 - 1.0).astype(np.float32)\n        return example\n"
  },
  {
    "path": "src/stablediffusion/ldm/data/personalized_style.py",
    "content": "import os\nimport numpy as np\nimport PIL\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\n\nimport random\n\nimagenet_templates_small = [\n    'a painting in the style of {}',\n    'a rendering in the style of {}',\n    'a cropped painting in the style of {}',\n    'the painting in the style of {}',\n    'a clean painting in the style of {}',\n    'a dirty painting in the style of {}',\n    'a dark painting in the style of {}',\n    'a picture in the style of {}',\n    'a cool painting in the style of {}',\n    'a close-up painting in the style of {}',\n    'a bright painting in the style of {}',\n    'a cropped painting in the style of {}',\n    'a good painting in the style of {}',\n    'a close-up painting in the style of {}',\n    'a rendition in the style of {}',\n    'a nice painting in the style of {}',\n    'a small painting in the style of {}',\n    'a weird painting in the style of {}',\n    'a large painting in the style of {}',\n]\n\nimagenet_dual_templates_small = [\n    'a painting in the style of {} with {}',\n    'a rendering in the style of {} with {}',\n    'a cropped painting in the style of {} with {}',\n    'the painting in the style of {} with {}',\n    'a clean painting in the style of {} with {}',\n    'a dirty painting in the style of {} with {}',\n    'a dark painting in the style of {} with {}',\n    'a cool painting in the style of {} with {}',\n    'a close-up painting in the style of {} with {}',\n    'a bright painting in the style of {} with {}',\n    'a cropped painting in the style of {} with {}',\n    'a good painting in the style of {} with {}',\n    'a painting of one {} in the style of {}',\n    'a nice painting in the style of {} with {}',\n    'a small painting in the style of {} with {}',\n    'a weird painting in the style of {} with {}',\n    'a large painting in the style of {} with {}',\n]\n\nper_img_token_list = [\n    'א',\n    'ב',\n    'ג',\n    'ד',\n    'ה',\n    'ו',\n    'ז',\n    'ח',\n    'ט',\n    'י',\n    'כ',\n    'ל',\n    'מ',\n    'נ',\n    'ס',\n    'ע',\n    'פ',\n    'צ',\n    'ק',\n    'ר',\n    'ש',\n    'ת',\n]\n\n\nclass PersonalizedBase(Dataset):\n    def __init__(\n        self,\n        data_root,\n        size=None,\n        repeats=100,\n        interpolation='bicubic',\n        flip_p=0.5,\n        set='train',\n        placeholder_token='*',\n        per_image_tokens=False,\n        center_crop=False,\n    ):\n\n        self.data_root = data_root\n\n        self.image_paths = [\n            os.path.join(self.data_root, file_path)\n            for file_path in os.listdir(self.data_root)\n        ]\n\n        # self._length = len(self.image_paths)\n        self.num_images = len(self.image_paths)\n        self._length = self.num_images\n\n        self.placeholder_token = placeholder_token\n\n        self.per_image_tokens = per_image_tokens\n        self.center_crop = center_crop\n\n        if per_image_tokens:\n            assert self.num_images < len(\n                per_img_token_list\n            ), f\"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'.\"\n\n        if set == 'train':\n            self._length = self.num_images * repeats\n\n        self.size = size\n        self.interpolation = {\n            'linear': PIL.Image.LINEAR,\n            'bilinear': PIL.Image.BILINEAR,\n            'bicubic': PIL.Image.BICUBIC,\n            'lanczos': PIL.Image.LANCZOS,\n        }[interpolation]\n        self.flip = transforms.RandomHorizontalFlip(p=flip_p)\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, i):\n        example = {}\n        image = Image.open(self.image_paths[i % self.num_images])\n\n        if not image.mode == 'RGB':\n            image = image.convert('RGB')\n\n        if self.per_image_tokens and np.random.uniform() < 0.25:\n            text = random.choice(imagenet_dual_templates_small).format(\n                self.placeholder_token, per_img_token_list[i % self.num_images]\n            )\n        else:\n            text = random.choice(imagenet_templates_small).format(\n                self.placeholder_token\n            )\n\n        example['caption'] = text\n\n        # default to score-sde preprocessing\n        img = np.array(image).astype(np.uint8)\n\n        if self.center_crop:\n            crop = min(img.shape[0], img.shape[1])\n            h, w, = (\n                img.shape[0],\n                img.shape[1],\n            )\n            img = img[\n                (h - crop) // 2 : (h + crop) // 2,\n                (w - crop) // 2 : (w + crop) // 2,\n            ]\n\n        image = Image.fromarray(img)\n        if self.size is not None:\n            image = image.resize(\n                (self.size, self.size), resample=self.interpolation\n            )\n\n        image = self.flip(image)\n        image = np.array(image).astype(np.uint8)\n        example['image'] = (image / 127.5 - 1.0).astype(np.float32)\n        return example\n"
  },
  {
    "path": "src/stablediffusion/ldm/dream/conditioning.py",
    "content": "'''\nThis module handles the generation of the conditioning tensors, including management of\nweighted subprompts.\n\nUseful function exports:\n\nget_uc_and_c()                  get the conditioned and unconditioned latent\nsplit_weighted_subpromopts()    split subprompts, normalize and weight them\nlog_tokenization()              print out colour-coded tokens and warn if truncated\n\n'''\nimport re\nimport torch\n\ndef get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):\n    uc = model.get_learned_conditioning([''])\n\n    # get weighted sub-prompts\n    weighted_subprompts = split_weighted_subprompts(\n        prompt, skip_normalize\n    )\n\n    if len(weighted_subprompts) > 1:\n        # i dont know if this is correct.. but it works\n        c = torch.zeros_like(uc)\n        # normalize each \"sub prompt\" and add it\n        for subprompt, weight in weighted_subprompts:\n            log_tokenization(subprompt, model, log_tokens)\n            c = torch.add(\n                c,\n                model.get_learned_conditioning([subprompt]),\n                alpha=weight,\n            )\n    else:   # just standard 1 prompt\n        log_tokenization(prompt, model, log_tokens)\n        c = model.get_learned_conditioning([prompt])\n    return (uc, c)\n\ndef split_weighted_subprompts(text, skip_normalize=False)->list:\n    \"\"\"\n    grabs all text up to the first occurrence of ':'\n    uses the grabbed text as a sub-prompt, and takes the value following ':' as weight\n    if ':' has no value defined, defaults to 1.0\n    repeats until no text remaining\n    \"\"\"\n    prompt_parser = re.compile(\"\"\"\n            (?P<prompt>     # capture group for 'prompt'\n            (?:\\\\\\:|[^:])+  # match one or more non ':' characters or escaped colons '\\:'\n            )               # end 'prompt'\n            (?:             # non-capture group\n            :+              # match one or more ':' characters\n            (?P<weight>     # capture group for 'weight'\n            -?\\d+(?:\\.\\d+)? # match positive or negative integer or decimal number\n            )?              # end weight capture group, make optional\n            \\s*             # strip spaces after weight\n            |               # OR\n            $               # else, if no ':' then match end of line\n            )               # end non-capture group\n            \"\"\", re.VERBOSE)\n    parsed_prompts = [(match.group(\"prompt\").replace(\"\\\\:\", \":\"), float(\n        match.group(\"weight\") or 1)) for match in re.finditer(prompt_parser, text)]\n    if skip_normalize:\n        return parsed_prompts\n    weight_sum = sum(map(lambda x: x[1], parsed_prompts))\n    if weight_sum == 0:\n        print(\n            \"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.\")\n        equal_weight = 1 / len(parsed_prompts)\n        return [(x[0], equal_weight) for x in parsed_prompts]\n    return [(x[0], x[1] / weight_sum) for x in parsed_prompts]\n        \n# shows how the prompt is tokenized\n# usually tokens have '</w>' to indicate end-of-word,\n# but for readability it has been replaced with ' '\ndef log_tokenization(text, model, log=False):\n    if not log:\n        return\n    tokens    = model.cond_stage_model.tokenizer._tokenize(text)\n    tokenized = \"\"\n    discarded = \"\"\n    usedTokens = 0\n    totalTokens = len(tokens)\n    for i in range(0, totalTokens):\n        token = tokens[i].replace('</w>', ' ')\n        # alternate color\n        s = (usedTokens % 6) + 1\n        if i < model.cond_stage_model.max_length:\n            tokenized = tokenized + f\"\\x1b[0;3{s};40m{token}\"\n            usedTokens += 1\n        else:  # over max token length\n            discarded = discarded + f\"\\x1b[0;3{s};40m{token}\"\n        print(f\"\\n>> Tokens ({usedTokens}):\\n{tokenized}\\x1b[0m\")\n        if discarded != \"\":\n            print(\n                f\">> Tokens Discarded ({totalTokens-usedTokens}):\\n{discarded}\\x1b[0m\"\n            )\n"
  },
  {
    "path": "src/stablediffusion/ldm/dream/devices.py",
    "content": "import torch\nfrom torch import autocast\nfrom contextlib import contextmanager, nullcontext\n\ndef choose_torch_device() -> str:\n    '''Convenience routine for guessing which GPU device to run model on'''\n    if torch.cuda.is_available():\n        return 'cuda'\n    if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n        return 'mps'\n    return 'cpu'\n\ndef choose_autocast_device(device):\n    '''Returns an autocast compatible device from a torch device'''\n    device_type = device.type # this returns 'mps' on M1\n    # autocast only supports cuda or cpu\n    if device_type in ('cuda','cpu'):\n        return device_type,autocast\n    else:\n        return 'cpu',nullcontext\n"
  },
  {
    "path": "src/stablediffusion/ldm/dream/generator/__init__.py",
    "content": "'''\nInitialization file for the ldm.dream.generator package\n'''\nfrom .base import Generator\n"
  },
  {
    "path": "src/stablediffusion/ldm/dream/generator/base.py",
    "content": "'''\nBase class for ldm.dream.generator.*\nincluding img2img, txt2img, and inpaint\n'''\nimport torch\nimport numpy as  np\nimport random\nfrom tqdm import tqdm, trange\nfrom PIL               import Image\nfrom einops import rearrange, repeat\nfrom pytorch_lightning import seed_everything\nfrom src.stablediffusion.ldm.dream.devices import choose_autocast_device\n\ndownsampling = 8\n\nclass Generator():\n    def __init__(self,model):\n        self.model               = model\n        self.seed                = None\n        self.latent_channels     = model.channels\n        self.downsampling_factor = downsampling   # BUG: should come from model or config\n        self.variation_amount    = 0\n        self.with_variations     = []\n\n    # this is going to be overridden in img2img.py, txt2img.py and inpaint.py\n    def get_make_image(self,prompt,**kwargs):\n        \"\"\"\n        Returns a function returning an image derived from the prompt and the initial image\n        Return value depends on the seed at the time you call it\n        \"\"\"\n        raise NotImplementedError(\"image_iterator() must be implemented in a descendent class\")\n\n    def set_variation(self, seed, variation_amount, with_variations):\n        self.seed             = seed\n        self.variation_amount = variation_amount\n        self.with_variations  = with_variations\n\n    def generate(self,prompt,init_image,width,height,iterations=1,seed=None,\n                 image_callback=None, step_callback=None,\n                 **kwargs):\n        device_type,scope   = choose_autocast_device(self.model.device)\n        make_image          = self.get_make_image(\n            prompt,\n            init_image    = init_image,\n            width         = width,\n            height        = height,\n            step_callback = step_callback,\n            **kwargs\n        )\n\n        results             = []\n        seed                = seed if seed else self.new_seed()\n        seed, initial_noise = self.generate_initial_noise(seed, width, height)\n        with scope(device_type), self.model.ema_scope():\n            for n in trange(iterations, desc='Generating'):\n                x_T = None\n                if self.variation_amount > 0:\n                    seed_everything(seed)\n                    target_noise = self.get_noise(width,height)\n                    x_T = self.slerp(self.variation_amount, initial_noise, target_noise)\n                elif initial_noise is not None:\n                    # i.e. we specified particular variations\n                    x_T = initial_noise\n                else:\n                    seed_everything(seed)\n                    if self.model.device.type == 'mps':\n                        x_T = self.get_noise(width,height)\n\n                # make_image will do the equivalent of get_noise itself\n                image = make_image(x_T)\n                results.append([image, seed])\n                if image_callback is not None:\n                    image_callback(image, seed)\n                seed = self.new_seed()\n        return results\n    \n    def sample_to_image(self,samples):\n        \"\"\"\n        Returns a function returning an image derived from the prompt and the initial image\n        Return value depends on the seed at the time you call it\n        \"\"\"\n        x_samples = self.model.decode_first_stage(samples)\n        x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)\n        if len(x_samples) != 1:\n            raise Exception(\n                f'>> expected to get a single image, but got {len(x_samples)}')\n        x_sample = 255.0 * rearrange(\n            x_samples[0].cpu().numpy(), 'c h w -> h w c'\n        )\n        return Image.fromarray(x_sample.astype(np.uint8))\n\n    def generate_initial_noise(self, seed, width, height):\n        initial_noise = None\n        if self.variation_amount > 0 or len(self.with_variations) > 0:\n            # use fixed initial noise plus random noise per iteration\n            seed_everything(seed)\n            initial_noise = self.get_noise(width,height)\n            for v_seed, v_weight in self.with_variations:\n                seed = v_seed\n                seed_everything(seed)\n                next_noise = self.get_noise(width,height)\n                initial_noise = self.slerp(v_weight, initial_noise, next_noise)\n            if self.variation_amount > 0:\n                random.seed() # reset RNG to an actually random state, so we can get a random seed for variations\n                seed = random.randrange(0,np.iinfo(np.uint32).max)\n            return (seed, initial_noise)\n        else:\n            return (seed, None)\n\n    # returns a tensor filled with random numbers from a normal distribution\n    def get_noise(self,width,height):\n        \"\"\"\n        Returns a tensor filled with random numbers, either form a normal distribution\n        (txt2img) or from the latent image (img2img, inpaint)\n        \"\"\"\n        raise NotImplementedError(\"get_noise() must be implemented in a descendent class\")\n    \n    def new_seed(self):\n        self.seed = random.randrange(0, np.iinfo(np.uint32).max)\n        return self.seed\n\n    def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):\n        '''\n        Spherical linear interpolation\n        Args:\n            t (float/np.ndarray): Float value between 0.0 and 1.0\n            v0 (np.ndarray): Starting vector\n            v1 (np.ndarray): Final vector\n            DOT_THRESHOLD (float): Threshold for considering the two vectors as\n                                colineal. Not recommended to alter this.\n        Returns:\n            v2 (np.ndarray): Interpolation vector between v0 and v1\n        '''\n        inputs_are_torch = False\n        if not isinstance(v0, np.ndarray):\n            inputs_are_torch = True\n            v0 = v0.detach().cpu().numpy()\n        if not isinstance(v1, np.ndarray):\n            inputs_are_torch = True\n            v1 = v1.detach().cpu().numpy()\n\n        dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))\n        if np.abs(dot) > DOT_THRESHOLD:\n            v2 = (1 - t) * v0 + t * v1\n        else:\n            theta_0 = np.arccos(dot)\n            sin_theta_0 = np.sin(theta_0)\n            theta_t = theta_0 * t\n            sin_theta_t = np.sin(theta_t)\n            s0 = np.sin(theta_0 - theta_t) / sin_theta_0\n            s1 = sin_theta_t / sin_theta_0\n            v2 = s0 * v0 + s1 * v1\n\n        if inputs_are_torch:\n            v2 = torch.from_numpy(v2).to(self.model.device)\n\n        return v2\n\n"
  },
  {
    "path": "src/stablediffusion/ldm/dream/generator/img2img.py",
    "content": "'''\nldm.dream.generator.txt2img descends from src.stablediffusion.ldm.dream.generator\n'''\n\nimport torch\nimport numpy as  np\nfrom src.stablediffusion.ldm.dream.devices             import choose_autocast_device\nfrom src.stablediffusion.ldm.dream.generator.base      import Generator\nfrom src.stablediffusion.ldm.models.diffusion.ddim     import DDIMSampler\n\nclass Img2Img(Generator):\n    def __init__(self,model):\n        super().__init__(model)\n        self.init_latent         = None    # by get_noise()\n    \n    @torch.no_grad()\n    def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,\n                       conditioning,init_image,strength,step_callback=None,**kwargs):\n        \"\"\"\n        Returns a function returning an image derived from the prompt and the initial image\n        Return value depends on the seed at the time you call it.\n        \"\"\"\n\n        # PLMS sampler not supported yet, so ignore previous sampler\n        if not isinstance(sampler,DDIMSampler):\n            print(\n                f\">> sampler '{sampler.__class__.__name__}' is not yet supported. Using DDIM sampler\"\n            )\n            sampler = DDIMSampler(self.model, device=self.model.device)\n\n        sampler.make_schedule(\n            ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False\n        )\n\n        device_type,scope   = choose_autocast_device(self.model.device)\n        with scope(device_type):\n            self.init_latent = self.model.get_first_stage_encoding(\n                self.model.encode_first_stage(init_image)\n            ) # move to latent space\n\n        t_enc = int(strength * steps)\n        uc, c   = conditioning\n\n        @torch.no_grad()\n        def make_image(x_T):\n            # encode (scaled latent)\n            z_enc = sampler.stochastic_encode(\n                self.init_latent,\n                torch.tensor([t_enc]).to(self.model.device),\n                noise=x_T\n            )\n            # decode it\n            samples = sampler.decode(\n                z_enc,\n                c,\n                t_enc,\n                img_callback = step_callback,\n                unconditional_guidance_scale=cfg_scale,\n                unconditional_conditioning=uc,\n            )\n            return self.sample_to_image(samples)\n\n        return make_image\n\n    def get_noise(self,width,height):\n        device      = self.model.device\n        init_latent = self.init_latent\n        assert init_latent is not None,'call to get_noise() when init_latent not set'\n        if device.type == 'mps':\n            return torch.randn_like(init_latent, device='cpu').to(device)\n        else:\n            return torch.randn_like(init_latent, device=device)\n"
  },
  {
    "path": "src/stablediffusion/ldm/dream/generator/inpaint.py",
    "content": "'''\nldm.dream.generator.inpaint descends from src.stablediffusion.ldm.dream.generator\n'''\n\nimport torch\nimport numpy as  np\nfrom einops import rearrange, repeat\nfrom src.stablediffusion.ldm.dream.devices             import choose_autocast_device\nfrom src.stablediffusion.ldm.dream.generator.img2img   import Img2Img\nfrom src.stablediffusion.ldm.models.diffusion.ddim     import DDIMSampler\n\nclass Inpaint(Img2Img):\n    def __init__(self,model):\n        self.init_latent = None\n        super().__init__(model)\n    \n    @torch.no_grad()\n    def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,\n                       conditioning,init_image,mask_image,strength,\n                       step_callback=None,**kwargs):\n        \"\"\"\n        Returns a function returning an image derived from the prompt and\n        the initial image + mask.  Return value depends on the seed at\n        the time you call it.  kwargs are 'init_latent' and 'strength'\n        \"\"\"\n\n        mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)\n        mask_image = repeat(mask_image, '1 ... -> b ...', b=1)\n\n        # PLMS sampler not supported yet, so ignore previous sampler\n        if not isinstance(sampler,DDIMSampler):\n            print(\n                f\">> sampler '{sampler.__class__.__name__}' is not yet supported. Using DDIM sampler\"\n            )\n            sampler = DDIMSampler(self.model, device=self.model.device)\n\n            sampler.make_schedule(\n                ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False\n            )\n\n        device_type,scope   = choose_autocast_device(self.model.device)\n        with scope(device_type):\n            self.init_latent = self.model.get_first_stage_encoding(\n                self.model.encode_first_stage(init_image)\n            ) # move to latent space\n\n        t_enc   = int(strength * steps)\n        uc, c   = conditioning\n\n        print(f\">> target t_enc is {t_enc} steps\")\n\n        @torch.no_grad()\n        def make_image(x_T):\n            # encode (scaled latent)\n            z_enc = sampler.stochastic_encode(\n                self.init_latent,\n                torch.tensor([t_enc]).to(self.model.device),\n                noise=x_T\n            )\n                                       \n            # decode it\n            samples = sampler.decode(\n                z_enc,\n                c,\n                t_enc,\n                img_callback                 = step_callback,\n                unconditional_guidance_scale = cfg_scale,\n                unconditional_conditioning = uc,\n                mask                       = mask_image,\n                init_latent                = self.init_latent\n            )\n            return self.sample_to_image(samples)\n\n        return make_image\n\n\n\n"
  },
  {
    "path": "src/stablediffusion/ldm/dream/generator/txt2img.py",
    "content": "'''\nldm.dream.generator.txt2img inherits from src.stablediffusion.ldm.dream.generator\n'''\n\nimport torch\nimport numpy as  np\nfrom src.stablediffusion.ldm.dream.generator.base import Generator\n\nclass Txt2Img(Generator):\n    def __init__(self,model):\n        super().__init__(model)\n    \n    @torch.no_grad()\n    def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,\n                       conditioning,width,height,step_callback=None,**kwargs):\n        \"\"\"\n        Returns a function returning an image derived from the prompt and the initial image\n        Return value depends on the seed at the time you call it\n        kwargs are 'width' and 'height'\n        \"\"\"\n        uc, c   = conditioning\n\n        @torch.no_grad()\n        def make_image(x_T):\n            shape = [\n                self.latent_channels,\n                height // self.downsampling_factor,\n                width  // self.downsampling_factor,\n            ]\n            samples, _ = sampler.sample(\n                batch_size                   = 1,\n                S                            = steps,\n                x_T                          = x_T,\n                conditioning                 = c,\n                shape                        = shape,\n                verbose                      = False,\n                unconditional_guidance_scale = cfg_scale,\n                unconditional_conditioning   = uc,\n                eta                          = ddim_eta,\n                img_callback                 = step_callback\n            )\n            return self.sample_to_image(samples)\n\n        return make_image\n\n\n    # returns a tensor filled with random numbers from a normal distribution\n    def get_noise(self,width,height):\n        device         = self.model.device\n        if device.type == 'mps':\n            return torch.randn([1,\n                                self.latent_channels,\n                                height // self.downsampling_factor,\n                                width  // self.downsampling_factor],\n                               device='cpu').to(device)\n        else:\n            return torch.randn([1,\n                                self.latent_channels,\n                                height // self.downsampling_factor,\n                                width  // self.downsampling_factor],\n                               device=device)\n"
  },
  {
    "path": "src/stablediffusion/ldm/dream/image_util.py",
    "content": "from math import sqrt, floor, ceil\nfrom PIL import Image\n\nclass InitImageResizer():\n    \"\"\"Simple class to create resized copies of an Image while preserving the aspect ratio.\"\"\"\n    def __init__(self,Image):\n        self.image = Image\n\n    def resize(self,width=None,height=None) -> Image:\n        \"\"\"\n        Return a copy of the image resized to fit within\n        a box width x height. The aspect ratio is \n        maintained. If neither width nor height are provided, \n        then returns a copy of the original image. If one or the other is\n        provided, then the other will be calculated from the\n        aspect ratio.\n\n        Everything is floored to the nearest multiple of 64 so\n        that it can be passed to img2img()\n        \"\"\"\n        im    = self.image\n        \n        ar = im.width/float(im.height)\n\n        # Infer missing values from aspect ratio\n        if not(width or height): # both missing\n            width  = im.width\n            height = im.height\n        elif not height:           # height missing\n            height = int(width/ar)\n        elif not width:            # width missing\n            width  = int(height*ar)\n\n        # rw and rh are the resizing width and height for the image\n        # they maintain the aspect ratio, but may not completelyl fill up\n        # the requested destination size\n        (rw,rh) = (width,int(width/ar)) if im.width>=im.height else (int(height*ar),height)\n\n        #round everything to multiples of 64\n        width,height,rw,rh = map(\n            lambda x: x-x%64, (width,height,rw,rh)\n        )\n\n        # no resize necessary, but return a copy\n        if im.width == width and im.height == height:\n            return im.copy()\n        \n        # otherwise resize the original image so that it fits inside the bounding box\n        resized_image = self.image.resize((rw,rh),resample=Image.Resampling.LANCZOS)\n        return resized_image\n\ndef make_grid(image_list, rows=None, cols=None):\n    image_cnt = len(image_list)\n    if None in (rows, cols):\n        rows = floor(sqrt(image_cnt))  # try to make it square\n        cols = ceil(image_cnt / rows)\n    width = image_list[0].width\n    height = image_list[0].height\n\n    grid_img = Image.new('RGB', (width * cols, height * rows))\n    i = 0\n    for r in range(0, rows):\n        for c in range(0, cols):\n            if i >= len(image_list):\n                break\n            grid_img.paste(image_list[i], (c * width, r * height))\n            i = i + 1\n\n    return grid_img\n\n"
  },
  {
    "path": "src/stablediffusion/ldm/dream/pngwriter.py",
    "content": "\"\"\"\nTwo helper classes for dealing with PNG images and their path names.\nPngWriter -- Converts Images generated by T2I into PNGs, finds\n             appropriate names for them, and writes prompt metadata\n             into the PNG.\nPromptFormatter -- Utility for converting a Namespace of prompt parameters\n             back into a formatted prompt string with command-line switches.\n\"\"\"\nimport os\nimport re\nfrom PIL import PngImagePlugin\n\n# -------------------image generation utils-----\n\n\nclass PngWriter:\n    def __init__(self, outdir):\n        self.outdir = outdir\n        os.makedirs(outdir, exist_ok=True)\n\n    # gives the next unique prefix in outdir\n    def unique_prefix(self):\n        # sort reverse alphabetically until we find max+1\n        dirlist = sorted(os.listdir(self.outdir), reverse=True)\n        # find the first filename that matches our pattern or return 000000.0.png\n        existing_name = next(\n            (f for f in dirlist if re.match('^(\\d+)\\..*\\.png', f)),\n            '0000000.0.png',\n        )\n        basecount = int(existing_name.split('.', 1)[0]) + 1\n        return f'{basecount:06}'\n\n    # saves image named _image_ to outdir/name, writing metadata from prompt\n    # returns full path of output\n    def save_image_and_prompt_to_png(self, image, prompt, name):\n        path = os.path.join(self.outdir, name)\n        info = PngImagePlugin.PngInfo()\n        info.add_text('Dream', prompt)\n        image.save(path, 'PNG', pnginfo=info)\n        return path\n\n\nclass PromptFormatter:\n    def __init__(self, t2i, opt):\n        self.t2i = t2i\n        self.opt = opt\n\n    # note: the t2i object should provide all these values.\n    # there should be no need to or against opt values\n    def normalize_prompt(self):\n        \"\"\"Normalize the prompt and switches\"\"\"\n        t2i = self.t2i\n        opt = self.opt\n\n        switches = list()\n        switches.append(f'\"{opt.prompt}\"')\n        switches.append(f'-s{opt.steps        or t2i.steps}')\n        switches.append(f'-W{opt.width        or t2i.width}')\n        switches.append(f'-H{opt.height       or t2i.height}')\n        switches.append(f'-C{opt.cfg_scale    or t2i.cfg_scale}')\n        switches.append(f'-A{opt.sampler_name or t2i.sampler_name}')\n# to do: put model name into the t2i object\n#        switches.append(f'--model{t2i.model_name}')\n        if opt.seamless or t2i.seamless:\n            switches.append(f'--seamless')\n        if opt.init_img:\n            switches.append(f'-I{opt.init_img}')\n        if opt.fit:\n            switches.append(f'--fit')\n        if opt.strength and opt.init_img is not None:\n            switches.append(f'-f{opt.strength or t2i.strength}')\n        if opt.gfpgan_strength:\n            switches.append(f'-G{opt.gfpgan_strength}')\n        if opt.upscale:\n            switches.append(f'-U {\" \".join([str(u) for u in opt.upscale])}')\n        if opt.variation_amount > 0:\n            switches.append(f'-v{opt.variation_amount}')\n        if opt.with_variations:\n            formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in opt.with_variations)\n            switches.append(f'-V{formatted_variations}')\n        return ' '.join(switches)\n"
  },
  {
    "path": "src/stablediffusion/ldm/dream/readline.py",
    "content": "\"\"\"\nReadline helper functions for dream.py (linux and mac only).\n\"\"\"\nimport os\nimport re\nimport atexit\n\n# ---------------readline utilities---------------------\ntry:\n    import readline\n\n    readline_available = True\nexcept:\n    readline_available = False\n\n\nclass Completer:\n    def __init__(self, options):\n        self.options = sorted(options)\n        return\n\n    def complete(self, text, state):\n        buffer = readline.get_line_buffer()\n\n        if text.startswith(('-I', '--init_img','-M','--init_mask')):\n            return self._path_completions(text, state, ('.png','.jpg','.jpeg'))\n\n        if buffer.strip().endswith('cd') or text.startswith(('.', '/')):\n            return self._path_completions(text, state, ())\n\n        response = None\n        if state == 0:\n            # This is the first time for this text, so build a match list.\n            if text:\n                self.matches = [\n                    s for s in self.options if s and s.startswith(text)\n                ]\n            else:\n                self.matches = self.options[:]\n\n        # Return the state'th item from the match list,\n        # if we have that many.\n        try:\n            response = self.matches[state]\n        except IndexError:\n            response = None\n        return response\n\n    def _path_completions(self, text, state, extensions):\n        # get the path so far\n        # TODO: replace this mess with a regular expression match\n        if text.startswith('-I'):\n            path = text.replace('-I', '', 1).lstrip()\n        elif text.startswith('--init_img='):\n            path = text.replace('--init_img=', '', 1).lstrip()\n        elif text.startswith('--init_mask='):\n            path = text.replace('--init_mask=', '', 1).lstrip()\n        elif text.startswith('-M'):\n            path = text.replace('-M', '', 1).lstrip()\n        else:\n            path = text\n\n        matches = list()\n\n        path = os.path.expanduser(path)\n        if len(path) == 0:\n            matches.append(text + './')\n        else:\n            dir = os.path.dirname(path)\n            dir_list = os.listdir(dir)\n            for n in dir_list:\n                if n.startswith('.') and len(n) > 1:\n                    continue\n                full_path = os.path.join(dir, n)\n                if full_path.startswith(path):\n                    if os.path.isdir(full_path):\n                        matches.append(\n                            os.path.join(os.path.dirname(text), n) + '/'\n                        )\n                    elif n.endswith(extensions):\n                        matches.append(os.path.join(os.path.dirname(text), n))\n\n        try:\n            response = matches[state]\n        except IndexError:\n            response = None\n        return response\n\n\nif readline_available:\n    readline.set_completer(\n        Completer(\n            [\n                '--steps','-s',\n                '--seed','-S',\n                '--iterations','-n',\n                '--width','-W','--height','-H',\n                '--cfg_scale','-C',\n                '--grid','-g',\n                '--individual','-i',\n                '--init_img','-I',\n                '--init_mask','-M',\n                '--strength','-f',\n                '--variants','-v',\n                '--outdir','-o',\n                '--sampler','-A','-m',\n                '--embedding_path',\n                '--device',\n                '--grid','-g',\n                '--gfpgan_strength','-G',\n                '--upscale','-U',\n                '-save_orig','--save_original',\n                '--skip_normalize','-x',\n                '--log_tokenization','t',\n            ]\n        ).complete\n    )\n    readline.set_completer_delims(' ')\n    readline.parse_and_bind('tab: complete')\n\n    histfile = os.path.join(os.path.expanduser('~'), '.dream_history')\n    try:\n        readline.read_history_file(histfile)\n        readline.set_history_length(1000)\n    except FileNotFoundError:\n        pass\n    atexit.register(readline.write_history_file, histfile)\n"
  },
  {
    "path": "src/stablediffusion/ldm/dream/server.py",
    "content": "import argparse\nimport json\nimport base64\nimport mimetypes\nimport os\nfrom http.server import BaseHTTPRequestHandler, ThreadingHTTPServer\nfrom src.stablediffusion.ldm.dream.pngwriter import PngWriter, PromptFormatter\nfrom threading import Event\n\ndef build_opt(post_data, seed, gfpgan_model_exists):\n    opt = argparse.Namespace()\n    setattr(opt, 'prompt', post_data['prompt'])\n    setattr(opt, 'init_img', post_data['initimg'])\n    setattr(opt, 'strength', float(post_data['strength']))\n    setattr(opt, 'iterations', int(post_data['iterations']))\n    setattr(opt, 'steps', int(post_data['steps']))\n    setattr(opt, 'width', int(post_data['width']))\n    setattr(opt, 'height', int(post_data['height']))\n    setattr(opt, 'seamless', 'seamless' in post_data)\n    setattr(opt, 'fit', 'fit' in post_data)\n    setattr(opt, 'mask', 'mask' in post_data)\n    setattr(opt, 'invert_mask', 'invert_mask' in post_data)\n    setattr(opt, 'cfg_scale', float(post_data['cfg_scale']))\n    setattr(opt, 'sampler_name', post_data['sampler_name'])\n    setattr(opt, 'gfpgan_strength', float(post_data['gfpgan_strength']) if gfpgan_model_exists else 0)\n    setattr(opt, 'upscale', [int(post_data['upscale_level']), float(post_data['upscale_strength'])] if post_data['upscale_level'] != '' else None)\n    setattr(opt, 'progress_images', 'progress_images' in post_data)\n    setattr(opt, 'seed', None if int(post_data['seed']) == -1 else int(post_data['seed']))\n    setattr(opt, 'variation_amount', float(post_data['variation_amount']) if int(post_data['seed']) != -1 else 0)\n    setattr(opt, 'with_variations', [])\n\n    broken = False\n    if int(post_data['seed']) != -1 and post_data['with_variations'] != '':\n        for part in post_data['with_variations'].split(','):\n            seed_and_weight = part.split(':')\n            if len(seed_and_weight) != 2:\n                print(f'could not parse with_variation part \"{part}\"')\n                broken = True\n                break\n            try:\n                seed = int(seed_and_weight[0])\n                weight = float(seed_and_weight[1])\n            except ValueError:\n                print(f'could not parse with_variation part \"{part}\"')\n                broken = True\n                break\n            opt.with_variations.append([seed, weight])\n    \n    if broken:\n        raise CanceledException\n\n    if len(opt.with_variations) == 0:\n        opt.with_variations = None\n\n    return opt\n\nclass CanceledException(Exception):\n    pass\n\nclass DreamServer(BaseHTTPRequestHandler):\n    model = None\n    outdir = None\n    canceled = Event()\n\n    def do_GET(self):\n        if self.path == \"/\":\n            self.send_response(200)\n            self.send_header(\"Content-type\", \"text/html\")\n            self.end_headers()\n            with open(\"./static/dream_web/index.html\", \"rb\") as content:\n                self.wfile.write(content.read())\n        elif self.path == \"/config.js\":\n            # unfortunately this import can't be at the top level, since that would cause a circular import\n            from src.stablediffusion.ldm.gfpgan.gfpgan_tools import gfpgan_model_exists\n            self.send_response(200)\n            self.send_header(\"Content-type\", \"application/javascript\")\n            self.end_headers()\n            config = {\n                'gfpgan_model_exists': gfpgan_model_exists\n            }\n            self.wfile.write(bytes(\"let config = \" + json.dumps(config) + \";\\n\", \"utf-8\"))\n        elif self.path == \"/run_log.json\":\n            self.send_response(200)\n            self.send_header(\"Content-type\", \"application/json\")\n            self.end_headers()\n            output = []\n            \n            log_file = os.path.join(self.outdir, \"dream_web_log.txt\")\n            if os.path.exists(log_file):\n                with open(log_file, \"r\") as log:\n                    for line in log:\n                        url, config = line.split(\": {\", maxsplit=1)\n                        config = json.loads(\"{\" + config)\n                        config[\"url\"] = url.lstrip(\".\")\n                        if os.path.exists(url):\n                            output.append(config)\n\n            self.wfile.write(bytes(json.dumps({\"run_log\": output}), \"utf-8\"))\n        elif self.path == \"/cancel\":\n            self.canceled.set()\n            self.send_response(200)\n            self.send_header(\"Content-type\", \"application/json\")\n            self.end_headers()\n            self.wfile.write(bytes('{}', 'utf8'))\n        else:\n            path = \".\" + self.path\n            cwd = os.path.realpath(os.getcwd())\n            is_in_cwd = os.path.commonprefix((os.path.realpath(path), cwd)) == cwd\n            if not (is_in_cwd and os.path.exists(path)):\n                self.send_response(404)\n                return\n            mime_type = mimetypes.guess_type(path)[0]\n            if mime_type is not None:\n                self.send_response(200)\n                self.send_header(\"Content-type\", mime_type)\n                self.end_headers()\n                with open(\".\" + self.path, \"rb\") as content:\n                    self.wfile.write(content.read())\n            else:\n                self.send_response(404)\n\n    def do_POST(self):\n        self.send_response(200)\n        self.send_header(\"Content-type\", \"application/json\")\n        self.end_headers()\n\n        # unfortunately this import can't be at the top level, since that would cause a circular import\n        from src.stablediffusion.ldm.gfpgan.gfpgan_tools import gfpgan_model_exists\n\n        content_length = int(self.headers['Content-Length'])\n        post_data = json.loads(self.rfile.read(content_length))\n        opt = build_opt(post_data, self.model.seed, gfpgan_model_exists)\n\n        self.canceled.clear()\n        print(f\">> Request to generate with prompt: {opt.prompt}\")\n        # In order to handle upscaled images, the PngWriter needs to maintain state\n        # across images generated by each call to prompt2img(), so we define it in\n        # the outer scope of image_done()\n        config = post_data.copy() # Shallow copy\n        config['initimg'] = config.pop('initimg_name', '')\n\n        images_generated = 0    # helps keep track of when upscaling is started\n        images_upscaled = 0     # helps keep track of when upscaling is completed\n        pngwriter = PngWriter(self.outdir)\n\n        prefix = pngwriter.unique_prefix()\n        # if upscaling is requested, then this will be called twice, once when\n        # the images are first generated, and then again when after upscaling\n        # is complete. The upscaling replaces the original file, so the second\n        # entry should not be inserted into the image list.\n        def image_done(image, seed, upscaled=False):\n            name = f'{prefix}.{seed}.png'\n            iter_opt = argparse.Namespace(**vars(opt)) # copy\n            if opt.variation_amount > 0:\n                this_variation = [[seed, opt.variation_amount]]\n                if opt.with_variations is None:\n                    iter_opt.with_variations = this_variation\n                else:\n                    iter_opt.with_variations = opt.with_variations + this_variation\n                iter_opt.variation_amount = 0\n            elif opt.with_variations is None:\n                iter_opt.seed = seed\n            normalized_prompt = PromptFormatter(self.model, iter_opt).normalize_prompt()\n            path = pngwriter.save_image_and_prompt_to_png(image, f'{normalized_prompt} -S{iter_opt.seed}', name)\n\n            if int(config['seed']) == -1:\n                config['seed'] = seed\n            # Append post_data to log, but only once!\n            if not upscaled:\n                with open(os.path.join(self.outdir, \"dream_web_log.txt\"), \"a\") as log:\n                    log.write(f\"{path}: {json.dumps(config)}\\n\")\n\n                self.wfile.write(bytes(json.dumps(\n                    {'event': 'result', 'url': path, 'seed': seed, 'config': config}\n                ) + '\\n',\"utf-8\"))\n\n            # control state of the \"postprocessing...\" message\n            upscaling_requested = opt.upscale or opt.gfpgan_strength > 0\n            nonlocal images_generated # NB: Is this bad python style? It is typical usage in a perl closure.\n            nonlocal images_upscaled  # NB: Is this bad python style? It is typical usage in a perl closure.\n            if upscaled:\n                images_upscaled += 1\n            else:\n                images_generated += 1\n            if upscaling_requested:\n                action = None\n                if images_generated >= opt.iterations:\n                    if images_upscaled < opt.iterations:\n                        action = 'upscaling-started'\n                    else:\n                        action = 'upscaling-done'\n                if action:\n                    x = images_upscaled + 1\n                    self.wfile.write(bytes(json.dumps(\n                        {'event': action, 'processed_file_cnt': f'{x}/{opt.iterations}'}\n                    ) + '\\n',\"utf-8\"))\n\n        step_writer = PngWriter(os.path.join(self.outdir, \"intermediates\"))\n        step_index = 1\n        def image_progress(sample, step):\n            if self.canceled.is_set():\n                self.wfile.write(bytes(json.dumps({'event':'canceled'}) + '\\n', 'utf-8'))\n                raise CanceledException\n            path = None\n            # since rendering images is moderately expensive, only render every 5th image\n            # and don't bother with the last one, since it'll render anyway\n            nonlocal step_index\n            if opt.progress_images and step % 5 == 0 and step < opt.steps - 1:\n                image = self.model.sample_to_image(sample)\n                name = f'{prefix}.{opt.seed}.{step_index}.png'\n                metadata = f'{opt.prompt} -S{opt.seed} [intermediate]'\n                path = step_writer.save_image_and_prompt_to_png(image, metadata, name)\n                step_index += 1\n            self.wfile.write(bytes(json.dumps(\n                {'event': 'step', 'step': step + 1, 'url': path}\n            ) + '\\n',\"utf-8\"))\n\n        try:\n            if opt.init_img is None:\n                # Run txt2img\n                self.model.prompt2image(**vars(opt), step_callback=image_progress, image_callback=image_done)\n            else:\n                # Decode initimg as base64 to temp file\n                with open(\"./img2img-tmp.png\", \"wb\") as f:\n                    initimg = opt.init_img.split(\",\")[1] # Ignore mime type\n                    f.write(base64.b64decode(initimg))\n                opt1 = argparse.Namespace(**vars(opt))\n                opt1.init_img = \"./img2img-tmp.png\"\n\n                try:\n                    # Run img2img\n                    self.model.prompt2image(**vars(opt1), step_callback=image_progress, image_callback=image_done)\n                finally:\n                    # Remove the temp file\n                    os.remove(\"./img2img-tmp.png\")\n        except CanceledException:\n            print(f\"Canceled.\")\n            return\n\n\nclass ThreadingDreamServer(ThreadingHTTPServer):\n    def __init__(self, server_address):\n        super(ThreadingDreamServer, self).__init__(server_address, DreamServer)\n"
  },
  {
    "path": "src/stablediffusion/ldm/generate.py",
    "content": "# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)\n\n# Derived from source code carrying the following copyrights\n# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich\n# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors\n\nimport torch\nimport numpy as np\nimport random\nimport os\nimport time\nimport re\nimport sys\nimport traceback\nimport transformers\n\nfrom omegaconf import OmegaConf\nfrom PIL import Image, ImageOps\nfrom torch import nn\nfrom pytorch_lightning import seed_everything\n\nfrom src.stablediffusion.ldm.util                      import instantiate_from_config\nfrom src.stablediffusion.ldm.models.diffusion.ddim     import DDIMSampler\nfrom src.stablediffusion.ldm.models.diffusion.plms     import PLMSSampler\nfrom src.stablediffusion.ldm.models.diffusion.ksampler import KSampler\nfrom src.stablediffusion.ldm.dream.pngwriter           import PngWriter\nfrom src.stablediffusion.ldm.dream.image_util          import InitImageResizer\nfrom src.stablediffusion.ldm.dream.devices             import choose_torch_device\nfrom src.stablediffusion.ldm.dream.conditioning        import get_uc_and_c\n\n\"\"\"Simplified text to image API for stable diffusion/latent diffusion\n\nExample Usage:\n\nfrom src.stablediffusion.ldm.generate import Generate\n\n# Create an object with default values\ngr = Generate()\n\n# do the slow model initialization\ngr.load_model()\n\n# Do the fast inference & image generation. Any options passed here\n# override the default values assigned during class initialization\n# Will call load_model() if the model was not previously loaded and so\n# may be slow at first.\n# The method returns a list of images. Each row of the list is a sub-list of [filename,seed]\nresults = gr.prompt2png(prompt     = \"an astronaut riding a horse\",\n                         outdir     = \"./outputs/samples\",\n                         iterations = 3)\n\nfor row in results:\n    print(f'filename={row[0]}')\n    print(f'seed    ={row[1]}')\n\n# Same thing, but using an initial image.\nresults = gr.prompt2png(prompt   = \"an astronaut riding a horse\",\n                         outdir   = \"./outputs/,\n                         iterations = 3,\n                         init_img = \"./sketches/horse+rider.png\")\n\nfor row in results:\n    print(f'filename={row[0]}')\n    print(f'seed    ={row[1]}')\n\n# Same thing, but we return a series of Image objects, which lets you manipulate them,\n# combine them, and save them under arbitrary names\n\nresults = gr.prompt2image(prompt   = \"an astronaut riding a horse\"\n                           outdir   = \"./outputs/\")\nfor row in results:\n    im   = row[0]\n    seed = row[1]\n    im.save(f'./outputs/samples/an_astronaut_riding_a_horse-{seed}.png')\n    im.thumbnail(100,100).save('./outputs/samples/astronaut_thumb.jpg')\n\nNote that the old txt2img() and img2img() calls are deprecated but will\nstill work.\n\nThe full list of arguments to Generate() are:\ngr = Generate(\n          weights     = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt')\n          config     = path to model configuraiton ('configs/stable-diffusion/v1-inference.yaml')\n          iterations  = <integer>     // how many times to run the sampling (1)\n          steps       = <integer>     // 50\n          seed        = <integer>     // current system time\n          sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms']  // k_lms\n          grid        = <boolean>     // false\n          width       = <integer>     // image width, multiple of 64 (512)\n          height      = <integer>     // image height, multiple of 64 (512)\n          cfg_scale   = <float>       // condition-free guidance scale (7.5)\n          )\n\n\"\"\"\n\n\nclass Generate:\n    \"\"\"Generate class\n    Stores default values for multiple configuration items\n    \"\"\"\n\n    def __init__(\n            self,\n            iterations            = 1,\n            steps                 = 50,\n            cfg_scale             = 7.5,\n            weights               = 'models/ldm/stable-diffusion-v1/model.ckpt',\n            config                = 'configs/stable-diffusion/v1-inference.yaml',\n            grid                  = False,\n            width                 = 512,\n            height                = 512,\n            sampler_name          = 'k_lms',\n            ddim_eta              = 0.0,  # deterministic\n            precision             = 'autocast',\n            full_precision        = False,\n            strength              = 0.75,  # default in scripts/img2img.py\n            seamless              = False,\n            embedding_path        = None,\n            device_type           = 'cuda',\n            ignore_ctrl_c         = False,\n    ):\n        self.iterations               = iterations\n        self.width                    = width\n        self.height                   = height\n        self.steps                    = steps\n        self.cfg_scale                = cfg_scale\n        self.weights                  = weights\n        self.config                   = config\n        self.sampler_name             = sampler_name\n        self.grid                     = grid\n        self.ddim_eta                 = ddim_eta\n        self.precision                = precision\n        self.full_precision           = True if choose_torch_device() == 'mps' else full_precision\n        self.strength                 = strength\n        self.seamless                 = seamless\n        self.embedding_path           = embedding_path\n        self.device_type              = device_type\n        self.ignore_ctrl_c            = ignore_ctrl_c    # note, this logic probably doesn't belong here...\n        self.model                    = None     # empty for now\n        self.sampler                  = None\n        self.device                   = None\n        self.generators               = {}\n        self.base_generator           = None\n        self.seed                     = None\n\n        if device_type == 'cuda' and not torch.cuda.is_available():\n            device_type = choose_torch_device()\n            print(\">> cuda not available, using device\", device_type)\n        self.device = torch.device(device_type)\n\n        # for VRAM usage statistics\n        device_type          = choose_torch_device()\n        self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None\n        transformers.logging.set_verbosity_error()\n\n    def prompt2png(self, prompt, outdir, **kwargs):\n        \"\"\"\n        Takes a prompt and an output directory, writes out the requested number\n        of PNG files, and returns an array of [[filename,seed],[filename,seed]...]\n        Optional named arguments are the same as those passed to Generate and prompt2image()\n        \"\"\"\n        results = self.prompt2image(prompt, **kwargs)\n        pngwriter = PngWriter(outdir)\n        prefix = pngwriter.unique_prefix()\n        outputs = []\n        for image, seed in results:\n            name = f'{prefix}.{seed}.png'\n            path = pngwriter.save_image_and_prompt_to_png(\n                image, f'{prompt} -S{seed}', name)\n            outputs.append([path, seed])\n        return outputs\n\n    def txt2img(self, prompt, **kwargs):\n        outdir = kwargs.pop('outdir', 'outputs/img-samples')\n        return self.prompt2png(prompt, outdir, **kwargs)\n\n    def img2img(self, prompt, **kwargs):\n        outdir = kwargs.pop('outdir', 'outputs/img-samples')\n        assert (\n            'init_img' in kwargs\n        ), 'call to img2img() must include the init_img argument'\n        return self.prompt2png(prompt, outdir, **kwargs)\n\n    def prompt2image(\n            self,\n            # these are common\n            prompt,\n            iterations     =    None,\n            steps          =    None,\n            seed           =    None,\n            cfg_scale      =    None,\n            ddim_eta       =    None,\n            skip_normalize =    False,\n            image_callback =    None,\n            step_callback  =    None,\n            width          =    None,\n            height         =    None,\n            sampler_name   =    None,\n            seamless       =    False,\n            log_tokenization=  False,\n            with_variations =   None,\n            variation_amount =  0.0,\n            # these are specific to img2img and inpaint\n            init_img       =    None,\n            init_mask      =    None,\n            fit            =    False,\n            strength       =    None,\n            # these are specific to GFPGAN/ESRGAN\n            gfpgan_strength=    0,\n            save_original  =    False,\n            upscale        =    None,\n            **args,\n    ):   # eat up additional cruft\n        \"\"\"\n        ldm.generate.prompt2image() is the common entry point for txt2img() and img2img()\n        It takes the following arguments:\n           prompt                          // prompt string (no default)\n           iterations                      // iterations (1); image count=iterations\n           steps                           // refinement steps per iteration\n           seed                            // seed for random number generator\n           width                           // width of image, in multiples of 64 (512)\n           height                          // height of image, in multiples of 64 (512)\n           cfg_scale                       // how strongly the prompt influences the image (7.5) (must be >1)\n           seamless                        // whether the generated image should tile\n           init_img                        // path to an initial image\n           strength                        // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely\n           gfpgan_strength                 // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely\n           ddim_eta                        // image randomness (eta=0.0 means the same seed always produces the same image)\n           step_callback                   // a function or method that will be called each step\n           image_callback                  // a function or method that will be called each time an image is generated\n           with_variations                 // a weighted list [(seed_1, weight_1), (seed_2, weight_2), ...] of variations which should be applied before doing any generation\n           variation_amount                // optional 0-1 value to slerp from -S noise to random noise (allows variations on an image)\n\n        To use the step callback, define a function that receives two arguments:\n        - Image GPU data\n        - The step number\n\n        To use the image callback, define a function of method that receives two arguments, an Image object\n        and the seed. You can then do whatever you like with the image, including converting it to\n        different formats and manipulating it. For example:\n\n            def process_image(image,seed):\n                image.save(f{'images/seed.png'})\n\n        The callback used by the prompt2png() can be found in ldm/dream_util.py. It contains code\n        to create the requested output directory, select a unique informative name for each image, and\n        write the prompt into the PNG metadata.\n        \"\"\"\n        # TODO: convert this into a getattr() loop\n        steps                 = steps      or self.steps\n        width                 = width      or self.width\n        height                = height     or self.height\n        seamless              = seamless   or self.seamless\n        cfg_scale             = cfg_scale  or self.cfg_scale\n        ddim_eta              = ddim_eta   or self.ddim_eta\n        iterations            = iterations or self.iterations\n        strength              = strength   or self.strength\n        self.seed             = seed\n        self.log_tokenization = log_tokenization\n        with_variations = [] if with_variations is None else with_variations\n\n        model = (\n            self.load_model()\n        )  # will instantiate the model or return it from cache\n\n        for m in model.modules():\n            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):\n                m.padding_mode = 'circular' if seamless else m._orig_padding_mode\n        \n        assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'\n        assert (\n            0.0 < strength < 1.0\n        ), 'img2img and inpaint strength can only work with 0.0 < strength < 1.0'\n        assert (\n                0.0 <= variation_amount <= 1.0\n        ), '-v --variation_amount must be in [0.0, 1.0]'\n\n        # check this logic - doesn't look right\n        if len(with_variations) > 0 or variation_amount > 1.0:\n            assert seed is not None,\\\n                'seed must be specified when using with_variations'\n            if variation_amount == 0.0:\n                assert iterations == 1,\\\n                    'when using --with_variations, multiple iterations are only possible when using --variation_amount'\n            assert all(0 <= weight <= 1 for _, weight in with_variations),\\\n                f'variation weights must be in [0.0, 1.0]: got {[weight for _, weight in with_variations]}'\n\n        width, height, _ = self._resolution_check(width, height, log=True)\n\n        if sampler_name and (sampler_name != self.sampler_name):\n            self.sampler_name = sampler_name\n            self._set_sampler()\n\n        tic = time.time()\n        if torch.cuda.is_available():\n            torch.cuda.reset_peak_memory_stats()\n\n        results          = list()\n        init_image       = None\n        mask_image       = None\n\n        try:\n            uc, c = get_uc_and_c(\n                prompt, model=self.model,\n                skip_normalize=skip_normalize,\n                log_tokens=self.log_tokenization\n            )\n\n            (init_image,mask_image) = self._make_images(init_img,init_mask, width, height, fit)\n            \n            if (init_image is not None) and (mask_image is not None):\n                generator = self._make_inpaint()\n            elif init_image is not None:\n                generator = self._make_img2img()\n            else:\n                generator = self._make_txt2img()\n\n            generator.set_variation(self.seed, variation_amount, with_variations)\n            results = generator.generate(\n                prompt,\n                iterations     = iterations,\n                seed           = self.seed,\n                sampler        = self.sampler,\n                steps          = steps,\n                cfg_scale      = cfg_scale,\n                conditioning   = (uc,c),\n                ddim_eta       = ddim_eta,\n                image_callback = image_callback,  # called after the final image is generated\n                step_callback  = step_callback,   # called after each intermediate image is generated\n                width          = width,\n                height         = height,\n                init_image     = init_image,      # notice that init_image is different from init_img\n                mask_image     = mask_image,\n                strength       = strength,\n            )\n\n            if upscale is not None or gfpgan_strength > 0:\n                self.upscale_and_reconstruct(results,\n                                             upscale        = upscale,\n                                             strength       = gfpgan_strength,\n                                             save_original  = save_original,\n                                             image_callback = image_callback)\n\n        except KeyboardInterrupt:\n            print('*interrupted*')\n            if not self.ignore_ctrl_c:\n                raise KeyboardInterrupt\n            print(\n                '>> Partial results will be returned; if --grid was requested, nothing will be returned.'\n            )\n        except RuntimeError as e:\n            print(traceback.format_exc(), file=sys.stderr)\n            print('>> Could not generate image.')\n\n        toc = time.time()\n        print('>> Usage stats:')\n        print(\n            f'>>   {len(results)} image(s) generated in', '%4.2fs' % (toc - tic)\n        )\n        if torch.cuda.is_available() and self.device.type == 'cuda':\n            print(\n                f'>>   Max VRAM used for this generation:',\n                '%4.2fG.' % (torch.cuda.max_memory_allocated() / 1e9),\n                'Current VRAM utilization:'\n                '%4.2fG' % (torch.cuda.memory_allocated() / 1e9),\n            )\n\n            self.session_peakmem = max(\n                self.session_peakmem, torch.cuda.max_memory_allocated()\n            )\n            print(\n                f'>>   Max VRAM used since script start: ',\n                '%4.2fG' % (self.session_peakmem / 1e9),\n            )\n        return results\n\n    def _make_images(self, img_path, mask_path, width, height, fit=False):\n        init_image      = None\n        init_mask       = None\n        if not img_path:\n            return None,None\n\n        image        = self._load_img(img_path, width, height, fit=fit) # this returns an Image\n        init_image   = self._create_init_image(image)                   # this returns a torch tensor\n\n        if self._has_transparency(image) and not mask_path:      # if image has a transparent area and no mask was provided, then try to generate mask\n            print('>> Initial image has transparent areas. Will inpaint in these regions.')\n            if self._check_for_erasure(image):\n                print(\n                    '>> WARNING: Colors underneath the transparent region seem to have been erased.\\n',\n                    '>>          Inpainting will be suboptimal. Please preserve the colors when making\\n',\n                    '>>          a transparency mask, or provide mask explicitly using --init_mask (-M).'\n                )\n            init_mask = self._create_init_mask(image)                   # this returns a torch tensor\n\n        if mask_path:\n            mask_image  = self._load_img(mask_path, width, height, fit=fit) # this returns an Image\n            init_mask   = self._create_init_mask(mask_image)\n\n        return init_image,init_mask\n\n    def _make_img2img(self):\n        if not self.generators.get('img2img'):\n            from src.stablediffusion.ldm.dream.generator.img2img import Img2Img\n            self.generators['img2img'] = Img2Img(self.model)\n        return self.generators['img2img']\n\n    def _make_txt2img(self):\n        if not self.generators.get('txt2img'):\n            from src.stablediffusion.ldm.dream.generator.txt2img import Txt2Img\n            self.generators['txt2img'] = Txt2Img(self.model)\n        return self.generators['txt2img']\n\n    def _make_inpaint(self):\n        if not self.generators.get('inpaint'):\n            from src.stablediffusion.ldm.dream.generator.inpaint import Inpaint\n            self.generators['inpaint'] = Inpaint(self.model)\n        return self.generators['inpaint']\n\n    def load_model(self):\n        \"\"\"Load and initialize the model from configuration variables passed at object creation time\"\"\"\n        if self.model is None:\n            seed_everything(random.randrange(0, np.iinfo(np.uint32).max))\n            try:\n                config = OmegaConf.load(self.config)\n                model = self._load_model_from_config(config, self.weights)\n                if self.embedding_path is not None:\n                    model.embedding_manager.load(\n                        self.embedding_path, self.full_precision\n                    )\n                self.model = model.to(self.device)\n                # model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here\n                self.model.cond_stage_model.device = self.device\n            except AttributeError as e:\n                print(f'>> Error loading model. {str(e)}', file=sys.stderr)\n                print(traceback.format_exc(), file=sys.stderr)\n                raise SystemExit from e\n\n            self._set_sampler()\n\n            for m in self.model.modules():\n                if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):\n                    m._orig_padding_mode = m.padding_mode\n\n        return self.model\n\n    def upscale_and_reconstruct(self,\n                                image_list,\n                                upscale       = None,\n                                strength      =  0.0,\n                                save_original = False,\n                                image_callback = None):\n        try:\n            if upscale is not None:\n                from src.stablediffusion.ldm.gfpgan.gfpgan_tools import real_esrgan_upscale\n            if strength > 0:\n                from src.stablediffusion.ldm.gfpgan.gfpgan_tools import run_gfpgan\n        except (ModuleNotFoundError, ImportError):\n            print(traceback.format_exc(), file=sys.stderr)\n            print('>> You may need to install the ESRGAN and/or GFPGAN modules')\n            return\n            \n        for r in image_list:\n            image, seed = r\n            try:\n                if upscale is not None:\n                    if len(upscale) < 2:\n                        upscale.append(0.75)\n                    image = real_esrgan_upscale(\n                        image,\n                        upscale[1],\n                        int(upscale[0]),\n                        seed,\n                    )\n                if strength > 0:\n                    image = run_gfpgan(\n                        image, strength, seed, 1\n                    )\n            except Exception as e:\n                print(\n                    f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\\n{e}'\n                )\n\n            if image_callback is not None:\n                image_callback(image, seed, upscaled=True)\n            else:\n                r[0] = image\n\n    # to help WebGUI - front end to generator util function\n    def sample_to_image(self,samples):\n        return self._sample_to_image(samples)\n\n    def _sample_to_image(self,samples):\n        if not self.base_generator:\n            from src.stablediffusion.ldm.dream.generator import Generator\n            self.base_generator = Generator(self.model)\n        return self.base_generator.sample_to_image(samples)\n\n    def _set_sampler(self):\n        msg = f'>> Setting Sampler to {self.sampler_name}'\n        if self.sampler_name == 'plms':\n            self.sampler = PLMSSampler(self.model, device=self.device)\n        elif self.sampler_name == 'ddim':\n            self.sampler = DDIMSampler(self.model, device=self.device)\n        elif self.sampler_name == 'k_dpm_2_a':\n            self.sampler = KSampler(\n                self.model, 'dpm_2_ancestral', device=self.device\n            )\n        elif self.sampler_name == 'k_dpm_2':\n            self.sampler = KSampler(self.model, 'dpm_2', device=self.device)\n        elif self.sampler_name == 'k_euler_a':\n            self.sampler = KSampler(\n                self.model, 'euler_ancestral', device=self.device\n            )\n        elif self.sampler_name == 'k_euler':\n            self.sampler = KSampler(self.model, 'euler', device=self.device)\n        elif self.sampler_name == 'k_heun':\n            self.sampler = KSampler(self.model, 'heun', device=self.device)\n        elif self.sampler_name == 'k_lms':\n            self.sampler = KSampler(self.model, 'lms', device=self.device)\n        else:\n            msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms'\n            self.sampler = PLMSSampler(self.model, device=self.device)\n\n        print(msg)\n\n    def _load_model_from_config(self, config, ckpt):\n        print(f'>> Loading model from {ckpt}')\n\n        # for usage statistics\n        device_type = choose_torch_device()\n        if device_type == 'cuda':\n            torch.cuda.reset_peak_memory_stats() \n        tic = time.time()\n\n        # this does the work\n        pl_sd = torch.load(ckpt, map_location='cpu')\n        sd = pl_sd['state_dict']\n        model = instantiate_from_config(config.model)\n        m, u = model.load_state_dict(sd, strict=False)\n        \n        if self.full_precision:\n            print(\n                '>> Using slower but more accurate full-precision math (--full_precision)'\n            )\n        else:\n            print(\n                '>> Using half precision math. Call with --full_precision to use more accurate but VRAM-intensive full precision.'\n            )\n            model.half()\n        model.to(self.device)\n        model.eval()\n\n        # usage statistics\n        toc = time.time()\n        print(\n            f'>> Model loaded in', '%4.2fs' % (toc - tic)\n        )\n        if device_type == 'cuda':\n            print(\n                '>> Max VRAM used to load the model:',\n                '%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),\n                '\\n>> Current VRAM usage:'\n                '%4.2fG' % (torch.cuda.memory_allocated() / 1e9),\n            )\n\n        return model\n\n    def _load_img(self, path, width, height, fit=False):\n        assert os.path.exists(path), f'>> {path}: File not found'\n\n        #        with Image.open(path) as img:\n        #            image = img.convert('RGBA')\n        image = Image.open(path)\n        print(\n            f'>> loaded input image of size {image.width}x{image.height} from {path}'\n        )\n        if fit:\n            image = self._fit_image(image,(width,height))\n        else:\n            image = self._squeeze_image(image)\n        return image\n\n    def _create_init_image(self,image):\n        image = image.convert('RGB')\n        # print(\n        #     f'>> DEBUG: writing the image to img.png'\n        # )\n        # image.save('img.png')\n        image = np.array(image).astype(np.float32) / 255.0\n        image = image[None].transpose(0, 3, 1, 2)\n        image = torch.from_numpy(image)\n        image = 2.0 * image - 1.0 \n        return image.to(self.device)\n\n    def _create_init_mask(self, image):\n        # convert into a black/white mask\n        image = self._image_to_mask(image)\n        image = image.convert('RGB')\n        # BUG: We need to use the model's downsample factor rather than hardcoding \"8\"\n        from src.stablediffusion.ldm.dream.generator.base import downsampling\n        image = image.resize((image.width//downsampling, image.height//downsampling), resample=Image.Resampling.LANCZOS)\n        # print(\n        #     f'>> DEBUG: writing the mask to mask.png'\n        #     )\n        # image.save('mask.png')\n        image = np.array(image)\n        image = image.astype(np.float32) / 255.0\n        image = image[None].transpose(0, 3, 1, 2)\n        image = torch.from_numpy(image)\n        return image.to(self.device)\n\n    # The mask is expected to have the region to be inpainted\n    # with alpha transparency. It converts it into a black/white\n    # image with the transparent part black.\n    def _image_to_mask(self, mask_image, invert=False) -> Image:\n        # Obtain the mask from the transparency channel\n        mask = Image.new(mode=\"L\", size=mask_image.size, color=255)\n        mask.putdata(mask_image.getdata(band=3))\n        if invert:\n            mask = ImageOps.invert(mask)\n        return mask\n\n    def _has_transparency(self,image):\n        if image.info.get(\"transparency\", None) is not None:\n            return True\n        if image.mode == \"P\":\n            transparent = image.info.get(\"transparency\", -1)\n            for _, index in image.getcolors():\n                if index == transparent:\n                    return True\n        elif image.mode == \"RGBA\":\n            extrema = image.getextrema()\n            if extrema[3][0] < 255:\n                return True\n        return False\n\n    \n    def _check_for_erasure(self,image):\n        width, height = image.size\n        pixdata       = image.load()\n        colored       = 0\n        for y in range(height):\n            for x in range(width):\n                if pixdata[x, y][3] == 0:\n                    r, g, b, _ = pixdata[x, y]\n                    if (r, g, b) != (0, 0, 0) and \\\n                       (r, g, b) != (255, 255, 255):\n                        colored += 1\n        return colored == 0\n\n    def _squeeze_image(self,image):\n        x,y,resize_needed = self._resolution_check(image.width,image.height)\n        if resize_needed:\n            return InitImageResizer(image).resize(x,y)\n        return image\n\n\n    def _fit_image(self,image,max_dimensions):\n        w,h = max_dimensions\n        print(\n            f'>> image will be resized to fit inside a box {w}x{h} in size.'\n        )\n        if image.width > image.height:\n            h   = None   # by setting h to none, we tell InitImageResizer to fit into the width and calculate height\n        elif image.height > image.width:\n            w   = None   # ditto for w\n        else:\n            pass\n        image = InitImageResizer(image).resize(w,h)   # note that InitImageResizer does the multiple of 64 truncation internally\n        print(\n            f'>> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}'\n            )\n        return image\n\n    def _resolution_check(self, width, height, log=False):\n        resize_needed = False\n        w, h = map(\n            lambda x: x - x % 64, (width, height)\n        )  # resize to integer multiple of 64\n        if h != height or w != width:\n            if log:\n                print(\n                    f'>> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}'\n                )\n            height = h\n            width  = w\n            resize_needed = True\n\n        if (width * height) > (self.width * self.height):\n            print(\">> This input is larger than your defaults. If you run out of memory, please use a smaller image.\")\n\n        return width, height, resize_needed\n\n\n"
  },
  {
    "path": "src/stablediffusion/ldm/gfpgan/gfpgan_tools.py",
    "content": "import torch\nimport warnings\nimport os\nimport sys\nimport numpy as np\n\nfrom PIL import Image\nfrom scripts.dream import create_argv_parser\n\narg_parser = create_argv_parser()\nopt        = arg_parser.parse_args()\nmodel_path          = os.path.join(opt.gfpgan_dir, opt.gfpgan_model_path)\ngfpgan_model_exists = os.path.isfile(model_path)\n\ndef run_gfpgan(image, strength, seed, upsampler_scale=4):\n    print(f'>> GFPGAN - Restoring Faces for image seed:{seed}')\n    gfpgan = None\n    with warnings.catch_warnings():\n        warnings.filterwarnings('ignore', category=DeprecationWarning)\n        warnings.filterwarnings('ignore', category=UserWarning)\n        \n        try:\n            if not gfpgan_model_exists:\n                raise Exception('GFPGAN model not found at path ' + model_path)\n\n            sys.path.append(os.path.abspath(opt.gfpgan_dir))\n            from gfpgan import GFPGANer\n\n            bg_upsampler = _load_gfpgan_bg_upsampler(\n                opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile\n            )\n\n            gfpgan = GFPGANer(\n                model_path=model_path,\n                upscale=upsampler_scale,\n                arch='clean',\n                channel_multiplier=2,\n                bg_upsampler=bg_upsampler,\n            )\n        except Exception:\n            import traceback\n\n            print('>> Error loading GFPGAN:', file=sys.stderr)\n            print(traceback.format_exc(), file=sys.stderr)\n\n    if gfpgan is None:\n        print(\n            f'>> WARNING: GFPGAN not initialized.'\n        )\n        print(\n            f'>> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth to {model_path}, \\nor change GFPGAN directory with --gfpgan_dir.'\n        )\n        return image\n\n    image = image.convert('RGB')\n\n    cropped_faces, restored_faces, restored_img = gfpgan.enhance(\n        np.array(image, dtype=np.uint8),\n        has_aligned=False,\n        only_center_face=False,\n        paste_back=True,\n    )\n    res = Image.fromarray(restored_img)\n\n    if strength < 1.0:\n        # Resize the image to the new image if the sizes have changed\n        if restored_img.size != image.size:\n            image = image.resize(res.size)\n        res = Image.blend(image, res, strength)\n\n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n    gfpgan = None\n\n    return res\n\n\ndef _load_gfpgan_bg_upsampler(bg_upsampler, upsampler_scale, bg_tile=400):\n    if bg_upsampler == 'realesrgan':\n        if not torch.cuda.is_available(): # CPU or MPS on M1\n            use_half_precision = False\n        else:\n            use_half_precision = True\n\n        model_path = {\n            2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',\n            4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',\n        }\n\n        if upsampler_scale not in model_path:\n            return None\n\n        from basicsr.archs.rrdbnet_arch import RRDBNet\n        from realesrgan import RealESRGANer\n\n        if upsampler_scale == 4:\n            model = RRDBNet(\n                num_in_ch=3,\n                num_out_ch=3,\n                num_feat=64,\n                num_block=23,\n                num_grow_ch=32,\n                scale=4,\n            )\n        if upsampler_scale == 2:\n            model = RRDBNet(\n                num_in_ch=3,\n                num_out_ch=3,\n                num_feat=64,\n                num_block=23,\n                num_grow_ch=32,\n                scale=2,\n            )\n\n        bg_upsampler = RealESRGANer(\n            scale=upsampler_scale,\n            model_path=model_path[upsampler_scale],\n            model=model,\n            tile=bg_tile,\n            tile_pad=10,\n            pre_pad=0,\n            half=use_half_precision,\n        )\n    else:\n        bg_upsampler = None\n\n    return bg_upsampler\n\n\ndef real_esrgan_upscale(image, strength, upsampler_scale, seed):\n    print(\n        f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'\n    )\n\n    with warnings.catch_warnings():\n        warnings.filterwarnings('ignore', category=DeprecationWarning)\n        warnings.filterwarnings('ignore', category=UserWarning)\n\n        try:\n            upsampler = _load_gfpgan_bg_upsampler(\n                opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile\n            )\n        except Exception:\n            import traceback\n\n            print('>> Error loading Real-ESRGAN:', file=sys.stderr)\n            print(traceback.format_exc(), file=sys.stderr)\n\n    output, img_mode = upsampler.enhance(\n        np.array(image, dtype=np.uint8),\n        outscale=upsampler_scale,\n        alpha_upsampler=opt.gfpgan_bg_upsampler,\n    )\n\n    res = Image.fromarray(output)\n\n    if strength < 1.0:\n        # Resize the image to the new image if the sizes have changed\n        if output.size != image.size:\n            image = image.resize(res.size)\n        res = Image.blend(image, res, strength)\n\n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n    upsampler = None\n\n    return res\n"
  },
  {
    "path": "src/stablediffusion/ldm/lr_scheduler.py",
    "content": "import numpy as np\n\n\nclass LambdaWarmUpCosineScheduler:\n    \"\"\"\n    note: use with a base_lr of 1.0\n    \"\"\"\n\n    def __init__(\n        self,\n        warm_up_steps,\n        lr_min,\n        lr_max,\n        lr_start,\n        max_decay_steps,\n        verbosity_interval=0,\n    ):\n        self.lr_warm_up_steps = warm_up_steps\n        self.lr_start = lr_start\n        self.lr_min = lr_min\n        self.lr_max = lr_max\n        self.lr_max_decay_steps = max_decay_steps\n        self.last_lr = 0.0\n        self.verbosity_interval = verbosity_interval\n\n    def schedule(self, n, **kwargs):\n        if self.verbosity_interval > 0:\n            if n % self.verbosity_interval == 0:\n                print(\n                    f'current step: {n}, recent lr-multiplier: {self.last_lr}'\n                )\n        if n < self.lr_warm_up_steps:\n            lr = (\n                self.lr_max - self.lr_start\n            ) / self.lr_warm_up_steps * n + self.lr_start\n            self.last_lr = lr\n            return lr\n        else:\n            t = (n - self.lr_warm_up_steps) / (\n                self.lr_max_decay_steps - self.lr_warm_up_steps\n            )\n            t = min(t, 1.0)\n            lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (\n                1 + np.cos(t * np.pi)\n            )\n            self.last_lr = lr\n            return lr\n\n    def __call__(self, n, **kwargs):\n        return self.schedule(n, **kwargs)\n\n\nclass LambdaWarmUpCosineScheduler2:\n    \"\"\"\n    supports repeated iterations, configurable via lists\n    note: use with a base_lr of 1.0.\n    \"\"\"\n\n    def __init__(\n        self,\n        warm_up_steps,\n        f_min,\n        f_max,\n        f_start,\n        cycle_lengths,\n        verbosity_interval=0,\n    ):\n        assert (\n            len(warm_up_steps)\n            == len(f_min)\n            == len(f_max)\n            == len(f_start)\n            == len(cycle_lengths)\n        )\n        self.lr_warm_up_steps = warm_up_steps\n        self.f_start = f_start\n        self.f_min = f_min\n        self.f_max = f_max\n        self.cycle_lengths = cycle_lengths\n        self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))\n        self.last_f = 0.0\n        self.verbosity_interval = verbosity_interval\n\n    def find_in_interval(self, n):\n        interval = 0\n        for cl in self.cum_cycles[1:]:\n            if n <= cl:\n                return interval\n            interval += 1\n\n    def schedule(self, n, **kwargs):\n        cycle = self.find_in_interval(n)\n        n = n - self.cum_cycles[cycle]\n        if self.verbosity_interval > 0:\n            if n % self.verbosity_interval == 0:\n                print(\n                    f'current step: {n}, recent lr-multiplier: {self.last_f}, '\n                    f'current cycle {cycle}'\n                )\n        if n < self.lr_warm_up_steps[cycle]:\n            f = (\n                self.f_max[cycle] - self.f_start[cycle]\n            ) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]\n            self.last_f = f\n            return f\n        else:\n            t = (n - self.lr_warm_up_steps[cycle]) / (\n                self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]\n            )\n            t = min(t, 1.0)\n            f = self.f_min[cycle] + 0.5 * (\n                self.f_max[cycle] - self.f_min[cycle]\n            ) * (1 + np.cos(t * np.pi))\n            self.last_f = f\n            return f\n\n    def __call__(self, n, **kwargs):\n        return self.schedule(n, **kwargs)\n\n\nclass LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):\n    def schedule(self, n, **kwargs):\n        cycle = self.find_in_interval(n)\n        n = n - self.cum_cycles[cycle]\n        if self.verbosity_interval > 0:\n            if n % self.verbosity_interval == 0:\n                print(\n                    f'current step: {n}, recent lr-multiplier: {self.last_f}, '\n                    f'current cycle {cycle}'\n                )\n\n        if n < self.lr_warm_up_steps[cycle]:\n            f = (\n                self.f_max[cycle] - self.f_start[cycle]\n            ) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]\n            self.last_f = f\n            return f\n        else:\n            f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (\n                self.cycle_lengths[cycle] - n\n            ) / (self.cycle_lengths[cycle])\n            self.last_f = f\n            return f\n"
  },
  {
    "path": "src/stablediffusion/ldm/models/autoencoder.py",
    "content": "import torch\nimport pytorch_lightning as pl\nimport torch.nn.functional as F\nfrom contextlib import contextmanager\n\nfrom taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer\n\nfrom src.stablediffusion.ldm.modules.diffusionmodules.model import Encoder, Decoder\nfrom src.stablediffusion.ldm.modules.distributions.distributions import (\n    DiagonalGaussianDistribution,\n)\n\nfrom src.stablediffusion.ldm.util import instantiate_from_config\n\n\nclass VQModel(pl.LightningModule):\n    def __init__(\n        self,\n        ddconfig,\n        lossconfig,\n        n_embed,\n        embed_dim,\n        ckpt_path=None,\n        ignore_keys=[],\n        image_key='image',\n        colorize_nlabels=None,\n        monitor=None,\n        batch_resize_range=None,\n        scheduler_config=None,\n        lr_g_factor=1.0,\n        remap=None,\n        sane_index_shape=False,  # tell vector quantizer to return indices as bhw\n        use_ema=False,\n    ):\n        super().__init__()\n        self.embed_dim = embed_dim\n        self.n_embed = n_embed\n        self.image_key = image_key\n        self.encoder = Encoder(**ddconfig)\n        self.decoder = Decoder(**ddconfig)\n        self.loss = instantiate_from_config(lossconfig)\n        self.quantize = VectorQuantizer(\n            n_embed,\n            embed_dim,\n            beta=0.25,\n            remap=remap,\n            sane_index_shape=sane_index_shape,\n        )\n        self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1)\n        self.post_quant_conv = torch.nn.Conv2d(\n            embed_dim, ddconfig['z_channels'], 1\n        )\n        if colorize_nlabels is not None:\n            assert type(colorize_nlabels) == int\n            self.register_buffer(\n                'colorize', torch.randn(3, colorize_nlabels, 1, 1)\n            )\n        if monitor is not None:\n            self.monitor = monitor\n        self.batch_resize_range = batch_resize_range\n        if self.batch_resize_range is not None:\n            print(\n                f'{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.'\n            )\n\n        self.use_ema = use_ema\n        if self.use_ema:\n            self.model_ema = LitEma(self)\n            print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')\n\n        if ckpt_path is not None:\n            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)\n        self.scheduler_config = scheduler_config\n        self.lr_g_factor = lr_g_factor\n\n    @contextmanager\n    def ema_scope(self, context=None):\n        if self.use_ema:\n            self.model_ema.store(self.parameters())\n            self.model_ema.copy_to(self)\n            if context is not None:\n                print(f'{context}: Switched to EMA weights')\n        try:\n            yield None\n        finally:\n            if self.use_ema:\n                self.model_ema.restore(self.parameters())\n                if context is not None:\n                    print(f'{context}: Restored training weights')\n\n    def init_from_ckpt(self, path, ignore_keys=list()):\n        sd = torch.load(path, map_location='cpu')['state_dict']\n        keys = list(sd.keys())\n        for k in keys:\n            for ik in ignore_keys:\n                if k.startswith(ik):\n                    print('Deleting key {} from state_dict.'.format(k))\n                    del sd[k]\n        missing, unexpected = self.load_state_dict(sd, strict=False)\n        print(\n            f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'\n        )\n        if len(missing) > 0:\n            print(f'Missing Keys: {missing}')\n            print(f'Unexpected Keys: {unexpected}')\n\n    def on_train_batch_end(self, *args, **kwargs):\n        if self.use_ema:\n            self.model_ema(self)\n\n    def encode(self, x):\n        h = self.encoder(x)\n        h = self.quant_conv(h)\n        quant, emb_loss, info = self.quantize(h)\n        return quant, emb_loss, info\n\n    def encode_to_prequant(self, x):\n        h = self.encoder(x)\n        h = self.quant_conv(h)\n        return h\n\n    def decode(self, quant):\n        quant = self.post_quant_conv(quant)\n        dec = self.decoder(quant)\n        return dec\n\n    def decode_code(self, code_b):\n        quant_b = self.quantize.embed_code(code_b)\n        dec = self.decode(quant_b)\n        return dec\n\n    def forward(self, input, return_pred_indices=False):\n        quant, diff, (_, _, ind) = self.encode(input)\n        dec = self.decode(quant)\n        if return_pred_indices:\n            return dec, diff, ind\n        return dec, diff\n\n    def get_input(self, batch, k):\n        x = batch[k]\n        if len(x.shape) == 3:\n            x = x[..., None]\n        x = (\n            x.permute(0, 3, 1, 2)\n            .to(memory_format=torch.contiguous_format)\n            .float()\n        )\n        if self.batch_resize_range is not None:\n            lower_size = self.batch_resize_range[0]\n            upper_size = self.batch_resize_range[1]\n            if self.global_step <= 4:\n                # do the first few batches with max size to avoid later oom\n                new_resize = upper_size\n            else:\n                new_resize = np.random.choice(\n                    np.arange(lower_size, upper_size + 16, 16)\n                )\n            if new_resize != x.shape[2]:\n                x = F.interpolate(x, size=new_resize, mode='bicubic')\n            x = x.detach()\n        return x\n\n    def training_step(self, batch, batch_idx, optimizer_idx):\n        # https://github.com/pytorch/pytorch/issues/37142\n        # try not to fool the heuristics\n        x = self.get_input(batch, self.image_key)\n        xrec, qloss, ind = self(x, return_pred_indices=True)\n\n        if optimizer_idx == 0:\n            # autoencode\n            aeloss, log_dict_ae = self.loss(\n                qloss,\n                x,\n                xrec,\n                optimizer_idx,\n                self.global_step,\n                last_layer=self.get_last_layer(),\n                split='train',\n                predicted_indices=ind,\n            )\n\n            self.log_dict(\n                log_dict_ae,\n                prog_bar=False,\n                logger=True,\n                on_step=True,\n                on_epoch=True,\n            )\n            return aeloss\n\n        if optimizer_idx == 1:\n            # discriminator\n            discloss, log_dict_disc = self.loss(\n                qloss,\n                x,\n                xrec,\n                optimizer_idx,\n                self.global_step,\n                last_layer=self.get_last_layer(),\n                split='train',\n            )\n            self.log_dict(\n                log_dict_disc,\n                prog_bar=False,\n                logger=True,\n                on_step=True,\n                on_epoch=True,\n            )\n            return discloss\n\n    def validation_step(self, batch, batch_idx):\n        log_dict = self._validation_step(batch, batch_idx)\n        with self.ema_scope():\n            log_dict_ema = self._validation_step(\n                batch, batch_idx, suffix='_ema'\n            )\n        return log_dict\n\n    def _validation_step(self, batch, batch_idx, suffix=''):\n        x = self.get_input(batch, self.image_key)\n        xrec, qloss, ind = self(x, return_pred_indices=True)\n        aeloss, log_dict_ae = self.loss(\n            qloss,\n            x,\n            xrec,\n            0,\n            self.global_step,\n            last_layer=self.get_last_layer(),\n            split='val' + suffix,\n            predicted_indices=ind,\n        )\n\n        discloss, log_dict_disc = self.loss(\n            qloss,\n            x,\n            xrec,\n            1,\n            self.global_step,\n            last_layer=self.get_last_layer(),\n            split='val' + suffix,\n            predicted_indices=ind,\n        )\n        rec_loss = log_dict_ae[f'val{suffix}/rec_loss']\n        self.log(\n            f'val{suffix}/rec_loss',\n            rec_loss,\n            prog_bar=True,\n            logger=True,\n            on_step=False,\n            on_epoch=True,\n            sync_dist=True,\n        )\n        self.log(\n            f'val{suffix}/aeloss',\n            aeloss,\n            prog_bar=True,\n            logger=True,\n            on_step=False,\n            on_epoch=True,\n            sync_dist=True,\n        )\n        if version.parse(pl.__version__) >= version.parse('1.4.0'):\n            del log_dict_ae[f'val{suffix}/rec_loss']\n        self.log_dict(log_dict_ae)\n        self.log_dict(log_dict_disc)\n        return self.log_dict\n\n    def configure_optimizers(self):\n        lr_d = self.learning_rate\n        lr_g = self.lr_g_factor * self.learning_rate\n        print('lr_d', lr_d)\n        print('lr_g', lr_g)\n        opt_ae = torch.optim.Adam(\n            list(self.encoder.parameters())\n            + list(self.decoder.parameters())\n            + list(self.quantize.parameters())\n            + list(self.quant_conv.parameters())\n            + list(self.post_quant_conv.parameters()),\n            lr=lr_g,\n            betas=(0.5, 0.9),\n        )\n        opt_disc = torch.optim.Adam(\n            self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9)\n        )\n\n        if self.scheduler_config is not None:\n            scheduler = instantiate_from_config(self.scheduler_config)\n\n            print('Setting up LambdaLR scheduler...')\n            scheduler = [\n                {\n                    'scheduler': LambdaLR(\n                        opt_ae, lr_lambda=scheduler.schedule\n                    ),\n                    'interval': 'step',\n                    'frequency': 1,\n                },\n                {\n                    'scheduler': LambdaLR(\n                        opt_disc, lr_lambda=scheduler.schedule\n                    ),\n                    'interval': 'step',\n                    'frequency': 1,\n                },\n            ]\n            return [opt_ae, opt_disc], scheduler\n        return [opt_ae, opt_disc], []\n\n    def get_last_layer(self):\n        return self.decoder.conv_out.weight\n\n    def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):\n        log = dict()\n        x = self.get_input(batch, self.image_key)\n        x = x.to(self.device)\n        if only_inputs:\n            log['inputs'] = x\n            return log\n        xrec, _ = self(x)\n        if x.shape[1] > 3:\n            # colorize with random projection\n            assert xrec.shape[1] > 3\n            x = self.to_rgb(x)\n            xrec = self.to_rgb(xrec)\n        log['inputs'] = x\n        log['reconstructions'] = xrec\n        if plot_ema:\n            with self.ema_scope():\n                xrec_ema, _ = self(x)\n                if x.shape[1] > 3:\n                    xrec_ema = self.to_rgb(xrec_ema)\n                log['reconstructions_ema'] = xrec_ema\n        return log\n\n    def to_rgb(self, x):\n        assert self.image_key == 'segmentation'\n        if not hasattr(self, 'colorize'):\n            self.register_buffer(\n                'colorize', torch.randn(3, x.shape[1], 1, 1).to(x)\n            )\n        x = F.conv2d(x, weight=self.colorize)\n        x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0\n        return x\n\n\nclass VQModelInterface(VQModel):\n    def __init__(self, embed_dim, *args, **kwargs):\n        super().__init__(embed_dim=embed_dim, *args, **kwargs)\n        self.embed_dim = embed_dim\n\n    def encode(self, x):\n        h = self.encoder(x)\n        h = self.quant_conv(h)\n        return h\n\n    def decode(self, h, force_not_quantize=False):\n        # also go through quantization layer\n        if not force_not_quantize:\n            quant, emb_loss, info = self.quantize(h)\n        else:\n            quant = h\n        quant = self.post_quant_conv(quant)\n        dec = self.decoder(quant)\n        return dec\n\n\nclass AutoencoderKL(pl.LightningModule):\n    def __init__(\n        self,\n        ddconfig,\n        lossconfig,\n        embed_dim,\n        ckpt_path=None,\n        ignore_keys=[],\n        image_key='image',\n        colorize_nlabels=None,\n        monitor=None,\n    ):\n        super().__init__()\n        self.image_key = image_key\n        self.encoder = Encoder(**ddconfig)\n        self.decoder = Decoder(**ddconfig)\n        self.loss = instantiate_from_config(lossconfig)\n        assert ddconfig['double_z']\n        self.quant_conv = torch.nn.Conv2d(\n            2 * ddconfig['z_channels'], 2 * embed_dim, 1\n        )\n        self.post_quant_conv = torch.nn.Conv2d(\n            embed_dim, ddconfig['z_channels'], 1\n        )\n        self.embed_dim = embed_dim\n        if colorize_nlabels is not None:\n            assert type(colorize_nlabels) == int\n            self.register_buffer(\n                'colorize', torch.randn(3, colorize_nlabels, 1, 1)\n            )\n        if monitor is not None:\n            self.monitor = monitor\n        if ckpt_path is not None:\n            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)\n\n    def init_from_ckpt(self, path, ignore_keys=list()):\n        sd = torch.load(path, map_location='cpu')['state_dict']\n        keys = list(sd.keys())\n        for k in keys:\n            for ik in ignore_keys:\n                if k.startswith(ik):\n                    print('Deleting key {} from state_dict.'.format(k))\n                    del sd[k]\n        self.load_state_dict(sd, strict=False)\n        print(f'Restored from {path}')\n\n    def encode(self, x):\n        h = self.encoder(x)\n        moments = self.quant_conv(h)\n        posterior = DiagonalGaussianDistribution(moments)\n        return posterior\n\n    def decode(self, z):\n        z = self.post_quant_conv(z)\n        dec = self.decoder(z)\n        return dec\n\n    def forward(self, input, sample_posterior=True):\n        posterior = self.encode(input)\n        if sample_posterior:\n            z = posterior.sample()\n        else:\n            z = posterior.mode()\n        dec = self.decode(z)\n        return dec, posterior\n\n    def get_input(self, batch, k):\n        x = batch[k]\n        if len(x.shape) == 3:\n            x = x[..., None]\n        x = (\n            x.permute(0, 3, 1, 2)\n            .to(memory_format=torch.contiguous_format)\n            .float()\n        )\n        return x\n\n    def training_step(self, batch, batch_idx, optimizer_idx):\n        inputs = self.get_input(batch, self.image_key)\n        reconstructions, posterior = self(inputs)\n\n        if optimizer_idx == 0:\n            # train encoder+decoder+logvar\n            aeloss, log_dict_ae = self.loss(\n                inputs,\n                reconstructions,\n                posterior,\n                optimizer_idx,\n                self.global_step,\n                last_layer=self.get_last_layer(),\n                split='train',\n            )\n            self.log(\n                'aeloss',\n                aeloss,\n                prog_bar=True,\n                logger=True,\n                on_step=True,\n                on_epoch=True,\n            )\n            self.log_dict(\n                log_dict_ae,\n                prog_bar=False,\n                logger=True,\n                on_step=True,\n                on_epoch=False,\n            )\n            return aeloss\n\n        if optimizer_idx == 1:\n            # train the discriminator\n            discloss, log_dict_disc = self.loss(\n                inputs,\n                reconstructions,\n                posterior,\n                optimizer_idx,\n                self.global_step,\n                last_layer=self.get_last_layer(),\n                split='train',\n            )\n\n            self.log(\n                'discloss',\n                discloss,\n                prog_bar=True,\n                logger=True,\n                on_step=True,\n                on_epoch=True,\n            )\n            self.log_dict(\n                log_dict_disc,\n                prog_bar=False,\n                logger=True,\n                on_step=True,\n                on_epoch=False,\n            )\n            return discloss\n\n    def validation_step(self, batch, batch_idx):\n        inputs = self.get_input(batch, self.image_key)\n        reconstructions, posterior = self(inputs)\n        aeloss, log_dict_ae = self.loss(\n            inputs,\n            reconstructions,\n            posterior,\n            0,\n            self.global_step,\n            last_layer=self.get_last_layer(),\n            split='val',\n        )\n\n        discloss, log_dict_disc = self.loss(\n            inputs,\n            reconstructions,\n            posterior,\n            1,\n            self.global_step,\n            last_layer=self.get_last_layer(),\n            split='val',\n        )\n\n        self.log('val/rec_loss', log_dict_ae['val/rec_loss'])\n        self.log_dict(log_dict_ae)\n        self.log_dict(log_dict_disc)\n        return self.log_dict\n\n    def configure_optimizers(self):\n        lr = self.learning_rate\n        opt_ae = torch.optim.Adam(\n            list(self.encoder.parameters())\n            + list(self.decoder.parameters())\n            + list(self.quant_conv.parameters())\n            + list(self.post_quant_conv.parameters()),\n            lr=lr,\n            betas=(0.5, 0.9),\n        )\n        opt_disc = torch.optim.Adam(\n            self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)\n        )\n        return [opt_ae, opt_disc], []\n\n    def get_last_layer(self):\n        return self.decoder.conv_out.weight\n\n    @torch.no_grad()\n    def log_images(self, batch, only_inputs=False, **kwargs):\n        log = dict()\n        x = self.get_input(batch, self.image_key)\n        x = x.to(self.device)\n        if not only_inputs:\n            xrec, posterior = self(x)\n            if x.shape[1] > 3:\n                # colorize with random projection\n                assert xrec.shape[1] > 3\n                x = self.to_rgb(x)\n                xrec = self.to_rgb(xrec)\n            log['samples'] = self.decode(torch.randn_like(posterior.sample()))\n            log['reconstructions'] = xrec\n        log['inputs'] = x\n        return log\n\n    def to_rgb(self, x):\n        assert self.image_key == 'segmentation'\n        if not hasattr(self, 'colorize'):\n            self.register_buffer(\n                'colorize', torch.randn(3, x.shape[1], 1, 1).to(x)\n            )\n        x = F.conv2d(x, weight=self.colorize)\n        x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0\n        return x\n\n\nclass IdentityFirstStage(torch.nn.Module):\n    def __init__(self, *args, vq_interface=False, **kwargs):\n        self.vq_interface = vq_interface  # TODO: Should be true by default but check to not break older stuff\n        super().__init__()\n\n    def encode(self, x, *args, **kwargs):\n        return x\n\n    def decode(self, x, *args, **kwargs):\n        return x\n\n    def quantize(self, x, *args, **kwargs):\n        if self.vq_interface:\n            return x, None, [None, None, None]\n        return x\n\n    def forward(self, x, *args, **kwargs):\n        return x\n"
  },
  {
    "path": "src/stablediffusion/ldm/models/diffusion/__init__.py",
    "content": ""
  },
  {
    "path": "src/stablediffusion/ldm/models/diffusion/classifier.py",
    "content": "import os\nimport torch\nimport pytorch_lightning as pl\nfrom omegaconf import OmegaConf\nfrom torch.nn import functional as F\nfrom torch.optim import AdamW\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom copy import deepcopy\nfrom einops import rearrange\nfrom glob import glob\nfrom natsort import natsorted\n\nfrom src.stablediffusion.ldm.modules.diffusionmodules.openaimodel import (\n    EncoderUNetModel,\n    UNetModel,\n)\nfrom src.stablediffusion.ldm.util import log_txt_as_img, default, ismap, instantiate_from_config\n\n__models__ = {'class_label': EncoderUNetModel, 'segmentation': UNetModel}\n\n\ndef disabled_train(self, mode=True):\n    \"\"\"Overwrite model.train with this function to make sure train/eval mode\n    does not change anymore.\"\"\"\n    return self\n\n\nclass NoisyLatentImageClassifier(pl.LightningModule):\n    def __init__(\n        self,\n        diffusion_path,\n        num_classes,\n        ckpt_path=None,\n        pool='attention',\n        label_key=None,\n        diffusion_ckpt_path=None,\n        scheduler_config=None,\n        weight_decay=1.0e-2,\n        log_steps=10,\n        monitor='val/loss',\n        *args,\n        **kwargs,\n    ):\n        super().__init__(*args, **kwargs)\n        self.num_classes = num_classes\n        # get latest config of diffusion model\n        diffusion_config = natsorted(\n            glob(os.path.join(diffusion_path, 'configs', '*-project.yaml'))\n        )[-1]\n        self.diffusion_config = OmegaConf.load(diffusion_config).model\n        self.diffusion_config.params.ckpt_path = diffusion_ckpt_path\n        self.load_diffusion()\n\n        self.monitor = monitor\n        self.numd = (\n            self.diffusion_model.first_stage_model.encoder.num_resolutions - 1\n        )\n        self.log_time_interval = (\n            self.diffusion_model.num_timesteps // log_steps\n        )\n        self.log_steps = log_steps\n\n        self.label_key = (\n            label_key\n            if not hasattr(self.diffusion_model, 'cond_stage_key')\n            else self.diffusion_model.cond_stage_key\n        )\n\n        assert (\n            self.label_key is not None\n        ), 'label_key neither in diffusion model nor in model.params'\n\n        if self.label_key not in __models__:\n            raise NotImplementedError()\n\n        self.load_classifier(ckpt_path, pool)\n\n        self.scheduler_config = scheduler_config\n        self.use_scheduler = self.scheduler_config is not None\n        self.weight_decay = weight_decay\n\n    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):\n        sd = torch.load(path, map_location='cpu')\n        if 'state_dict' in list(sd.keys()):\n            sd = sd['state_dict']\n        keys = list(sd.keys())\n        for k in keys:\n            for ik in ignore_keys:\n                if k.startswith(ik):\n                    print('Deleting key {} from state_dict.'.format(k))\n                    del sd[k]\n        missing, unexpected = (\n            self.load_state_dict(sd, strict=False)\n            if not only_model\n            else self.model.load_state_dict(sd, strict=False)\n        )\n        print(\n            f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'\n        )\n        if len(missing) > 0:\n            print(f'Missing Keys: {missing}')\n        if len(unexpected) > 0:\n            print(f'Unexpected Keys: {unexpected}')\n\n    def load_diffusion(self):\n        model = instantiate_from_config(self.diffusion_config)\n        self.diffusion_model = model.eval()\n        self.diffusion_model.train = disabled_train\n        for param in self.diffusion_model.parameters():\n            param.requires_grad = False\n\n    def load_classifier(self, ckpt_path, pool):\n        model_config = deepcopy(\n            self.diffusion_config.params.unet_config.params\n        )\n        model_config.in_channels = (\n            self.diffusion_config.params.unet_config.params.out_channels\n        )\n        model_config.out_channels = self.num_classes\n        if self.label_key == 'class_label':\n            model_config.pool = pool\n\n        self.model = __models__[self.label_key](**model_config)\n        if ckpt_path is not None:\n            print(\n                '#####################################################################'\n            )\n            print(f'load from ckpt \"{ckpt_path}\"')\n            print(\n                '#####################################################################'\n            )\n            self.init_from_ckpt(ckpt_path)\n\n    @torch.no_grad()\n    def get_x_noisy(self, x, t, noise=None):\n        noise = default(noise, lambda: torch.randn_like(x))\n        continuous_sqrt_alpha_cumprod = None\n        if self.diffusion_model.use_continuous_noise:\n            continuous_sqrt_alpha_cumprod = (\n                self.diffusion_model.sample_continuous_noise_level(\n                    x.shape[0], t + 1\n                )\n            )\n            # todo: make sure t+1 is correct here\n\n        return self.diffusion_model.q_sample(\n            x_start=x,\n            t=t,\n            noise=noise,\n            continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod,\n        )\n\n    def forward(self, x_noisy, t, *args, **kwargs):\n        return self.model(x_noisy, t)\n\n    @torch.no_grad()\n    def get_input(self, batch, k):\n        x = batch[k]\n        if len(x.shape) == 3:\n            x = x[..., None]\n        x = rearrange(x, 'b h w c -> b c h w')\n        x = x.to(memory_format=torch.contiguous_format).float()\n        return x\n\n    @torch.no_grad()\n    def get_conditioning(self, batch, k=None):\n        if k is None:\n            k = self.label_key\n        assert k is not None, 'Needs to provide label key'\n\n        targets = batch[k].to(self.device)\n\n        if self.label_key == 'segmentation':\n            targets = rearrange(targets, 'b h w c -> b c h w')\n            for down in range(self.numd):\n                h, w = targets.shape[-2:]\n                targets = F.interpolate(\n                    targets, size=(h // 2, w // 2), mode='nearest'\n                )\n\n            # targets = rearrange(targets,'b c h w -> b h w c')\n\n        return targets\n\n    def compute_top_k(self, logits, labels, k, reduction='mean'):\n        _, top_ks = torch.topk(logits, k, dim=1)\n        if reduction == 'mean':\n            return (\n                (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()\n            )\n        elif reduction == 'none':\n            return (top_ks == labels[:, None]).float().sum(dim=-1)\n\n    def on_train_epoch_start(self):\n        # save some memory\n        self.diffusion_model.model.to('cpu')\n\n    @torch.no_grad()\n    def write_logs(self, loss, logits, targets):\n        log_prefix = 'train' if self.training else 'val'\n        log = {}\n        log[f'{log_prefix}/loss'] = loss.mean()\n        log[f'{log_prefix}/acc@1'] = self.compute_top_k(\n            logits, targets, k=1, reduction='mean'\n        )\n        log[f'{log_prefix}/acc@5'] = self.compute_top_k(\n            logits, targets, k=5, reduction='mean'\n        )\n\n        self.log_dict(\n            log,\n            prog_bar=False,\n            logger=True,\n            on_step=self.training,\n            on_epoch=True,\n        )\n        self.log(\n            'loss', log[f'{log_prefix}/loss'], prog_bar=True, logger=False\n        )\n        self.log(\n            'global_step',\n            self.global_step,\n            logger=False,\n            on_epoch=False,\n            prog_bar=True,\n        )\n        lr = self.optimizers().param_groups[0]['lr']\n        self.log(\n            'lr_abs',\n            lr,\n            on_step=True,\n            logger=True,\n            on_epoch=False,\n            prog_bar=True,\n        )\n\n    def shared_step(self, batch, t=None):\n        x, *_ = self.diffusion_model.get_input(\n            batch, k=self.diffusion_model.first_stage_key\n        )\n        targets = self.get_conditioning(batch)\n        if targets.dim() == 4:\n            targets = targets.argmax(dim=1)\n        if t is None:\n            t = torch.randint(\n                0,\n                self.diffusion_model.num_timesteps,\n                (x.shape[0],),\n                device=self.device,\n            ).long()\n        else:\n            t = torch.full(\n                size=(x.shape[0],), fill_value=t, device=self.device\n            ).long()\n        x_noisy = self.get_x_noisy(x, t)\n        logits = self(x_noisy, t)\n\n        loss = F.cross_entropy(logits, targets, reduction='none')\n\n        self.write_logs(loss.detach(), logits.detach(), targets.detach())\n\n        loss = loss.mean()\n        return loss, logits, x_noisy, targets\n\n    def training_step(self, batch, batch_idx):\n        loss, *_ = self.shared_step(batch)\n        return loss\n\n    def reset_noise_accs(self):\n        self.noisy_acc = {\n            t: {'acc@1': [], 'acc@5': []}\n            for t in range(\n                0,\n                self.diffusion_model.num_timesteps,\n                self.diffusion_model.log_every_t,\n            )\n        }\n\n    def on_validation_start(self):\n        self.reset_noise_accs()\n\n    @torch.no_grad()\n    def validation_step(self, batch, batch_idx):\n        loss, *_ = self.shared_step(batch)\n\n        for t in self.noisy_acc:\n            _, logits, _, targets = self.shared_step(batch, t)\n            self.noisy_acc[t]['acc@1'].append(\n                self.compute_top_k(logits, targets, k=1, reduction='mean')\n            )\n            self.noisy_acc[t]['acc@5'].append(\n                self.compute_top_k(logits, targets, k=5, reduction='mean')\n            )\n\n        return loss\n\n    def configure_optimizers(self):\n        optimizer = AdamW(\n            self.model.parameters(),\n            lr=self.learning_rate,\n            weight_decay=self.weight_decay,\n        )\n\n        if self.use_scheduler:\n            scheduler = instantiate_from_config(self.scheduler_config)\n\n            print('Setting up LambdaLR scheduler...')\n            scheduler = [\n                {\n                    'scheduler': LambdaLR(\n                        optimizer, lr_lambda=scheduler.schedule\n                    ),\n                    'interval': 'step',\n                    'frequency': 1,\n                }\n            ]\n            return [optimizer], scheduler\n\n        return optimizer\n\n    @torch.no_grad()\n    def log_images(self, batch, N=8, *args, **kwargs):\n        log = dict()\n        x = self.get_input(batch, self.diffusion_model.first_stage_key)\n        log['inputs'] = x\n\n        y = self.get_conditioning(batch)\n\n        if self.label_key == 'class_label':\n            y = log_txt_as_img((x.shape[2], x.shape[3]), batch['human_label'])\n            log['labels'] = y\n\n        if ismap(y):\n            log['labels'] = self.diffusion_model.to_rgb(y)\n\n            for step in range(self.log_steps):\n                current_time = step * self.log_time_interval\n\n                _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)\n\n                log[f'inputs@t{current_time}'] = x_noisy\n\n                pred = F.one_hot(\n                    logits.argmax(dim=1), num_classes=self.num_classes\n                )\n                pred = rearrange(pred, 'b h w c -> b c h w')\n\n                log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(\n                    pred\n                )\n\n        for key in log:\n            log[key] = log[key][:N]\n\n        return log\n"
  },
  {
    "path": "src/stablediffusion/ldm/models/diffusion/ddim.py",
    "content": "\"\"\"SAMPLING ONLY.\"\"\"\n\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\nfrom functools import partial\nfrom src.stablediffusion.ldm.dream.devices import choose_torch_device\n\nfrom src.stablediffusion.ldm.modules.diffusionmodules.util import (\n    make_ddim_sampling_parameters,\n    make_ddim_timesteps,\n    noise_like,\n    extract_into_tensor,\n)\n\n\nclass DDIMSampler(object):\n    def __init__(self, model, schedule='linear', device=None, **kwargs):\n        super().__init__()\n        self.model = model\n        self.ddpm_num_timesteps = model.num_timesteps\n        self.schedule = schedule\n        self.device   = device or choose_torch_device()\n\n    def register_buffer(self, name, attr):\n        if type(attr) == torch.Tensor:\n            if attr.device != torch.device(self.device):\n                attr = attr.to(dtype=torch.float32, device=self.device)\n        setattr(self, name, attr)\n\n    def make_schedule(\n        self,\n        ddim_num_steps,\n        ddim_discretize='uniform',\n        ddim_eta=0.0,\n        verbose=True,\n    ):\n        self.ddim_timesteps = make_ddim_timesteps(\n            ddim_discr_method=ddim_discretize,\n            num_ddim_timesteps=ddim_num_steps,\n            num_ddpm_timesteps=self.ddpm_num_timesteps,\n            verbose=verbose,\n        )\n        alphas_cumprod = self.model.alphas_cumprod\n        assert (\n            alphas_cumprod.shape[0] == self.ddpm_num_timesteps\n        ), 'alphas have to be defined for each timestep'\n        to_torch = (\n            lambda x: x.clone()\n            .detach()\n            .to(torch.float32)\n            .to(self.model.device)\n        )\n\n        self.register_buffer('betas', to_torch(self.model.betas))\n        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))\n        self.register_buffer(\n            'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)\n        )\n\n        # calculations for diffusion q(x_t | x_{t-1}) and others\n        self.register_buffer(\n            'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))\n        )\n        self.register_buffer(\n            'sqrt_one_minus_alphas_cumprod',\n            to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),\n        )\n        self.register_buffer(\n            'log_one_minus_alphas_cumprod',\n            to_torch(np.log(1.0 - alphas_cumprod.cpu())),\n        )\n        self.register_buffer(\n            'sqrt_recip_alphas_cumprod',\n            to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),\n        )\n        self.register_buffer(\n            'sqrt_recipm1_alphas_cumprod',\n            to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),\n        )\n\n        # ddim sampling parameters\n        (\n            ddim_sigmas,\n            ddim_alphas,\n            ddim_alphas_prev,\n        ) = make_ddim_sampling_parameters(\n            alphacums=alphas_cumprod.cpu(),\n            ddim_timesteps=self.ddim_timesteps,\n            eta=ddim_eta,\n            verbose=verbose,\n        )\n        self.register_buffer('ddim_sigmas', ddim_sigmas)\n        self.register_buffer('ddim_alphas', ddim_alphas)\n        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)\n        self.register_buffer(\n            'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)\n        )\n        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(\n            (1 - self.alphas_cumprod_prev)\n            / (1 - self.alphas_cumprod)\n            * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)\n        )\n        self.register_buffer(\n            'ddim_sigmas_for_original_num_steps',\n            sigmas_for_original_sampling_steps,\n        )\n\n    @torch.no_grad()\n    def sample(\n        self,\n        S,\n        batch_size,\n        shape,\n        conditioning=None,\n        callback=None,\n        normals_sequence=None,\n        img_callback=None,\n        quantize_x0=False,\n        eta=0.0,\n        mask=None,\n        x0=None,\n        temperature=1.0,\n        noise_dropout=0.0,\n        score_corrector=None,\n        corrector_kwargs=None,\n        verbose=True,\n        x_T=None,\n        log_every_t=100,\n        unconditional_guidance_scale=1.0,\n        unconditional_conditioning=None,\n        # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...\n        **kwargs,\n    ):\n        if conditioning is not None:\n            if isinstance(conditioning, dict):\n                cbs = conditioning[list(conditioning.keys())[0]].shape[0]\n                if cbs != batch_size:\n                    print(\n                        f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'\n                    )\n            else:\n                if conditioning.shape[0] != batch_size:\n                    print(\n                        f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'\n                    )\n\n        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)\n        # sampling\n        C, H, W = shape\n        size = (batch_size, C, H, W)\n        print(f'Data shape for DDIM sampling is {size}, eta {eta}')\n\n        samples, intermediates = self.ddim_sampling(\n            conditioning,\n            size,\n            callback=callback,\n            img_callback=img_callback,\n            quantize_denoised=quantize_x0,\n            mask=mask,\n            x0=x0,\n            ddim_use_original_steps=False,\n            noise_dropout=noise_dropout,\n            temperature=temperature,\n            score_corrector=score_corrector,\n            corrector_kwargs=corrector_kwargs,\n            x_T=x_T,\n            log_every_t=log_every_t,\n            unconditional_guidance_scale=unconditional_guidance_scale,\n            unconditional_conditioning=unconditional_conditioning,\n        )\n        return samples, intermediates\n\n    # This routine gets called from img2img\n    @torch.no_grad()\n    def ddim_sampling(\n        self,\n        cond,\n        shape,\n        x_T=None,\n        ddim_use_original_steps=False,\n        callback=None,\n        timesteps=None,\n        quantize_denoised=False,\n        mask=None,\n        x0=None,\n        img_callback=None,\n        log_every_t=100,\n        temperature=1.0,\n        noise_dropout=0.0,\n        score_corrector=None,\n        corrector_kwargs=None,\n        unconditional_guidance_scale=1.0,\n        unconditional_conditioning=None,\n    ):\n        device = self.model.betas.device\n        b = shape[0]\n        if x_T is None:\n            img = torch.randn(shape, device=device)\n        else:\n            img = x_T\n\n        if timesteps is None:\n            timesteps = (\n                self.ddpm_num_timesteps\n                if ddim_use_original_steps\n                else self.ddim_timesteps\n            )\n        elif timesteps is not None and not ddim_use_original_steps:\n            subset_end = (\n                int(\n                    min(timesteps / self.ddim_timesteps.shape[0], 1)\n                    * self.ddim_timesteps.shape[0]\n                )\n                - 1\n            )\n            timesteps = self.ddim_timesteps[:subset_end]\n\n        intermediates = {'x_inter': [img], 'pred_x0': [img]}\n        time_range = (\n            reversed(range(0, timesteps))\n            if ddim_use_original_steps\n            else np.flip(timesteps)\n        )\n        total_steps = (\n            timesteps if ddim_use_original_steps else timesteps.shape[0]\n        )\n        print(f'Running DDIM Sampling with {total_steps} timesteps')\n\n        iterator = tqdm(\n            time_range,\n            desc='DDIM Sampler',\n            total=total_steps,\n            dynamic_ncols=True,\n        )\n\n        for i, step in enumerate(iterator):\n            index = total_steps - i - 1\n            ts = torch.full((b,), step, device=device, dtype=torch.long)\n\n            if mask is not None:\n                assert x0 is not None\n                img_orig = self.model.q_sample(\n                    x0, ts\n                )  # TODO: deterministic forward pass?\n                img = img_orig * mask + (1.0 - mask) * img\n\n            outs = self.p_sample_ddim(\n                img,\n                cond,\n                ts,\n                index=index,\n                use_original_steps=ddim_use_original_steps,\n                quantize_denoised=quantize_denoised,\n                temperature=temperature,\n                noise_dropout=noise_dropout,\n                score_corrector=score_corrector,\n                corrector_kwargs=corrector_kwargs,\n                unconditional_guidance_scale=unconditional_guidance_scale,\n                unconditional_conditioning=unconditional_conditioning,\n            )\n            img, pred_x0 = outs\n            if callback:\n                callback(i)\n            if img_callback:\n                img_callback(pred_x0, i)\n\n            if index % log_every_t == 0 or index == total_steps - 1:\n                intermediates['x_inter'].append(img)\n                intermediates['pred_x0'].append(pred_x0)\n\n        return img, intermediates\n\n    # This routine gets called from ddim_sampling() and decode()\n    @torch.no_grad()\n    def p_sample_ddim(\n        self,\n        x,\n        c,\n        t,\n        index,\n        repeat_noise=False,\n        use_original_steps=False,\n        quantize_denoised=False,\n        temperature=1.0,\n        noise_dropout=0.0,\n        score_corrector=None,\n        corrector_kwargs=None,\n        unconditional_guidance_scale=1.0,\n        unconditional_conditioning=None,\n    ):\n        b, *_, device = *x.shape, x.device\n\n        if (\n            unconditional_conditioning is None\n            or unconditional_guidance_scale == 1.0\n        ):\n            e_t = self.model.apply_model(x, t, c)\n        else:\n            x_in = torch.cat([x] * 2)\n            t_in = torch.cat([t] * 2)\n            c_in = torch.cat([unconditional_conditioning, c])\n            e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)\n            e_t = e_t_uncond + unconditional_guidance_scale * (\n                e_t - e_t_uncond\n            )\n\n        if score_corrector is not None:\n            assert self.model.parameterization == 'eps'\n            e_t = score_corrector.modify_score(\n                self.model, e_t, x, t, c, **corrector_kwargs\n            )\n\n        alphas = (\n            self.model.alphas_cumprod\n            if use_original_steps\n            else self.ddim_alphas\n        )\n        alphas_prev = (\n            self.model.alphas_cumprod_prev\n            if use_original_steps\n            else self.ddim_alphas_prev\n        )\n        sqrt_one_minus_alphas = (\n            self.model.sqrt_one_minus_alphas_cumprod\n            if use_original_steps\n            else self.ddim_sqrt_one_minus_alphas\n        )\n        sigmas = (\n            self.model.ddim_sigmas_for_original_num_steps\n            if use_original_steps\n            else self.ddim_sigmas\n        )\n        # select parameters corresponding to the currently considered timestep\n        a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)\n        a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)\n        sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)\n        sqrt_one_minus_at = torch.full(\n            (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device\n        )\n\n        # current prediction for x_0\n        pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()\n        if quantize_denoised:\n            pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)\n        # direction pointing to x_t\n        dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t\n        noise = (\n            sigma_t * noise_like(x.shape, device, repeat_noise) * temperature\n        )\n        if noise_dropout > 0.0:\n            noise = torch.nn.functional.dropout(noise, p=noise_dropout)\n        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise\n        return x_prev, pred_x0\n\n    @torch.no_grad()\n    def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):\n        # fast, but does not allow for exact reconstruction\n        # t serves as an index to gather the correct alphas\n        if use_original_steps:\n            sqrt_alphas_cumprod = self.sqrt_alphas_cumprod\n            sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod\n        else:\n            sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)\n            sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas\n\n        if noise is None:\n            noise = torch.randn_like(x0)\n        return (\n            extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0\n            + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)\n            * noise\n        )\n\n    @torch.no_grad()\n    def decode(\n            self,\n            x_latent,\n            cond,\n            t_start,\n            img_callback=None,\n            unconditional_guidance_scale=1.0,\n            unconditional_conditioning=None,\n            use_original_steps=False,\n            init_latent       = None,\n            mask              = None,\n    ):\n\n        timesteps = (\n            np.arange(self.ddpm_num_timesteps)\n            if use_original_steps\n            else self.ddim_timesteps\n        )\n        timesteps = timesteps[:t_start]\n\n        time_range = np.flip(timesteps)\n        total_steps = timesteps.shape[0]\n        print(f'Running DDIM Sampling with {total_steps} timesteps')\n\n        iterator = tqdm(time_range, desc='Decoding image', total=total_steps)\n        x_dec = x_latent\n        x0    = init_latent\n\n        for i, step in enumerate(iterator):\n            index = total_steps - i - 1\n            ts = torch.full(\n                (x_latent.shape[0],),\n                step,\n                device=x_latent.device,\n                dtype=torch.long,\n            )\n\n            if mask is not None:\n                assert x0 is not None\n                xdec_orig = self.model.q_sample(\n                    x0, ts\n                )  # TODO: deterministic forward pass?\n                x_dec = xdec_orig * mask + (1.0 - mask) * x_dec\n\n            x_dec, _ = self.p_sample_ddim(\n                x_dec,\n                cond,\n                ts,\n                index=index,\n                use_original_steps=use_original_steps,\n                unconditional_guidance_scale=unconditional_guidance_scale,\n                unconditional_conditioning=unconditional_conditioning,\n            )\n\n            if img_callback:\n                img_callback(x_dec, i)\n\n        return x_dec\n"
  },
  {
    "path": "src/stablediffusion/ldm/models/diffusion/ddpm.py",
    "content": "\"\"\"\nwild mixture of\nhttps://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py\nhttps://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py\nhttps://github.com/CompVis/taming-transformers\n-- merci\n\"\"\"\n\nimport torch\n\nimport torch.nn as nn\nimport os\nimport numpy as np\nimport pytorch_lightning as pl\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom einops import rearrange, repeat\nfrom contextlib import contextmanager\nfrom functools import partial\nfrom tqdm import tqdm\nfrom torchvision.utils import make_grid\nfrom pytorch_lightning.utilities.distributed import rank_zero_only\nimport urllib\n\nfrom src.stablediffusion.ldm.util import (\n    log_txt_as_img,\n    exists,\n    default,\n    ismap,\n    isimage,\n    mean_flat,\n    count_params,\n    instantiate_from_config,\n)\nfrom src.stablediffusion.ldm.modules.ema import LitEma\nfrom src.stablediffusion.ldm.modules.distributions.distributions import (\n    normal_kl,\n    DiagonalGaussianDistribution,\n)\nfrom src.stablediffusion.ldm.models.autoencoder import (\n    VQModelInterface,\n    IdentityFirstStage,\n    AutoencoderKL,\n)\nfrom src.stablediffusion.ldm.modules.diffusionmodules.util import (\n    make_beta_schedule,\n    extract_into_tensor,\n    noise_like,\n)\nfrom src.stablediffusion.ldm.models.diffusion.ddim import DDIMSampler\n\n\n__conditioning_keys__ = {\n    'concat': 'c_concat',\n    'crossattn': 'c_crossattn',\n    'adm': 'y',\n}\n\n\ndef disabled_train(self, mode=True):\n    \"\"\"Overwrite model.train with this function to make sure train/eval mode\n    does not change anymore.\"\"\"\n    return self\n\n\ndef uniform_on_device(r1, r2, shape, device):\n    return (r1 - r2) * torch.rand(*shape, device=device) + r2\n\n\nclass DDPM(pl.LightningModule):\n    # classic DDPM with Gaussian diffusion, in image space\n    def __init__(\n        self,\n        unet_config,\n        timesteps=1000,\n        beta_schedule='linear',\n        loss_type='l2',\n        ckpt_path=None,\n        ignore_keys=[],\n        load_only_unet=False,\n        monitor='val/loss',\n        use_ema=True,\n        first_stage_key='image',\n        image_size=256,\n        channels=3,\n        log_every_t=100,\n        clip_denoised=True,\n        linear_start=1e-4,\n        linear_end=2e-2,\n        cosine_s=8e-3,\n        given_betas=None,\n        original_elbo_weight=0.0,\n        embedding_reg_weight=0.0,\n        v_posterior=0.0,  # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta\n        l_simple_weight=1.0,\n        conditioning_key=None,\n        parameterization='eps',  # all assuming fixed variance schedules\n        scheduler_config=None,\n        use_positional_encodings=False,\n        learn_logvar=False,\n        logvar_init=0.0,\n    ):\n        super().__init__()\n        assert parameterization in [\n            'eps',\n            'x0',\n        ], 'currently only supporting \"eps\" and \"x0\"'\n        self.parameterization = parameterization\n        print(\n            f'{self.__class__.__name__}: Running in {self.parameterization}-prediction mode'\n        )\n        self.cond_stage_model = None\n        self.clip_denoised = clip_denoised\n        self.log_every_t = log_every_t\n        self.first_stage_key = first_stage_key\n        self.image_size = image_size  # try conv?\n        self.channels = channels\n        self.use_positional_encodings = use_positional_encodings\n        self.model = DiffusionWrapper(unet_config, conditioning_key)\n        count_params(self.model, verbose=True)\n        self.use_ema = use_ema\n        if self.use_ema:\n            self.model_ema = LitEma(self.model)\n            print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')\n\n        self.use_scheduler = scheduler_config is not None\n        if self.use_scheduler:\n            self.scheduler_config = scheduler_config\n\n        self.v_posterior = v_posterior\n        self.original_elbo_weight = original_elbo_weight\n        self.l_simple_weight = l_simple_weight\n        self.embedding_reg_weight = embedding_reg_weight\n\n        if monitor is not None:\n            self.monitor = monitor\n        if ckpt_path is not None:\n            self.init_from_ckpt(\n                ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet\n            )\n\n        self.register_schedule(\n            given_betas=given_betas,\n            beta_schedule=beta_schedule,\n            timesteps=timesteps,\n            linear_start=linear_start,\n            linear_end=linear_end,\n            cosine_s=cosine_s,\n        )\n\n        self.loss_type = loss_type\n\n        self.learn_logvar = learn_logvar\n        self.logvar = torch.full(\n            fill_value=logvar_init, size=(self.num_timesteps,)\n        )\n        if self.learn_logvar:\n            self.logvar = nn.Parameter(self.logvar, requires_grad=True)\n\n    def register_schedule(\n        self,\n        given_betas=None,\n        beta_schedule='linear',\n        timesteps=1000,\n        linear_start=1e-4,\n        linear_end=2e-2,\n        cosine_s=8e-3,\n    ):\n        if exists(given_betas):\n            betas = given_betas\n        else:\n            betas = make_beta_schedule(\n                beta_schedule,\n                timesteps,\n                linear_start=linear_start,\n                linear_end=linear_end,\n                cosine_s=cosine_s,\n            )\n        alphas = 1.0 - betas\n        alphas_cumprod = np.cumprod(alphas, axis=0)\n        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])\n\n        (timesteps,) = betas.shape\n        self.num_timesteps = int(timesteps)\n        self.linear_start = linear_start\n        self.linear_end = linear_end\n        assert (\n            alphas_cumprod.shape[0] == self.num_timesteps\n        ), 'alphas have to be defined for each timestep'\n\n        to_torch = partial(torch.tensor, dtype=torch.float32)\n\n        self.register_buffer('betas', to_torch(betas))\n        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))\n        self.register_buffer(\n            'alphas_cumprod_prev', to_torch(alphas_cumprod_prev)\n        )\n\n        # calculations for diffusion q(x_t | x_{t-1}) and others\n        self.register_buffer(\n            'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))\n        )\n        self.register_buffer(\n            'sqrt_one_minus_alphas_cumprod',\n            to_torch(np.sqrt(1.0 - alphas_cumprod)),\n        )\n        self.register_buffer(\n            'log_one_minus_alphas_cumprod',\n            to_torch(np.log(1.0 - alphas_cumprod)),\n        )\n        self.register_buffer(\n            'sqrt_recip_alphas_cumprod',\n            to_torch(np.sqrt(1.0 / alphas_cumprod)),\n        )\n        self.register_buffer(\n            'sqrt_recipm1_alphas_cumprod',\n            to_torch(np.sqrt(1.0 / alphas_cumprod - 1)),\n        )\n\n        # calculations for posterior q(x_{t-1} | x_t, x_0)\n        posterior_variance = (1 - self.v_posterior) * betas * (\n            1.0 - alphas_cumprod_prev\n        ) / (1.0 - alphas_cumprod) + self.v_posterior * betas\n        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)\n        self.register_buffer(\n            'posterior_variance', to_torch(posterior_variance)\n        )\n        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain\n        self.register_buffer(\n            'posterior_log_variance_clipped',\n            to_torch(np.log(np.maximum(posterior_variance, 1e-20))),\n        )\n        self.register_buffer(\n            'posterior_mean_coef1',\n            to_torch(\n                betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)\n            ),\n        )\n        self.register_buffer(\n            'posterior_mean_coef2',\n            to_torch(\n                (1.0 - alphas_cumprod_prev)\n                * np.sqrt(alphas)\n                / (1.0 - alphas_cumprod)\n            ),\n        )\n\n        if self.parameterization == 'eps':\n            lvlb_weights = self.betas**2 / (\n                2\n                * self.posterior_variance\n                * to_torch(alphas)\n                * (1 - self.alphas_cumprod)\n            )\n        elif self.parameterization == 'x0':\n            lvlb_weights = (\n                0.5\n                * np.sqrt(torch.Tensor(alphas_cumprod))\n                / (2.0 * 1 - torch.Tensor(alphas_cumprod))\n            )\n        else:\n            raise NotImplementedError('mu not supported')\n        # TODO how to choose this term\n        lvlb_weights[0] = lvlb_weights[1]\n        self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)\n        assert not torch.isnan(self.lvlb_weights).all()\n\n    @contextmanager\n    def ema_scope(self, context=None):\n        if self.use_ema:\n            self.model_ema.store(self.model.parameters())\n            self.model_ema.copy_to(self.model)\n            if context is not None:\n                print(f'{context}: Switched to EMA weights')\n        try:\n            yield None\n        finally:\n            if self.use_ema:\n                self.model_ema.restore(self.model.parameters())\n                if context is not None:\n                    print(f'{context}: Restored training weights')\n\n    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):\n        sd = torch.load(path, map_location='cpu')\n        if 'state_dict' in list(sd.keys()):\n            sd = sd['state_dict']\n        keys = list(sd.keys())\n        for k in keys:\n            for ik in ignore_keys:\n                if k.startswith(ik):\n                    print('Deleting key {} from state_dict.'.format(k))\n                    del sd[k]\n        missing, unexpected = (\n            self.load_state_dict(sd, strict=False)\n            if not only_model\n            else self.model.load_state_dict(sd, strict=False)\n        )\n        print(\n            f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'\n        )\n        if len(missing) > 0:\n            print(f'Missing Keys: {missing}')\n        if len(unexpected) > 0:\n            print(f'Unexpected Keys: {unexpected}')\n\n    def q_mean_variance(self, x_start, t):\n        \"\"\"\n        Get the distribution q(x_t | x_0).\n        :param x_start: the [N x C x ...] tensor of noiseless inputs.\n        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.\n        :return: A tuple (mean, variance, log_variance), all of x_start's shape.\n        \"\"\"\n        mean = (\n            extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape)\n            * x_start\n        )\n        variance = extract_into_tensor(\n            1.0 - self.alphas_cumprod, t, x_start.shape\n        )\n        log_variance = extract_into_tensor(\n            self.log_one_minus_alphas_cumprod, t, x_start.shape\n        )\n        return mean, variance, log_variance\n\n    def predict_start_from_noise(self, x_t, t, noise):\n        return (\n            extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape)\n            * x_t\n            - extract_into_tensor(\n                self.sqrt_recipm1_alphas_cumprod, t, x_t.shape\n            )\n            * noise\n        )\n\n    def q_posterior(self, x_start, x_t, t):\n        posterior_mean = (\n            extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape)\n            * x_start\n            + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape)\n            * x_t\n        )\n        posterior_variance = extract_into_tensor(\n            self.posterior_variance, t, x_t.shape\n        )\n        posterior_log_variance_clipped = extract_into_tensor(\n            self.posterior_log_variance_clipped, t, x_t.shape\n        )\n        return (\n            posterior_mean,\n            posterior_variance,\n            posterior_log_variance_clipped,\n        )\n\n    def p_mean_variance(self, x, t, clip_denoised: bool):\n        model_out = self.model(x, t)\n        if self.parameterization == 'eps':\n            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)\n        elif self.parameterization == 'x0':\n            x_recon = model_out\n        if clip_denoised:\n            x_recon.clamp_(-1.0, 1.0)\n\n        (\n            model_mean,\n            posterior_variance,\n            posterior_log_variance,\n        ) = self.q_posterior(x_start=x_recon, x_t=x, t=t)\n        return model_mean, posterior_variance, posterior_log_variance\n\n    @torch.no_grad()\n    def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):\n        b, *_, device = *x.shape, x.device\n        model_mean, _, model_log_variance = self.p_mean_variance(\n            x=x, t=t, clip_denoised=clip_denoised\n        )\n        noise = noise_like(x.shape, device, repeat_noise)\n        # no noise when t == 0\n        nonzero_mask = (1 - (t == 0).float()).reshape(\n            b, *((1,) * (len(x.shape) - 1))\n        )\n        return (\n            model_mean\n            + nonzero_mask * (0.5 * model_log_variance).exp() * noise\n        )\n\n    @torch.no_grad()\n    def p_sample_loop(self, shape, return_intermediates=False):\n        device = self.betas.device\n        b = shape[0]\n        img = torch.randn(shape, device=device)\n        intermediates = [img]\n        for i in tqdm(\n            reversed(range(0, self.num_timesteps)),\n            desc='Sampling t',\n            total=self.num_timesteps,\n            dynamic_ncols=True,\n        ):\n            img = self.p_sample(\n                img,\n                torch.full((b,), i, device=device, dtype=torch.long),\n                clip_denoised=self.clip_denoised,\n            )\n            if i % self.log_every_t == 0 or i == self.num_timesteps - 1:\n                intermediates.append(img)\n        if return_intermediates:\n            return img, intermediates\n        return img\n\n    @torch.no_grad()\n    def sample(self, batch_size=16, return_intermediates=False):\n        image_size = self.image_size\n        channels = self.channels\n        return self.p_sample_loop(\n            (batch_size, channels, image_size, image_size),\n            return_intermediates=return_intermediates,\n        )\n\n    def q_sample(self, x_start, t, noise=None):\n        noise = default(noise, lambda: torch.randn_like(x_start))\n        return (\n            extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape)\n            * x_start\n            + extract_into_tensor(\n                self.sqrt_one_minus_alphas_cumprod, t, x_start.shape\n            )\n            * noise\n        )\n\n    def get_loss(self, pred, target, mean=True):\n        if self.loss_type == 'l1':\n            loss = (target - pred).abs()\n            if mean:\n                loss = loss.mean()\n        elif self.loss_type == 'l2':\n            if mean:\n                loss = torch.nn.functional.mse_loss(target, pred)\n            else:\n                loss = torch.nn.functional.mse_loss(\n                    target, pred, reduction='none'\n                )\n        else:\n            raise NotImplementedError(\"unknown loss type '{loss_type}'\")\n\n        return loss\n\n    def p_losses(self, x_start, t, noise=None):\n        noise = default(noise, lambda: torch.randn_like(x_start))\n        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)\n        model_out = self.model(x_noisy, t)\n\n        loss_dict = {}\n        if self.parameterization == 'eps':\n            target = noise\n        elif self.parameterization == 'x0':\n            target = x_start\n        else:\n            raise NotImplementedError(\n                f'Paramterization {self.parameterization} not yet supported'\n            )\n\n        loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])\n\n        log_prefix = 'train' if self.training else 'val'\n\n        loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})\n        loss_simple = loss.mean() * self.l_simple_weight\n\n        loss_vlb = (self.lvlb_weights[t] * loss).mean()\n        loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})\n\n        loss = loss_simple + self.original_elbo_weight * loss_vlb\n\n        loss_dict.update({f'{log_prefix}/loss': loss})\n\n        return loss, loss_dict\n\n    def forward(self, x, *args, **kwargs):\n        # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size\n        # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'\n        t = torch.randint(\n            0, self.num_timesteps, (x.shape[0],), device=self.device\n        ).long()\n        return self.p_losses(x, t, *args, **kwargs)\n\n    def get_input(self, batch, k):\n        x = batch[k]\n        if len(x.shape) == 3:\n            x = x[..., None]\n        x = rearrange(x, 'b h w c -> b c h w')\n        x = x.to(memory_format=torch.contiguous_format).float()\n        return x\n\n    def shared_step(self, batch):\n        x = self.get_input(batch, self.first_stage_key)\n        loss, loss_dict = self(x)\n        return loss, loss_dict\n\n    def training_step(self, batch, batch_idx):\n        loss, loss_dict = self.shared_step(batch)\n\n        self.log_dict(\n            loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True\n        )\n\n        self.log(\n            'global_step',\n            self.global_step,\n            prog_bar=True,\n            logger=True,\n            on_step=True,\n            on_epoch=False,\n        )\n\n        if self.use_scheduler:\n            lr = self.optimizers().param_groups[0]['lr']\n            self.log(\n                'lr_abs',\n                lr,\n                prog_bar=True,\n                logger=True,\n                on_step=True,\n                on_epoch=False,\n            )\n\n        return loss\n\n    @torch.no_grad()\n    def validation_step(self, batch, batch_idx):\n        _, loss_dict_no_ema = self.shared_step(batch)\n        with self.ema_scope():\n            _, loss_dict_ema = self.shared_step(batch)\n            loss_dict_ema = {\n                key + '_ema': loss_dict_ema[key] for key in loss_dict_ema\n            }\n        self.log_dict(\n            loss_dict_no_ema,\n            prog_bar=False,\n            logger=True,\n            on_step=False,\n            on_epoch=True,\n        )\n        self.log_dict(\n            loss_dict_ema,\n            prog_bar=False,\n            logger=True,\n            on_step=False,\n            on_epoch=True,\n        )\n\n    def on_train_batch_end(self, *args, **kwargs):\n        if self.use_ema:\n            self.model_ema(self.model)\n\n    def _get_rows_from_list(self, samples):\n        n_imgs_per_row = len(samples)\n        denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')\n        denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')\n        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)\n        return denoise_grid\n\n    @torch.no_grad()\n    def log_images(\n        self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs\n    ):\n        log = dict()\n        x = self.get_input(batch, self.first_stage_key)\n        N = min(x.shape[0], N)\n        n_row = min(x.shape[0], n_row)\n        x = x.to(self.device)[:N]\n        log['inputs'] = x\n\n        # get diffusion row\n        diffusion_row = list()\n        x_start = x[:n_row]\n\n        for t in range(self.num_timesteps):\n            if t % self.log_every_t == 0 or t == self.num_timesteps - 1:\n                t = repeat(torch.tensor([t]), '1 -> b', b=n_row)\n                t = t.to(self.device).long()\n                noise = torch.randn_like(x_start)\n                x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)\n                diffusion_row.append(x_noisy)\n\n        log['diffusion_row'] = self._get_rows_from_list(diffusion_row)\n\n        if sample:\n            # get denoise row\n            with self.ema_scope('Plotting'):\n                samples, denoise_row = self.sample(\n                    batch_size=N, return_intermediates=True\n                )\n\n            log['samples'] = samples\n            log['denoise_row'] = self._get_rows_from_list(denoise_row)\n\n        if return_keys:\n            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:\n                return log\n            else:\n                return {key: log[key] for key in return_keys}\n        return log\n\n    def configure_optimizers(self):\n        lr = self.learning_rate\n        params = list(self.model.parameters())\n        if self.learn_logvar:\n            params = params + [self.logvar]\n        opt = torch.optim.AdamW(params, lr=lr)\n        return opt\n\n\nclass LatentDiffusion(DDPM):\n    \"\"\"main class\"\"\"\n\n    def __init__(\n        self,\n        first_stage_config,\n        cond_stage_config,\n        personalization_config,\n        num_timesteps_cond=None,\n        cond_stage_key='image',\n        cond_stage_trainable=False,\n        concat_mode=True,\n        cond_stage_forward=None,\n        conditioning_key=None,\n        scale_factor=1.0,\n        scale_by_std=False,\n        *args,\n        **kwargs,\n    ):\n\n        self.num_timesteps_cond = default(num_timesteps_cond, 1)\n        self.scale_by_std = scale_by_std\n        assert self.num_timesteps_cond <= kwargs['timesteps']\n        # for backwards compatibility after implementation of DiffusionWrapper\n        if conditioning_key is None:\n            conditioning_key = 'concat' if concat_mode else 'crossattn'\n        if cond_stage_config == '__is_unconditional__':\n            conditioning_key = None\n        ckpt_path = kwargs.pop('ckpt_path', None)\n        ignore_keys = kwargs.pop('ignore_keys', [])\n        super().__init__(conditioning_key=conditioning_key, *args, **kwargs)\n        self.concat_mode = concat_mode\n        self.cond_stage_trainable = cond_stage_trainable\n        self.cond_stage_key = cond_stage_key\n\n        try:\n            self.num_downs = (\n                len(first_stage_config.params.ddconfig.ch_mult) - 1\n            )\n        except:\n            self.num_downs = 0\n        if not scale_by_std:\n            self.scale_factor = scale_factor\n        else:\n            self.register_buffer('scale_factor', torch.tensor(scale_factor))\n        self.instantiate_first_stage(first_stage_config)\n        self.instantiate_cond_stage(cond_stage_config)\n\n        self.cond_stage_forward = cond_stage_forward\n        self.clip_denoised = False\n        self.bbox_tokenizer = None\n\n        self.restarted_from_ckpt = False\n        if ckpt_path is not None:\n            self.init_from_ckpt(ckpt_path, ignore_keys)\n            self.restarted_from_ckpt = True\n\n        self.cond_stage_model.train = disabled_train\n        for param in self.cond_stage_model.parameters():\n            param.requires_grad = False\n\n        self.model.eval()\n        self.model.train = disabled_train\n        for param in self.model.parameters():\n            param.requires_grad = False\n\n        self.embedding_manager = self.instantiate_embedding_manager(\n            personalization_config, self.cond_stage_model\n        )\n\n        self.emb_ckpt_counter = 0\n\n        # if self.embedding_manager.is_clip:\n        #     self.cond_stage_model.update_embedding_func(self.embedding_manager)\n\n        for param in self.embedding_manager.embedding_parameters():\n            param.requires_grad = True\n\n    def make_cond_schedule(\n        self,\n    ):\n        self.cond_ids = torch.full(\n            size=(self.num_timesteps,),\n            fill_value=self.num_timesteps - 1,\n            dtype=torch.long,\n        )\n        ids = torch.round(\n            torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)\n        ).long()\n        self.cond_ids[: self.num_timesteps_cond] = ids\n\n    @rank_zero_only\n    @torch.no_grad()\n    def on_train_batch_start(self, batch, batch_idx, dataloader_idx):\n        # only for very first batch\n        if (\n            self.scale_by_std\n            and self.current_epoch == 0\n            and self.global_step == 0\n            and batch_idx == 0\n            and not self.restarted_from_ckpt\n        ):\n            assert (\n                self.scale_factor == 1.0\n            ), 'rather not use custom rescaling and std-rescaling simultaneously'\n            # set rescale weight to 1./std of encodings\n            print('### USING STD-RESCALING ###')\n            x = super().get_input(batch, self.first_stage_key)\n            x = x.to(self.device)\n            encoder_posterior = self.encode_first_stage(x)\n            z = self.get_first_stage_encoding(encoder_posterior).detach()\n            del self.scale_factor\n            self.register_buffer('scale_factor', 1.0 / z.flatten().std())\n            print(f'setting self.scale_factor to {self.scale_factor}')\n            print('### USING STD-RESCALING ###')\n\n    def register_schedule(\n        self,\n        given_betas=None,\n        beta_schedule='linear',\n        timesteps=1000,\n        linear_start=1e-4,\n        linear_end=2e-2,\n        cosine_s=8e-3,\n    ):\n        super().register_schedule(\n            given_betas,\n            beta_schedule,\n            timesteps,\n            linear_start,\n            linear_end,\n            cosine_s,\n        )\n\n        self.shorten_cond_schedule = self.num_timesteps_cond > 1\n        if self.shorten_cond_schedule:\n            self.make_cond_schedule()\n\n    def instantiate_first_stage(self, config):\n        model = instantiate_from_config(config)\n        self.first_stage_model = model.eval()\n        self.first_stage_model.train = disabled_train\n        for param in self.first_stage_model.parameters():\n            param.requires_grad = False\n\n    def instantiate_cond_stage(self, config):\n        if not self.cond_stage_trainable:\n            if config == '__is_first_stage__':\n                print('Using first stage also as cond stage.')\n                self.cond_stage_model = self.first_stage_model\n            elif config == '__is_unconditional__':\n                print(\n                    f'Training {self.__class__.__name__} as an unconditional model.'\n                )\n                self.cond_stage_model = None\n                # self.be_unconditional = True\n            else:\n                model = instantiate_from_config(config)\n                self.cond_stage_model = model.eval()\n                self.cond_stage_model.train = disabled_train\n                for param in self.cond_stage_model.parameters():\n                    param.requires_grad = False\n        else:\n            assert config != '__is_first_stage__'\n            assert config != '__is_unconditional__'\n            try:\n                model = instantiate_from_config(config)\n            except urllib.error.URLError:\n                raise SystemExit(\n                    \"* Couldn't load a dependency. Try running scripts/preload_models.py from an internet-conected machine.\"\n                )\n            self.cond_stage_model = model\n\n    def instantiate_embedding_manager(self, config, embedder):\n        model = instantiate_from_config(config, embedder=embedder)\n\n        if config.params.get(\n            'embedding_manager_ckpt', None\n        ):   # do not load if missing OR empty string\n            model.load(config.params.embedding_manager_ckpt)\n\n        return model\n\n    def _get_denoise_row_from_list(\n        self, samples, desc='', force_no_decoder_quantization=False\n    ):\n        denoise_row = []\n        for zd in tqdm(samples, desc=desc):\n            denoise_row.append(\n                self.decode_first_stage(\n                    zd.to(self.device),\n                    force_not_quantize=force_no_decoder_quantization,\n                )\n            )\n        n_imgs_per_row = len(denoise_row)\n        denoise_row = torch.stack(denoise_row)  # n_log_step, n_row, C, H, W\n        denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')\n        denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')\n        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)\n        return denoise_grid\n\n    def get_first_stage_encoding(self, encoder_posterior):\n        if isinstance(encoder_posterior, DiagonalGaussianDistribution):\n            z = encoder_posterior.sample()\n        elif isinstance(encoder_posterior, torch.Tensor):\n            z = encoder_posterior\n        else:\n            raise NotImplementedError(\n                f\"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented\"\n            )\n        return self.scale_factor * z\n\n    def get_learned_conditioning(self, c):\n        if self.cond_stage_forward is None:\n            if hasattr(self.cond_stage_model, 'encode') and callable(\n                self.cond_stage_model.encode\n            ):\n                c = self.cond_stage_model.encode(\n                    c, embedding_manager=self.embedding_manager\n                )\n                if isinstance(c, DiagonalGaussianDistribution):\n                    c = c.mode()\n            else:\n                c = self.cond_stage_model(c)\n        else:\n            assert hasattr(self.cond_stage_model, self.cond_stage_forward)\n            c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)\n        return c\n\n    def meshgrid(self, h, w):\n        y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)\n        x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)\n\n        arr = torch.cat([y, x], dim=-1)\n        return arr\n\n    def delta_border(self, h, w):\n        \"\"\"\n        :param h: height\n        :param w: width\n        :return: normalized distance to image border,\n         wtith min distance = 0 at border and max dist = 0.5 at image center\n        \"\"\"\n        lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)\n        arr = self.meshgrid(h, w) / lower_right_corner\n        dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]\n        dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]\n        edge_dist = torch.min(\n            torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1\n        )[0]\n        return edge_dist\n\n    def get_weighting(self, h, w, Ly, Lx, device):\n        weighting = self.delta_border(h, w)\n        weighting = torch.clip(\n            weighting,\n            self.split_input_params['clip_min_weight'],\n            self.split_input_params['clip_max_weight'],\n        )\n        weighting = (\n            weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)\n        )\n\n        if self.split_input_params['tie_braker']:\n            L_weighting = self.delta_border(Ly, Lx)\n            L_weighting = torch.clip(\n                L_weighting,\n                self.split_input_params['clip_min_tie_weight'],\n                self.split_input_params['clip_max_tie_weight'],\n            )\n\n            L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)\n            weighting = weighting * L_weighting\n        return weighting\n\n    def get_fold_unfold(\n        self, x, kernel_size, stride, uf=1, df=1\n    ):  # todo load once not every time, shorten code\n        \"\"\"\n        :param x: img of size (bs, c, h, w)\n        :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])\n        \"\"\"\n        bs, nc, h, w = x.shape\n\n        # number of crops in image\n        Ly = (h - kernel_size[0]) // stride[0] + 1\n        Lx = (w - kernel_size[1]) // stride[1] + 1\n\n        if uf == 1 and df == 1:\n            fold_params = dict(\n                kernel_size=kernel_size, dilation=1, padding=0, stride=stride\n            )\n            unfold = torch.nn.Unfold(**fold_params)\n\n            fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)\n\n            weighting = self.get_weighting(\n                kernel_size[0], kernel_size[1], Ly, Lx, x.device\n            ).to(x.dtype)\n            normalization = fold(weighting).view(\n                1, 1, h, w\n            )  # normalizes the overlap\n            weighting = weighting.view(\n                (1, 1, kernel_size[0], kernel_size[1], Ly * Lx)\n            )\n\n        elif uf > 1 and df == 1:\n            fold_params = dict(\n                kernel_size=kernel_size, dilation=1, padding=0, stride=stride\n            )\n            unfold = torch.nn.Unfold(**fold_params)\n\n            fold_params2 = dict(\n                kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),\n                dilation=1,\n                padding=0,\n                stride=(stride[0] * uf, stride[1] * uf),\n            )\n            fold = torch.nn.Fold(\n                output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2\n            )\n\n            weighting = self.get_weighting(\n                kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device\n            ).to(x.dtype)\n            normalization = fold(weighting).view(\n                1, 1, h * uf, w * uf\n            )  # normalizes the overlap\n            weighting = weighting.view(\n                (1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)\n            )\n\n        elif df > 1 and uf == 1:\n            fold_params = dict(\n                kernel_size=kernel_size, dilation=1, padding=0, stride=stride\n            )\n            unfold = torch.nn.Unfold(**fold_params)\n\n            fold_params2 = dict(\n                kernel_size=(kernel_size[0] // df, kernel_size[0] // df),\n                dilation=1,\n                padding=0,\n                stride=(stride[0] // df, stride[1] // df),\n            )\n            fold = torch.nn.Fold(\n                output_size=(x.shape[2] // df, x.shape[3] // df),\n                **fold_params2,\n            )\n\n            weighting = self.get_weighting(\n                kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device\n            ).to(x.dtype)\n            normalization = fold(weighting).view(\n                1, 1, h // df, w // df\n            )  # normalizes the overlap\n            weighting = weighting.view(\n                (1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)\n            )\n\n        else:\n            raise NotImplementedError\n\n        return fold, unfold, normalization, weighting\n\n    @torch.no_grad()\n    def get_input(\n        self,\n        batch,\n        k,\n        return_first_stage_outputs=False,\n        force_c_encode=False,\n        cond_key=None,\n        return_original_cond=False,\n        bs=None,\n    ):\n        x = super().get_input(batch, k)\n        if bs is not None:\n            x = x[:bs]\n        x = x.to(self.device)\n        encoder_posterior = self.encode_first_stage(x)\n        z = self.get_first_stage_encoding(encoder_posterior).detach()\n\n        if self.model.conditioning_key is not None:\n            if cond_key is None:\n                cond_key = self.cond_stage_key\n            if cond_key != self.first_stage_key:\n                if cond_key in ['caption', 'coordinates_bbox']:\n                    xc = batch[cond_key]\n                elif cond_key == 'class_label':\n                    xc = batch\n                else:\n                    xc = super().get_input(batch, cond_key).to(self.device)\n            else:\n                xc = x\n            if not self.cond_stage_trainable or force_c_encode:\n                if isinstance(xc, dict) or isinstance(xc, list):\n                    # import pudb; pudb.set_trace()\n                    c = self.get_learned_conditioning(xc)\n                else:\n                    c = self.get_learned_conditioning(xc.to(self.device))\n            else:\n                c = xc\n            if bs is not None:\n                c = c[:bs]\n\n            if self.use_positional_encodings:\n                pos_x, pos_y = self.compute_latent_shifts(batch)\n                ckey = __conditioning_keys__[self.model.conditioning_key]\n                c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}\n\n        else:\n            c = None\n            xc = None\n            if self.use_positional_encodings:\n                pos_x, pos_y = self.compute_latent_shifts(batch)\n                c = {'pos_x': pos_x, 'pos_y': pos_y}\n        out = [z, c]\n        if return_first_stage_outputs:\n            xrec = self.decode_first_stage(z)\n            out.extend([x, xrec])\n        if return_original_cond:\n            out.append(xc)\n        return out\n\n    @torch.no_grad()\n    def decode_first_stage(\n        self, z, predict_cids=False, force_not_quantize=False\n    ):\n        if predict_cids:\n            if z.dim() == 4:\n                z = torch.argmax(z.exp(), dim=1).long()\n            z = self.first_stage_model.quantize.get_codebook_entry(\n                z, shape=None\n            )\n            z = rearrange(z, 'b h w c -> b c h w').contiguous()\n\n        z = 1.0 / self.scale_factor * z\n\n        if hasattr(self, 'split_input_params'):\n            if self.split_input_params['patch_distributed_vq']:\n                ks = self.split_input_params['ks']  # eg. (128, 128)\n                stride = self.split_input_params['stride']  # eg. (64, 64)\n                uf = self.split_input_params['vqf']\n                bs, nc, h, w = z.shape\n                if ks[0] > h or ks[1] > w:\n                    ks = (min(ks[0], h), min(ks[1], w))\n                    print('reducing Kernel')\n\n                if stride[0] > h or stride[1] > w:\n                    stride = (min(stride[0], h), min(stride[1], w))\n                    print('reducing stride')\n\n                fold, unfold, normalization, weighting = self.get_fold_unfold(\n                    z, ks, stride, uf=uf\n                )\n\n                z = unfold(z)  # (bn, nc * prod(**ks), L)\n                # 1. Reshape to img shape\n                z = z.view(\n                    (z.shape[0], -1, ks[0], ks[1], z.shape[-1])\n                )  # (bn, nc, ks[0], ks[1], L )\n\n                # 2. apply model loop over last dim\n                if isinstance(self.first_stage_model, VQModelInterface):\n                    output_list = [\n                        self.first_stage_model.decode(\n                            z[:, :, :, :, i],\n                            force_not_quantize=predict_cids\n                            or force_not_quantize,\n                        )\n                        for i in range(z.shape[-1])\n                    ]\n                else:\n\n                    output_list = [\n                        self.first_stage_model.decode(z[:, :, :, :, i])\n                        for i in range(z.shape[-1])\n                    ]\n\n                o = torch.stack(\n                    output_list, axis=-1\n                )  # # (bn, nc, ks[0], ks[1], L)\n                o = o * weighting\n                # Reverse 1. reshape to img shape\n                o = o.view(\n                    (o.shape[0], -1, o.shape[-1])\n                )  # (bn, nc * ks[0] * ks[1], L)\n                # stitch crops together\n                decoded = fold(o)\n                decoded = decoded / normalization  # norm is shape (1, 1, h, w)\n                return decoded\n            else:\n                if isinstance(self.first_stage_model, VQModelInterface):\n                    return self.first_stage_model.decode(\n                        z,\n                        force_not_quantize=predict_cids or force_not_quantize,\n                    )\n                else:\n                    return self.first_stage_model.decode(z)\n\n        else:\n            if isinstance(self.first_stage_model, VQModelInterface):\n                return self.first_stage_model.decode(\n                    z, force_not_quantize=predict_cids or force_not_quantize\n                )\n            else:\n                return self.first_stage_model.decode(z)\n\n    # same as above but without decorator\n    def differentiable_decode_first_stage(\n        self, z, predict_cids=False, force_not_quantize=False\n    ):\n        if predict_cids:\n            if z.dim() == 4:\n                z = torch.argmax(z.exp(), dim=1).long()\n            z = self.first_stage_model.quantize.get_codebook_entry(\n                z, shape=None\n            )\n            z = rearrange(z, 'b h w c -> b c h w').contiguous()\n\n        z = 1.0 / self.scale_factor * z\n\n        if hasattr(self, 'split_input_params'):\n            if self.split_input_params['patch_distributed_vq']:\n                ks = self.split_input_params['ks']  # eg. (128, 128)\n                stride = self.split_input_params['stride']  # eg. (64, 64)\n                uf = self.split_input_params['vqf']\n                bs, nc, h, w = z.shape\n                if ks[0] > h or ks[1] > w:\n                    ks = (min(ks[0], h), min(ks[1], w))\n                    print('reducing Kernel')\n\n                if stride[0] > h or stride[1] > w:\n                    stride = (min(stride[0], h), min(stride[1], w))\n                    print('reducing stride')\n\n                fold, unfold, normalization, weighting = self.get_fold_unfold(\n                    z, ks, stride, uf=uf\n                )\n\n                z = unfold(z)  # (bn, nc * prod(**ks), L)\n                # 1. Reshape to img shape\n                z = z.view(\n                    (z.shape[0], -1, ks[0], ks[1], z.shape[-1])\n                )  # (bn, nc, ks[0], ks[1], L )\n\n                # 2. apply model loop over last dim\n                if isinstance(self.first_stage_model, VQModelInterface):\n                    output_list = [\n                        self.first_stage_model.decode(\n                            z[:, :, :, :, i],\n                            force_not_quantize=predict_cids\n                            or force_not_quantize,\n                        )\n                        for i in range(z.shape[-1])\n                    ]\n                else:\n\n                    output_list = [\n                        self.first_stage_model.decode(z[:, :, :, :, i])\n                        for i in range(z.shape[-1])\n                    ]\n\n                o = torch.stack(\n                    output_list, axis=-1\n                )  # # (bn, nc, ks[0], ks[1], L)\n                o = o * weighting\n                # Reverse 1. reshape to img shape\n                o = o.view(\n                    (o.shape[0], -1, o.shape[-1])\n                )  # (bn, nc * ks[0] * ks[1], L)\n                # stitch crops together\n                decoded = fold(o)\n                decoded = decoded / normalization  # norm is shape (1, 1, h, w)\n                return decoded\n            else:\n                if isinstance(self.first_stage_model, VQModelInterface):\n                    return self.first_stage_model.decode(\n                        z,\n                        force_not_quantize=predict_cids or force_not_quantize,\n                    )\n                else:\n                    return self.first_stage_model.decode(z)\n\n        else:\n            if isinstance(self.first_stage_model, VQModelInterface):\n                return self.first_stage_model.decode(\n                    z, force_not_quantize=predict_cids or force_not_quantize\n                )\n            else:\n                return self.first_stage_model.decode(z)\n\n    @torch.no_grad()\n    def encode_first_stage(self, x):\n        if hasattr(self, 'split_input_params'):\n            if self.split_input_params['patch_distributed_vq']:\n                ks = self.split_input_params['ks']  # eg. (128, 128)\n                stride = self.split_input_params['stride']  # eg. (64, 64)\n                df = self.split_input_params['vqf']\n                self.split_input_params['original_image_size'] = x.shape[-2:]\n                bs, nc, h, w = x.shape\n                if ks[0] > h or ks[1] > w:\n                    ks = (min(ks[0], h), min(ks[1], w))\n                    print('reducing Kernel')\n\n                if stride[0] > h or stride[1] > w:\n                    stride = (min(stride[0], h), min(stride[1], w))\n                    print('reducing stride')\n\n                fold, unfold, normalization, weighting = self.get_fold_unfold(\n                    x, ks, stride, df=df\n                )\n                z = unfold(x)  # (bn, nc * prod(**ks), L)\n                # Reshape to img shape\n                z = z.view(\n                    (z.shape[0], -1, ks[0], ks[1], z.shape[-1])\n                )  # (bn, nc, ks[0], ks[1], L )\n\n                output_list = [\n                    self.first_stage_model.encode(z[:, :, :, :, i])\n                    for i in range(z.shape[-1])\n                ]\n\n                o = torch.stack(output_list, axis=-1)\n                o = o * weighting\n\n                # Reverse reshape to img shape\n                o = o.view(\n                    (o.shape[0], -1, o.shape[-1])\n                )  # (bn, nc * ks[0] * ks[1], L)\n                # stitch crops together\n                decoded = fold(o)\n                decoded = decoded / normalization\n                return decoded\n\n            else:\n                return self.first_stage_model.encode(x)\n        else:\n            return self.first_stage_model.encode(x)\n\n    def shared_step(self, batch, **kwargs):\n        x, c = self.get_input(batch, self.first_stage_key)\n        loss = self(x, c)\n        return loss\n\n    def forward(self, x, c, *args, **kwargs):\n        t = torch.randint(\n            0, self.num_timesteps, (x.shape[0],), device=self.device\n        ).long()\n        if self.model.conditioning_key is not None:\n            assert c is not None\n            if self.cond_stage_trainable:\n                c = self.get_learned_conditioning(c)\n            if self.shorten_cond_schedule:  # TODO: drop this option\n                tc = self.cond_ids[t].to(self.device)\n                c = self.q_sample(\n                    x_start=c, t=tc, noise=torch.randn_like(c.float())\n                )\n\n        return self.p_losses(x, c, t, *args, **kwargs)\n\n    def _rescale_annotations(\n        self, bboxes, crop_coordinates\n    ):  # TODO: move to dataset\n        def rescale_bbox(bbox):\n            x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])\n            y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])\n            w = min(bbox[2] / crop_coordinates[2], 1 - x0)\n            h = min(bbox[3] / crop_coordinates[3], 1 - y0)\n            return x0, y0, w, h\n\n        return [rescale_bbox(b) for b in bboxes]\n\n    def apply_model(self, x_noisy, t, cond, return_ids=False):\n\n        if isinstance(cond, dict):\n            # hybrid case, cond is exptected to be a dict\n            pass\n        else:\n            if not isinstance(cond, list):\n                cond = [cond]\n            key = (\n                'c_concat'\n                if self.model.conditioning_key == 'concat'\n                else 'c_crossattn'\n            )\n            cond = {key: cond}\n\n        if hasattr(self, 'split_input_params'):\n            assert (\n                len(cond) == 1\n            )  # todo can only deal with one conditioning atm\n            assert not return_ids\n            ks = self.split_input_params['ks']  # eg. (128, 128)\n            stride = self.split_input_params['stride']  # eg. (64, 64)\n\n            h, w = x_noisy.shape[-2:]\n\n            fold, unfold, normalization, weighting = self.get_fold_unfold(\n                x_noisy, ks, stride\n            )\n\n            z = unfold(x_noisy)  # (bn, nc * prod(**ks), L)\n            # Reshape to img shape\n            z = z.view(\n                (z.shape[0], -1, ks[0], ks[1], z.shape[-1])\n            )  # (bn, nc, ks[0], ks[1], L )\n            z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]\n\n            if (\n                self.cond_stage_key\n                in ['image', 'LR_image', 'segmentation', 'bbox_img']\n                and self.model.conditioning_key\n            ):  # todo check for completeness\n                c_key = next(iter(cond.keys()))  # get key\n                c = next(iter(cond.values()))  # get value\n                assert (\n                    len(c) == 1\n                )  # todo extend to list with more than one elem\n                c = c[0]  # get element\n\n                c = unfold(c)\n                c = c.view(\n                    (c.shape[0], -1, ks[0], ks[1], c.shape[-1])\n                )  # (bn, nc, ks[0], ks[1], L )\n\n                cond_list = [\n                    {c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])\n                ]\n\n            elif self.cond_stage_key == 'coordinates_bbox':\n                assert (\n                    'original_image_size' in self.split_input_params\n                ), 'BoudingBoxRescaling is missing original_image_size'\n\n                # assuming padding of unfold is always 0 and its dilation is always 1\n                n_patches_per_row = int((w - ks[0]) / stride[0] + 1)\n                full_img_h, full_img_w = self.split_input_params[\n                    'original_image_size'\n                ]\n                # as we are operating on latents, we need the factor from the original image size to the\n                # spatial latent size to properly rescale the crops for regenerating the bbox annotations\n                num_downs = self.first_stage_model.encoder.num_resolutions - 1\n                rescale_latent = 2 ** (num_downs)\n\n                # get top left postions of patches as conforming for the bbbox tokenizer, therefore we\n                # need to rescale the tl patch coordinates to be in between (0,1)\n                tl_patch_coordinates = [\n                    (\n                        rescale_latent\n                        * stride[0]\n                        * (patch_nr % n_patches_per_row)\n                        / full_img_w,\n                        rescale_latent\n                        * stride[1]\n                        * (patch_nr // n_patches_per_row)\n                        / full_img_h,\n                    )\n                    for patch_nr in range(z.shape[-1])\n                ]\n\n                # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)\n                patch_limits = [\n                    (\n                        x_tl,\n                        y_tl,\n                        rescale_latent * ks[0] / full_img_w,\n                        rescale_latent * ks[1] / full_img_h,\n                    )\n                    for x_tl, y_tl in tl_patch_coordinates\n                ]\n                # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]\n\n                # tokenize crop coordinates for the bounding boxes of the respective patches\n                patch_limits_tknzd = [\n                    torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[\n                        None\n                    ].to(self.device)\n                    for bbox in patch_limits\n                ]  # list of length l with tensors of shape (1, 2)\n                print(patch_limits_tknzd[0].shape)\n                # cut tknzd crop position from conditioning\n                assert isinstance(\n                    cond, dict\n                ), 'cond must be dict to be fed into model'\n                cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)\n                print(cut_cond.shape)\n\n                adapted_cond = torch.stack(\n                    [\n                        torch.cat([cut_cond, p], dim=1)\n                        for p in patch_limits_tknzd\n                    ]\n                )\n                adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')\n                print(adapted_cond.shape)\n                adapted_cond = self.get_learned_conditioning(adapted_cond)\n                print(adapted_cond.shape)\n                adapted_cond = rearrange(\n                    adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]\n                )\n                print(adapted_cond.shape)\n\n                cond_list = [{'c_crossattn': [e]} for e in adapted_cond]\n\n            else:\n                cond_list = [\n                    cond for i in range(z.shape[-1])\n                ]  # Todo make this more efficient\n\n            # apply model by loop over crops\n            output_list = [\n                self.model(z_list[i], t, **cond_list[i])\n                for i in range(z.shape[-1])\n            ]\n            assert not isinstance(\n                output_list[0], tuple\n            )  # todo cant deal with multiple model outputs check this never happens\n\n            o = torch.stack(output_list, axis=-1)\n            o = o * weighting\n            # Reverse reshape to img shape\n            o = o.view(\n                (o.shape[0], -1, o.shape[-1])\n            )  # (bn, nc * ks[0] * ks[1], L)\n            # stitch crops together\n            x_recon = fold(o) / normalization\n\n        else:\n            x_recon = self.model(x_noisy, t, **cond)\n\n        if isinstance(x_recon, tuple) and not return_ids:\n            return x_recon[0]\n        else:\n            return x_recon\n\n    def _predict_eps_from_xstart(self, x_t, t, pred_xstart):\n        return (\n            extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape)\n            * x_t\n            - pred_xstart\n        ) / extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)\n\n    def _prior_bpd(self, x_start):\n        \"\"\"\n        Get the prior KL term for the variational lower-bound, measured in\n        bits-per-dim.\n        This term can't be optimized, as it only depends on the encoder.\n        :param x_start: the [N x C x ...] tensor of inputs.\n        :return: a batch of [N] KL values (in bits), one per batch element.\n        \"\"\"\n        batch_size = x_start.shape[0]\n        t = torch.tensor(\n            [self.num_timesteps - 1] * batch_size, device=x_start.device\n        )\n        qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)\n        kl_prior = normal_kl(\n            mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0\n        )\n        return mean_flat(kl_prior) / np.log(2.0)\n\n    def p_losses(self, x_start, cond, t, noise=None):\n        noise = default(noise, lambda: torch.randn_like(x_start))\n        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)\n        model_output = self.apply_model(x_noisy, t, cond)\n\n        loss_dict = {}\n        prefix = 'train' if self.training else 'val'\n\n        if self.parameterization == 'x0':\n            target = x_start\n        elif self.parameterization == 'eps':\n            target = noise\n        else:\n            raise NotImplementedError()\n\n        loss_simple = self.get_loss(model_output, target, mean=False).mean(\n            [1, 2, 3]\n        )\n        loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})\n\n        logvar_t = self.logvar[t].to(self.device)\n        loss = loss_simple / torch.exp(logvar_t) + logvar_t\n        # loss = loss_simple / torch.exp(self.logvar) + self.logvar\n        if self.learn_logvar:\n            loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})\n            loss_dict.update({'logvar': self.logvar.data.mean()})\n\n        loss = self.l_simple_weight * loss.mean()\n\n        loss_vlb = self.get_loss(model_output, target, mean=False).mean(\n            dim=(1, 2, 3)\n        )\n        loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()\n        loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})\n        loss += self.original_elbo_weight * loss_vlb\n        loss_dict.update({f'{prefix}/loss': loss})\n\n        if self.embedding_reg_weight > 0:\n            loss_embedding_reg = (\n                self.embedding_manager.embedding_to_coarse_loss().mean()\n            )\n\n            loss_dict.update({f'{prefix}/loss_emb_reg': loss_embedding_reg})\n\n            loss += self.embedding_reg_weight * loss_embedding_reg\n            loss_dict.update({f'{prefix}/loss': loss})\n\n        return loss, loss_dict\n\n    def p_mean_variance(\n        self,\n        x,\n        c,\n        t,\n        clip_denoised: bool,\n        return_codebook_ids=False,\n        quantize_denoised=False,\n        return_x0=False,\n        score_corrector=None,\n        corrector_kwargs=None,\n    ):\n        t_in = t\n        model_out = self.apply_model(\n            x, t_in, c, return_ids=return_codebook_ids\n        )\n\n        if score_corrector is not None:\n            assert self.parameterization == 'eps'\n            model_out = score_corrector.modify_score(\n                self, model_out, x, t, c, **corrector_kwargs\n            )\n\n        if return_codebook_ids:\n            model_out, logits = model_out\n\n        if self.parameterization == 'eps':\n            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)\n        elif self.parameterization == 'x0':\n            x_recon = model_out\n        else:\n            raise NotImplementedError()\n\n        if clip_denoised:\n            x_recon.clamp_(-1.0, 1.0)\n        if quantize_denoised:\n            x_recon, _, [_, _, indices] = self.first_stage_model.quantize(\n                x_recon\n            )\n        (\n            model_mean,\n            posterior_variance,\n            posterior_log_variance,\n        ) = self.q_posterior(x_start=x_recon, x_t=x, t=t)\n        if return_codebook_ids:\n            return (\n                model_mean,\n                posterior_variance,\n                posterior_log_variance,\n                logits,\n            )\n        elif return_x0:\n            return (\n                model_mean,\n                posterior_variance,\n                posterior_log_variance,\n                x_recon,\n            )\n        else:\n            return model_mean, posterior_variance, posterior_log_variance\n\n    @torch.no_grad()\n    def p_sample(\n        self,\n        x,\n        c,\n        t,\n        clip_denoised=False,\n        repeat_noise=False,\n        return_codebook_ids=False,\n        quantize_denoised=False,\n        return_x0=False,\n        temperature=1.0,\n        noise_dropout=0.0,\n        score_corrector=None,\n        corrector_kwargs=None,\n    ):\n        b, *_, device = *x.shape, x.device\n        outputs = self.p_mean_variance(\n            x=x,\n            c=c,\n            t=t,\n            clip_denoised=clip_denoised,\n            return_codebook_ids=return_codebook_ids,\n            quantize_denoised=quantize_denoised,\n            return_x0=return_x0,\n            score_corrector=score_corrector,\n            corrector_kwargs=corrector_kwargs,\n        )\n        if return_codebook_ids:\n            raise DeprecationWarning('Support dropped.')\n            model_mean, _, model_log_variance, logits = outputs\n        elif return_x0:\n            model_mean, _, model_log_variance, x0 = outputs\n        else:\n            model_mean, _, model_log_variance = outputs\n\n        noise = noise_like(x.shape, device, repeat_noise) * temperature\n        if noise_dropout > 0.0:\n            noise = torch.nn.functional.dropout(noise, p=noise_dropout)\n        # no noise when t == 0\n        nonzero_mask = (1 - (t == 0).float()).reshape(\n            b, *((1,) * (len(x.shape) - 1))\n        )\n\n        if return_codebook_ids:\n            return model_mean + nonzero_mask * (\n                0.5 * model_log_variance\n            ).exp() * noise, logits.argmax(dim=1)\n        if return_x0:\n            return (\n                model_mean\n                + nonzero_mask * (0.5 * model_log_variance).exp() * noise,\n                x0,\n            )\n        else:\n            return (\n                model_mean\n                + nonzero_mask * (0.5 * model_log_variance).exp() * noise\n            )\n\n    @torch.no_grad()\n    def progressive_denoising(\n        self,\n        cond,\n        shape,\n        verbose=True,\n        callback=None,\n        quantize_denoised=False,\n        img_callback=None,\n        mask=None,\n        x0=None,\n        temperature=1.0,\n        noise_dropout=0.0,\n        score_corrector=None,\n        corrector_kwargs=None,\n        batch_size=None,\n        x_T=None,\n        start_T=None,\n        log_every_t=None,\n    ):\n        if not log_every_t:\n            log_every_t = self.log_every_t\n        timesteps = self.num_timesteps\n        if batch_size is not None:\n            b = batch_size if batch_size is not None else shape[0]\n            shape = [batch_size] + list(shape)\n        else:\n            b = batch_size = shape[0]\n        if x_T is None:\n            img = torch.randn(shape, device=self.device)\n        else:\n            img = x_T\n        intermediates = []\n        if cond is not None:\n            if isinstance(cond, dict):\n                cond = {\n                    key: cond[key][:batch_size]\n                    if not isinstance(cond[key], list)\n                    else list(map(lambda x: x[:batch_size], cond[key]))\n                    for key in cond\n                }\n            else:\n                cond = (\n                    [c[:batch_size] for c in cond]\n                    if isinstance(cond, list)\n                    else cond[:batch_size]\n                )\n\n        if start_T is not None:\n            timesteps = min(timesteps, start_T)\n        iterator = (\n            tqdm(\n                reversed(range(0, timesteps)),\n                desc='Progressive Generation',\n                total=timesteps,\n            )\n            if verbose\n            else reversed(range(0, timesteps))\n        )\n        if type(temperature) == float:\n            temperature = [temperature] * timesteps\n\n        for i in iterator:\n            ts = torch.full((b,), i, device=self.device, dtype=torch.long)\n            if self.shorten_cond_schedule:\n                assert self.model.conditioning_key != 'hybrid'\n                tc = self.cond_ids[ts].to(cond.device)\n                cond = self.q_sample(\n                    x_start=cond, t=tc, noise=torch.randn_like(cond)\n                )\n\n            img, x0_partial = self.p_sample(\n                img,\n                cond,\n                ts,\n                clip_denoised=self.clip_denoised,\n                quantize_denoised=quantize_denoised,\n                return_x0=True,\n                temperature=temperature[i],\n                noise_dropout=noise_dropout,\n                score_corrector=score_corrector,\n                corrector_kwargs=corrector_kwargs,\n            )\n            if mask is not None:\n                assert x0 is not None\n                img_orig = self.q_sample(x0, ts)\n                img = img_orig * mask + (1.0 - mask) * img\n\n            if i % log_every_t == 0 or i == timesteps - 1:\n                intermediates.append(x0_partial)\n            if callback:\n                callback(i)\n            if img_callback:\n                img_callback(img, i)\n        return img, intermediates\n\n    @torch.no_grad()\n    def p_sample_loop(\n        self,\n        cond,\n        shape,\n        return_intermediates=False,\n        x_T=None,\n        verbose=True,\n        callback=None,\n        timesteps=None,\n        quantize_denoised=False,\n        mask=None,\n        x0=None,\n        img_callback=None,\n        start_T=None,\n        log_every_t=None,\n    ):\n\n        if not log_every_t:\n            log_every_t = self.log_every_t\n        device = self.betas.device\n        b = shape[0]\n        if x_T is None:\n            img = torch.randn(shape, device=device)\n        else:\n            img = x_T\n\n        intermediates = [img]\n        if timesteps is None:\n            timesteps = self.num_timesteps\n\n        if start_T is not None:\n            timesteps = min(timesteps, start_T)\n        iterator = (\n            tqdm(\n                reversed(range(0, timesteps)),\n                desc='Sampling t',\n                total=timesteps,\n            )\n            if verbose\n            else reversed(range(0, timesteps))\n        )\n\n        if mask is not None:\n            assert x0 is not None\n            assert (\n                x0.shape[2:3] == mask.shape[2:3]\n            )  # spatial size has to match\n\n        for i in iterator:\n            ts = torch.full((b,), i, device=device, dtype=torch.long)\n            if self.shorten_cond_schedule:\n                assert self.model.conditioning_key != 'hybrid'\n                tc = self.cond_ids[ts].to(cond.device)\n                cond = self.q_sample(\n                    x_start=cond, t=tc, noise=torch.randn_like(cond)\n                )\n\n            img = self.p_sample(\n                img,\n                cond,\n                ts,\n                clip_denoised=self.clip_denoised,\n                quantize_denoised=quantize_denoised,\n            )\n            if mask is not None:\n                img_orig = self.q_sample(x0, ts)\n                img = img_orig * mask + (1.0 - mask) * img\n\n            if i % log_every_t == 0 or i == timesteps - 1:\n                intermediates.append(img)\n            if callback:\n                callback(i)\n            if img_callback:\n                img_callback(img, i)\n\n        if return_intermediates:\n            return img, intermediates\n        return img\n\n    @torch.no_grad()\n    def sample(\n        self,\n        cond,\n        batch_size=16,\n        return_intermediates=False,\n        x_T=None,\n        verbose=True,\n        timesteps=None,\n        quantize_denoised=False,\n        mask=None,\n        x0=None,\n        shape=None,\n        **kwargs,\n    ):\n        if shape is None:\n            shape = (\n                batch_size,\n                self.channels,\n                self.image_size,\n                self.image_size,\n            )\n        if cond is not None:\n            if isinstance(cond, dict):\n                cond = {\n                    key: cond[key][:batch_size]\n                    if not isinstance(cond[key], list)\n                    else list(map(lambda x: x[:batch_size], cond[key]))\n                    for key in cond\n                }\n            else:\n                cond = (\n                    [c[:batch_size] for c in cond]\n                    if isinstance(cond, list)\n                    else cond[:batch_size]\n                )\n        return self.p_sample_loop(\n            cond,\n            shape,\n            return_intermediates=return_intermediates,\n            x_T=x_T,\n            verbose=verbose,\n            timesteps=timesteps,\n            quantize_denoised=quantize_denoised,\n            mask=mask,\n            x0=x0,\n        )\n\n    @torch.no_grad()\n    def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):\n\n        if ddim:\n            ddim_sampler = DDIMSampler(self)\n            shape = (self.channels, self.image_size, self.image_size)\n            samples, intermediates = ddim_sampler.sample(\n                ddim_steps, batch_size, shape, cond, verbose=False, **kwargs\n            )\n\n        else:\n            samples, intermediates = self.sample(\n                cond=cond,\n                batch_size=batch_size,\n                return_intermediates=True,\n                **kwargs,\n            )\n\n        return samples, intermediates\n\n    @torch.no_grad()\n    def log_images(\n        self,\n        batch,\n        N=8,\n        n_row=4,\n        sample=True,\n        ddim_steps=200,\n        ddim_eta=1.0,\n        return_keys=None,\n        quantize_denoised=True,\n        inpaint=False,\n        plot_denoise_rows=False,\n        plot_progressive_rows=False,\n        plot_diffusion_rows=False,\n        **kwargs,\n    ):\n\n        use_ddim = ddim_steps is not None\n\n        log = dict()\n        z, c, x, xrec, xc = self.get_input(\n            batch,\n            self.first_stage_key,\n            return_first_stage_outputs=True,\n            force_c_encode=True,\n            return_original_cond=True,\n            bs=N,\n        )\n        N = min(x.shape[0], N)\n        n_row = min(x.shape[0], n_row)\n        log['inputs'] = x\n        log['reconstruction'] = xrec\n        if self.model.conditioning_key is not None:\n            if hasattr(self.cond_stage_model, 'decode'):\n                xc = self.cond_stage_model.decode(c)\n                log['conditioning'] = xc\n            elif self.cond_stage_key in ['caption']:\n                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch['caption'])\n                log['conditioning'] = xc\n            elif self.cond_stage_key == 'class_label':\n                xc = log_txt_as_img(\n                    (x.shape[2], x.shape[3]), batch['human_label']\n                )\n                log['conditioning'] = xc\n            elif isimage(xc):\n                log['conditioning'] = xc\n            if ismap(xc):\n                log['original_conditioning'] = self.to_rgb(xc)\n\n        if plot_diffusion_rows:\n            # get diffusion row\n            diffusion_row = list()\n            z_start = z[:n_row]\n            for t in range(self.num_timesteps):\n                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:\n                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)\n                    t = t.to(self.device).long()\n                    noise = torch.randn_like(z_start)\n                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)\n                    diffusion_row.append(self.decode_first_stage(z_noisy))\n\n            diffusion_row = torch.stack(\n                diffusion_row\n            )  # n_log_step, n_row, C, H, W\n            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')\n            diffusion_grid = rearrange(\n                diffusion_grid, 'b n c h w -> (b n) c h w'\n            )\n            diffusion_grid = make_grid(\n                diffusion_grid, nrow=diffusion_row.shape[0]\n            )\n            log['diffusion_row'] = diffusion_grid\n\n        if sample:\n            # get denoise row\n            with self.ema_scope('Plotting'):\n                samples, z_denoise_row = self.sample_log(\n                    cond=c,\n                    batch_size=N,\n                    ddim=use_ddim,\n                    ddim_steps=ddim_steps,\n                    eta=ddim_eta,\n                )\n                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)\n            x_samples = self.decode_first_stage(samples)\n            log['samples'] = x_samples\n            if plot_denoise_rows:\n                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)\n                log['denoise_row'] = denoise_grid\n\n            uc = self.get_learned_conditioning(len(c) * [''])\n            sample_scaled, _ = self.sample_log(\n                cond=c,\n                batch_size=N,\n                ddim=use_ddim,\n                ddim_steps=ddim_steps,\n                eta=ddim_eta,\n                unconditional_guidance_scale=5.0,\n                unconditional_conditioning=uc,\n            )\n            log['samples_scaled'] = self.decode_first_stage(sample_scaled)\n\n            if (\n                quantize_denoised\n                and not isinstance(self.first_stage_model, AutoencoderKL)\n                and not isinstance(self.first_stage_model, IdentityFirstStage)\n            ):\n                # also display when quantizing x0 while sampling\n                with self.ema_scope('Plotting Quantized Denoised'):\n                    samples, z_denoise_row = self.sample_log(\n                        cond=c,\n                        batch_size=N,\n                        ddim=use_ddim,\n                        ddim_steps=ddim_steps,\n                        eta=ddim_eta,\n                        quantize_denoised=True,\n                    )\n                    # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,\n                    #                                      quantize_denoised=True)\n                x_samples = self.decode_first_stage(samples.to(self.device))\n                log['samples_x0_quantized'] = x_samples\n\n            if inpaint:\n                # make a simple center square\n                b, h, w = z.shape[0], z.shape[2], z.shape[3]\n                mask = torch.ones(N, h, w).to(self.device)\n                # zeros will be filled in\n                mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0\n                mask = mask[:, None, ...]\n                with self.ema_scope('Plotting Inpaint'):\n\n                    samples, _ = self.sample_log(\n                        cond=c,\n                        batch_size=N,\n                        ddim=use_ddim,\n                        eta=ddim_eta,\n                        ddim_steps=ddim_steps,\n                        x0=z[:N],\n                        mask=mask,\n                    )\n                x_samples = self.decode_first_stage(samples.to(self.device))\n                log['samples_inpainting'] = x_samples\n                log['mask'] = mask\n\n                # outpaint\n                with self.ema_scope('Plotting Outpaint'):\n                    samples, _ = self.sample_log(\n                        cond=c,\n                        batch_size=N,\n                        ddim=use_ddim,\n                        eta=ddim_eta,\n                        ddim_steps=ddim_steps,\n                        x0=z[:N],\n                        mask=mask,\n                    )\n                x_samples = self.decode_first_stage(samples.to(self.device))\n                log['samples_outpainting'] = x_samples\n\n        if plot_progressive_rows:\n            with self.ema_scope('Plotting Progressives'):\n                img, progressives = self.progressive_denoising(\n                    c,\n                    shape=(self.channels, self.image_size, self.image_size),\n                    batch_size=N,\n                )\n            prog_row = self._get_denoise_row_from_list(\n                progressives, desc='Progressive Generation'\n            )\n            log['progressive_row'] = prog_row\n\n        if return_keys:\n            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:\n                return log\n            else:\n                return {key: log[key] for key in return_keys}\n        return log\n\n    def configure_optimizers(self):\n        lr = self.learning_rate\n\n        if self.embedding_manager is not None:\n            params = list(self.embedding_manager.embedding_parameters())\n            # params = list(self.cond_stage_model.transformer.text_model.embeddings.embedding_manager.embedding_parameters())\n        else:\n            params = list(self.model.parameters())\n            if self.cond_stage_trainable:\n                print(\n                    f'{self.__class__.__name__}: Also optimizing conditioner params!'\n                )\n                params = params + list(self.cond_stage_model.parameters())\n            if self.learn_logvar:\n                print('Diffusion model optimizing logvar')\n                params.append(self.logvar)\n        opt = torch.optim.AdamW(params, lr=lr)\n        if self.use_scheduler:\n            assert 'target' in self.scheduler_config\n            scheduler = instantiate_from_config(self.scheduler_config)\n\n            print('Setting up LambdaLR scheduler...')\n            scheduler = [\n                {\n                    'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),\n                    'interval': 'step',\n                    'frequency': 1,\n                }\n            ]\n            return [opt], scheduler\n        return opt\n\n    @torch.no_grad()\n    def to_rgb(self, x):\n        x = x.float()\n        if not hasattr(self, 'colorize'):\n            self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)\n        x = nn.functional.conv2d(x, weight=self.colorize)\n        x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0\n        return x\n\n    @rank_zero_only\n    def on_save_checkpoint(self, checkpoint):\n        checkpoint.clear()\n\n        if os.path.isdir(self.trainer.checkpoint_callback.dirpath):\n            self.embedding_manager.save(\n                os.path.join(\n                    self.trainer.checkpoint_callback.dirpath, 'embeddings.pt'\n                )\n            )\n\n            if (self.global_step - self.emb_ckpt_counter) > 500:\n                self.embedding_manager.save(\n                    os.path.join(\n                        self.trainer.checkpoint_callback.dirpath,\n                        f'embeddings_gs-{self.global_step}.pt',\n                    )\n                )\n\n                self.emb_ckpt_counter += 500\n\n\nclass DiffusionWrapper(pl.LightningModule):\n    def __init__(self, diff_model_config, conditioning_key):\n        super().__init__()\n        self.diffusion_model = instantiate_from_config(diff_model_config)\n        self.conditioning_key = conditioning_key\n        assert self.conditioning_key in [\n            None,\n            'concat',\n            'crossattn',\n            'hybrid',\n            'adm',\n        ]\n\n    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):\n        if self.conditioning_key is None:\n            out = self.diffusion_model(x, t)\n        elif self.conditioning_key == 'concat':\n            xc = torch.cat([x] + c_concat, dim=1)\n            out = self.diffusion_model(xc, t)\n        elif self.conditioning_key == 'crossattn':\n            cc = torch.cat(c_crossattn, 1)\n            out = self.diffusion_model(x, t, context=cc)\n        elif self.conditioning_key == 'hybrid':\n            xc = torch.cat([x] + c_concat, dim=1)\n            cc = torch.cat(c_crossattn, 1)\n            out = self.diffusion_model(xc, t, context=cc)\n        elif self.conditioning_key == 'adm':\n            cc = c_crossattn[0]\n            out = self.diffusion_model(x, t, y=cc)\n        else:\n            raise NotImplementedError()\n\n        return out\n\n\nclass Layout2ImgDiffusion(LatentDiffusion):\n    # TODO: move all layout-specific hacks to this class\n    def __init__(self, cond_stage_key, *args, **kwargs):\n        assert (\n            cond_stage_key == 'coordinates_bbox'\n        ), 'Layout2ImgDiffusion only for cond_stage_key=\"coordinates_bbox\"'\n        super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)\n\n    def log_images(self, batch, N=8, *args, **kwargs):\n        logs = super().log_images(batch=batch, N=N, *args, **kwargs)\n\n        key = 'train' if self.training else 'validation'\n        dset = self.trainer.datamodule.datasets[key]\n        mapper = dset.conditional_builders[self.cond_stage_key]\n\n        bbox_imgs = []\n        map_fn = lambda catno: dset.get_textual_label(\n            dset.get_category_id(catno)\n        )\n        for tknzd_bbox in batch[self.cond_stage_key][:N]:\n            bboximg = mapper.plot(\n                tknzd_bbox.detach().cpu(), map_fn, (256, 256)\n            )\n            bbox_imgs.append(bboximg)\n\n        cond_img = torch.stack(bbox_imgs, dim=0)\n        logs['bbox_image'] = cond_img\n        return logs\n"
  },
  {
    "path": "src/stablediffusion/ldm/models/diffusion/ksampler.py",
    "content": "\"\"\"wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers\"\"\"\nimport k_diffusion as K\nimport torch\nimport torch.nn as nn\nfrom src.stablediffusion.ldm.dream.devices import choose_torch_device\n\nclass CFGDenoiser(nn.Module):\n    def __init__(self, model):\n        super().__init__()\n        self.inner_model = model\n\n    def forward(self, x, sigma, uncond, cond, cond_scale):\n        x_in = torch.cat([x] * 2)\n        sigma_in = torch.cat([sigma] * 2)\n        cond_in = torch.cat([uncond, cond])\n        uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)\n        return uncond + (cond - uncond) * cond_scale\n\n\nclass KSampler(object):\n    def __init__(self, model, schedule='lms', device=None, **kwargs):\n        super().__init__()\n        self.model = K.external.CompVisDenoiser(model)\n        self.schedule = schedule\n        self.device   = device or choose_torch_device()\n\n        def forward(self, x, sigma, uncond, cond, cond_scale):\n            x_in = torch.cat([x] * 2)\n            sigma_in = torch.cat([sigma] * 2)\n            cond_in = torch.cat([uncond, cond])\n            uncond, cond = self.inner_model(\n                x_in, sigma_in, cond=cond_in\n            ).chunk(2)\n            return uncond + (cond - uncond) * cond_scale\n\n    # most of these arguments are ignored and are only present for compatibility with\n    # other samples\n    @torch.no_grad()\n    def sample(\n        self,\n        S,\n        batch_size,\n        shape,\n        conditioning=None,\n        callback=None,\n        normals_sequence=None,\n        img_callback=None,\n        quantize_x0=False,\n        eta=0.0,\n        mask=None,\n        x0=None,\n        temperature=1.0,\n        noise_dropout=0.0,\n        score_corrector=None,\n        corrector_kwargs=None,\n        verbose=True,\n        x_T=None,\n        log_every_t=100,\n        unconditional_guidance_scale=1.0,\n        unconditional_conditioning=None,\n        # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...\n        **kwargs,\n    ):\n        def route_callback(k_callback_values):\n            if img_callback is not None:\n                img_callback(k_callback_values['x'], k_callback_values['i'])\n\n        sigmas = self.model.get_sigmas(S)\n        if x_T is not None:\n            x = x_T * sigmas[0]\n        else:\n            x = (\n                torch.randn([batch_size, *shape], device=self.device)\n                * sigmas[0]\n            )   # for GPU draw\n        model_wrap_cfg = CFGDenoiser(self.model)\n        extra_args = {\n            'cond': conditioning,\n            'uncond': unconditional_conditioning,\n            'cond_scale': unconditional_guidance_scale,\n        }\n        return (\n            K.sampling.__dict__[f'sample_{self.schedule}'](\n                model_wrap_cfg, x, sigmas, extra_args=extra_args,\n                callback=route_callback\n            ),\n            None,\n        )\n"
  },
  {
    "path": "src/stablediffusion/ldm/models/diffusion/plms.py",
    "content": "\"\"\"SAMPLING ONLY.\"\"\"\n\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\nfrom functools import partial\nfrom src.stablediffusion.ldm.dream.devices import choose_torch_device\n\nfrom src.stablediffusion.ldm.modules.diffusionmodules.util import (\n    make_ddim_sampling_parameters,\n    make_ddim_timesteps,\n    noise_like,\n)\n\n\nclass PLMSSampler(object):\n    def __init__(self, model, schedule='linear', device=None, **kwargs):\n        super().__init__()\n        self.model = model\n        self.ddpm_num_timesteps = model.num_timesteps\n        self.schedule = schedule\n        self.device   = device if device else choose_torch_device()\n\n    def register_buffer(self, name, attr):\n        if type(attr) == torch.Tensor:\n            if attr.device != torch.device(self.device):\n                attr = attr.to(torch.float32).to(torch.device(self.device))\n        setattr(self, name, attr)\n\n    def make_schedule(\n        self,\n        ddim_num_steps,\n        ddim_discretize='uniform',\n        ddim_eta=0.0,\n        verbose=True,\n    ):\n        if ddim_eta != 0:\n            raise ValueError('ddim_eta must be 0 for PLMS')\n        self.ddim_timesteps = make_ddim_timesteps(\n            ddim_discr_method=ddim_discretize,\n            num_ddim_timesteps=ddim_num_steps,\n            num_ddpm_timesteps=self.ddpm_num_timesteps,\n            verbose=verbose,\n        )\n        alphas_cumprod = self.model.alphas_cumprod\n        assert (\n            alphas_cumprod.shape[0] == self.ddpm_num_timesteps\n        ), 'alphas have to be defined for each timestep'\n        to_torch = (\n            lambda x: x.clone()\n            .detach()\n            .to(torch.float32)\n            .to(self.model.device)\n        )\n\n        self.register_buffer('betas', to_torch(self.model.betas))\n        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))\n        self.register_buffer(\n            'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)\n        )\n\n        # calculations for diffusion q(x_t | x_{t-1}) and others\n        self.register_buffer(\n            'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))\n        )\n        self.register_buffer(\n            'sqrt_one_minus_alphas_cumprod',\n            to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),\n        )\n        self.register_buffer(\n            'log_one_minus_alphas_cumprod',\n            to_torch(np.log(1.0 - alphas_cumprod.cpu())),\n        )\n        self.register_buffer(\n            'sqrt_recip_alphas_cumprod',\n            to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),\n        )\n        self.register_buffer(\n            'sqrt_recipm1_alphas_cumprod',\n            to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),\n        )\n\n        # ddim sampling parameters\n        (\n            ddim_sigmas,\n            ddim_alphas,\n            ddim_alphas_prev,\n        ) = make_ddim_sampling_parameters(\n            alphacums=alphas_cumprod.cpu(),\n            ddim_timesteps=self.ddim_timesteps,\n            eta=ddim_eta,\n            verbose=verbose,\n        )\n        self.register_buffer('ddim_sigmas', ddim_sigmas)\n        self.register_buffer('ddim_alphas', ddim_alphas)\n        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)\n        self.register_buffer(\n            'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)\n        )\n        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(\n            (1 - self.alphas_cumprod_prev)\n            / (1 - self.alphas_cumprod)\n            * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)\n        )\n        self.register_buffer(\n            'ddim_sigmas_for_original_num_steps',\n            sigmas_for_original_sampling_steps,\n        )\n\n    @torch.no_grad()\n    def sample(\n        self,\n        S,\n        batch_size,\n        shape,\n        conditioning=None,\n        callback=None,\n        normals_sequence=None,\n        img_callback=None,\n        quantize_x0=False,\n        eta=0.0,\n        mask=None,\n        x0=None,\n        temperature=1.0,\n        noise_dropout=0.0,\n        score_corrector=None,\n        corrector_kwargs=None,\n        verbose=True,\n        x_T=None,\n        log_every_t=100,\n        unconditional_guidance_scale=1.0,\n        unconditional_conditioning=None,\n        # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...\n        **kwargs,\n    ):\n        if conditioning is not None:\n            if isinstance(conditioning, dict):\n                cbs = conditioning[list(conditioning.keys())[0]].shape[0]\n                if cbs != batch_size:\n                    print(\n                        f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'\n                    )\n            else:\n                if conditioning.shape[0] != batch_size:\n                    print(\n                        f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'\n                    )\n\n        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)\n        # sampling\n        C, H, W = shape\n        size = (batch_size, C, H, W)\n        #        print(f'Data shape for PLMS sampling is {size}')\n\n        samples, intermediates = self.plms_sampling(\n            conditioning,\n            size,\n            callback=callback,\n            img_callback=img_callback,\n            quantize_denoised=quantize_x0,\n            mask=mask,\n            x0=x0,\n            ddim_use_original_steps=False,\n            noise_dropout=noise_dropout,\n            temperature=temperature,\n            score_corrector=score_corrector,\n            corrector_kwargs=corrector_kwargs,\n            x_T=x_T,\n            log_every_t=log_every_t,\n            unconditional_guidance_scale=unconditional_guidance_scale,\n            unconditional_conditioning=unconditional_conditioning,\n        )\n        return samples, intermediates\n\n    @torch.no_grad()\n    def plms_sampling(\n        self,\n        cond,\n        shape,\n        x_T=None,\n        ddim_use_original_steps=False,\n        callback=None,\n        timesteps=None,\n        quantize_denoised=False,\n        mask=None,\n        x0=None,\n        img_callback=None,\n        log_every_t=100,\n        temperature=1.0,\n        noise_dropout=0.0,\n        score_corrector=None,\n        corrector_kwargs=None,\n        unconditional_guidance_scale=1.0,\n        unconditional_conditioning=None,\n    ):\n        device = self.model.betas.device\n        b = shape[0]\n        if x_T is None:\n            img = torch.randn(shape, device=device)\n        else:\n            img = x_T\n\n        if timesteps is None:\n            timesteps = (\n                self.ddpm_num_timesteps\n                if ddim_use_original_steps\n                else self.ddim_timesteps\n            )\n        elif timesteps is not None and not ddim_use_original_steps:\n            subset_end = (\n                int(\n                    min(timesteps / self.ddim_timesteps.shape[0], 1)\n                    * self.ddim_timesteps.shape[0]\n                )\n                - 1\n            )\n            timesteps = self.ddim_timesteps[:subset_end]\n\n        intermediates = {'x_inter': [img], 'pred_x0': [img]}\n        time_range = (\n            list(reversed(range(0, timesteps)))\n            if ddim_use_original_steps\n            else np.flip(timesteps)\n        )\n        total_steps = (\n            timesteps if ddim_use_original_steps else timesteps.shape[0]\n        )\n        #        print(f\"Running PLMS Sampling with {total_steps} timesteps\")\n\n        iterator = tqdm(\n            time_range,\n            desc='PLMS Sampler',\n            total=total_steps,\n            dynamic_ncols=True,\n        )\n        old_eps = []\n\n        for i, step in enumerate(iterator):\n            index = total_steps - i - 1\n            ts = torch.full((b,), step, device=device, dtype=torch.long)\n            ts_next = torch.full(\n                (b,),\n                time_range[min(i + 1, len(time_range) - 1)],\n                device=device,\n                dtype=torch.long,\n            )\n\n            if mask is not None:\n                assert x0 is not None\n                img_orig = self.model.q_sample(\n                    x0, ts\n                )  # TODO: deterministic forward pass?\n                img = img_orig * mask + (1.0 - mask) * img\n\n            outs = self.p_sample_plms(\n                img,\n                cond,\n                ts,\n                index=index,\n                use_original_steps=ddim_use_original_steps,\n                quantize_denoised=quantize_denoised,\n                temperature=temperature,\n                noise_dropout=noise_dropout,\n                score_corrector=score_corrector,\n                corrector_kwargs=corrector_kwargs,\n                unconditional_guidance_scale=unconditional_guidance_scale,\n                unconditional_conditioning=unconditional_conditioning,\n                old_eps=old_eps,\n                t_next=ts_next,\n            )\n            img, pred_x0, e_t = outs\n            old_eps.append(e_t)\n            if len(old_eps) >= 4:\n                old_eps.pop(0)\n            if callback:\n                callback(i)\n            if img_callback:\n                img_callback(pred_x0, i)\n\n            if index % log_every_t == 0 or index == total_steps - 1:\n                intermediates['x_inter'].append(img)\n                intermediates['pred_x0'].append(pred_x0)\n\n        return img, intermediates\n\n    @torch.no_grad()\n    def p_sample_plms(\n        self,\n        x,\n        c,\n        t,\n        index,\n        repeat_noise=False,\n        use_original_steps=False,\n        quantize_denoised=False,\n        temperature=1.0,\n        noise_dropout=0.0,\n        score_corrector=None,\n        corrector_kwargs=None,\n        unconditional_guidance_scale=1.0,\n        unconditional_conditioning=None,\n        old_eps=None,\n        t_next=None,\n    ):\n        b, *_, device = *x.shape, x.device\n\n        def get_model_output(x, t):\n            if (\n                unconditional_conditioning is None\n                or unconditional_guidance_scale == 1.0\n            ):\n                e_t = self.model.apply_model(x, t, c)\n            else:\n                x_in = torch.cat([x] * 2)\n                t_in = torch.cat([t] * 2)\n                c_in = torch.cat([unconditional_conditioning, c])\n                e_t_uncond, e_t = self.model.apply_model(\n                    x_in, t_in, c_in\n                ).chunk(2)\n                e_t = e_t_uncond + unconditional_guidance_scale * (\n                    e_t - e_t_uncond\n                )\n\n            if score_corrector is not None:\n                assert self.model.parameterization == 'eps'\n                e_t = score_corrector.modify_score(\n                    self.model, e_t, x, t, c, **corrector_kwargs\n                )\n\n            return e_t\n\n        alphas = (\n            self.model.alphas_cumprod\n            if use_original_steps\n            else self.ddim_alphas\n        )\n        alphas_prev = (\n            self.model.alphas_cumprod_prev\n            if use_original_steps\n            else self.ddim_alphas_prev\n        )\n        sqrt_one_minus_alphas = (\n            self.model.sqrt_one_minus_alphas_cumprod\n            if use_original_steps\n            else self.ddim_sqrt_one_minus_alphas\n        )\n        sigmas = (\n            self.model.ddim_sigmas_for_original_num_steps\n            if use_original_steps\n            else self.ddim_sigmas\n        )\n\n        def get_x_prev_and_pred_x0(e_t, index):\n            # select parameters corresponding to the currently considered timestep\n            a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)\n            a_prev = torch.full(\n                (b, 1, 1, 1), alphas_prev[index], device=device\n            )\n            sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)\n            sqrt_one_minus_at = torch.full(\n                (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device\n            )\n\n            # current prediction for x_0\n            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()\n            if quantize_denoised:\n                pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)\n            # direction pointing to x_t\n            dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t\n            noise = (\n                sigma_t\n                * noise_like(x.shape, device, repeat_noise)\n                * temperature\n            )\n            if noise_dropout > 0.0:\n                noise = torch.nn.functional.dropout(noise, p=noise_dropout)\n            x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise\n            return x_prev, pred_x0\n\n        e_t = get_model_output(x, t)\n        if len(old_eps) == 0:\n            # Pseudo Improved Euler (2nd order)\n            x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)\n            e_t_next = get_model_output(x_prev, t_next)\n            e_t_prime = (e_t + e_t_next) / 2\n        elif len(old_eps) == 1:\n            # 2nd order Pseudo Linear Multistep (Adams-Bashforth)\n            e_t_prime = (3 * e_t - old_eps[-1]) / 2\n        elif len(old_eps) == 2:\n            # 3nd order Pseudo Linear Multistep (Adams-Bashforth)\n            e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12\n        elif len(old_eps) >= 3:\n            # 4nd order Pseudo Linear Multistep (Adams-Bashforth)\n            e_t_prime = (\n                55 * e_t\n                - 59 * old_eps[-1]\n                + 37 * old_eps[-2]\n                - 9 * old_eps[-3]\n            ) / 24\n\n        x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)\n\n        return x_prev, pred_x0, e_t\n"
  },
  {
    "path": "src/stablediffusion/ldm/modules/attention.py",
    "content": "from inspect import isfunction\nimport math\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn, einsum\nfrom einops import rearrange, repeat\n\nfrom src.stablediffusion.ldm.modules.diffusionmodules.util import checkpoint\n\nimport psutil\n\ndef exists(val):\n    return val is not None\n\n\ndef uniq(arr):\n    return{el: True for el in arr}.keys()\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\n\ndef max_neg_value(t):\n    return -torch.finfo(t.dtype).max\n\n\ndef init_(tensor):\n    dim = tensor.shape[-1]\n    std = 1 / math.sqrt(dim)\n    tensor.uniform_(-std, std)\n    return tensor\n\n\n# feedforward\nclass GEGLU(nn.Module):\n    def __init__(self, dim_in, dim_out):\n        super().__init__()\n        self.proj = nn.Linear(dim_in, dim_out * 2)\n\n    def forward(self, x):\n        x, gate = self.proj(x).chunk(2, dim=-1)\n        return x * F.gelu(gate)\n\n\nclass FeedForward(nn.Module):\n    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):\n        super().__init__()\n        inner_dim = int(dim * mult)\n        dim_out = default(dim_out, dim)\n        project_in = nn.Sequential(\n            nn.Linear(dim, inner_dim),\n            nn.GELU()\n        ) if not glu else GEGLU(dim, inner_dim)\n\n        self.net = nn.Sequential(\n            project_in,\n            nn.Dropout(dropout),\n            nn.Linear(inner_dim, dim_out)\n        )\n\n    def forward(self, x):\n        return self.net(x)\n\n\ndef zero_module(module):\n    \"\"\"\n    Zero out the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().zero_()\n    return module\n\n\ndef Normalize(in_channels):\n    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)\n\n\nclass LinearAttention(nn.Module):\n    def __init__(self, dim, heads=4, dim_head=32):\n        super().__init__()\n        self.heads = heads\n        hidden_dim = dim_head * heads\n        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)\n        self.to_out = nn.Conv2d(hidden_dim, dim, 1)\n\n    def forward(self, x):\n        b, c, h, w = x.shape\n        qkv = self.to_qkv(x)\n        q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)\n        k = k.softmax(dim=-1)  \n        context = torch.einsum('bhdn,bhen->bhde', k, v)\n        out = torch.einsum('bhde,bhdn->bhen', context, q)\n        out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)\n        return self.to_out(out)\n\n\nclass SpatialSelfAttention(nn.Module):\n    def __init__(self, in_channels):\n        super().__init__()\n        self.in_channels = in_channels\n\n        self.norm = Normalize(in_channels)\n        self.q = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.k = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.v = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.proj_out = torch.nn.Conv2d(in_channels,\n                                        in_channels,\n                                        kernel_size=1,\n                                        stride=1,\n                                        padding=0)\n\n    def forward(self, x):\n        h_ = x\n        h_ = self.norm(h_)\n        q = self.q(h_)\n        k = self.k(h_)\n        v = self.v(h_)\n\n        # compute attention\n        b,c,h,w = q.shape\n        q = rearrange(q, 'b c h w -> b (h w) c')\n        k = rearrange(k, 'b c h w -> b c (h w)')\n        w_ = torch.einsum('bij,bjk->bik', q, k)\n\n        w_ = w_ * (int(c)**(-0.5))\n        w_ = torch.nn.functional.softmax(w_, dim=2)\n\n        # attend to values\n        v = rearrange(v, 'b c h w -> b c (h w)')\n        w_ = rearrange(w_, 'b i j -> b j i')\n        h_ = torch.einsum('bij,bjk->bik', v, w_)\n        h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)\n        h_ = self.proj_out(h_)\n\n        return x+h_\n\n\nclass CrossAttention(nn.Module):\n    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):\n        super().__init__()\n        inner_dim = dim_head * heads\n        context_dim = default(context_dim, query_dim)\n\n        self.scale = dim_head ** -0.5\n        self.heads = heads\n\n        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)\n        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)\n        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)\n\n        self.to_out = nn.Sequential(\n            nn.Linear(inner_dim, query_dim),\n            nn.Dropout(dropout)\n        )\n\n        if not torch.cuda.is_available():\n            mem_av = psutil.virtual_memory().available / (1024**3)\n            if mem_av > 32:\n                self.einsum_op = self.einsum_op_v1\n            elif mem_av > 12:\n                self.einsum_op = self.einsum_op_v2\n            else:\n                self.einsum_op = self.einsum_op_v3   \n            del mem_av \n        else:\n            self.einsum_op = self.einsum_op_v4\n\n    # mps 64-128 GB\n    def einsum_op_v1(self, q, k, v, r1):\n        if q.shape[1] <= 4096: # for 512x512: the max q.shape[1] is 4096\n            s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # aggressive/faster: operation in one go\n            s2 = s1.softmax(dim=-1, dtype=q.dtype)\n            del s1\n            r1 = einsum('b i j, b j d -> b i d', s2, v)\n            del s2\n        else:\n            # q.shape[0] * q.shape[1] * slice_size >= 2**31 throws err\n            # needs around half of that slice_size to not generate noise\n            slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))\n            for i in range(0, q.shape[1], slice_size):\n                end = i + slice_size\n                s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale\n                s2 = s1.softmax(dim=-1, dtype=r1.dtype)\n                del s1  \n                r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)\n                del s2\n        return r1\n\n    # mps 16-32 GB (can be optimized)\n    def einsum_op_v2(self, q, k, v, r1):\n        slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))\n        for i in range(0, q.shape[1], slice_size): # conservative/less mem: operation in steps\n            end = i + slice_size\n            s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale\n            s2 = s1.softmax(dim=-1, dtype=r1.dtype)\n            del s1  \n            r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)\n            del s2\n        return r1\n\n    # mps 8 GB\n    def einsum_op_v3(self, q, k, v, r1):\n        slice_size = 1\n        for i in range(0, q.shape[0], slice_size): # iterate over q.shape[0]\n            end = min(q.shape[0], i + slice_size)\n            s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) # adapted einsum for mem\n            s1 *= self.scale\n            s2 = s1.softmax(dim=-1, dtype=r1.dtype)\n            del s1\n            r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) # adapted einsum for mem\n            del s2\n        return r1\n\n    # cuda\n    def einsum_op_v4(self, q, k, v, r1):\n        stats = torch.cuda.memory_stats(q.device)\n        mem_active = stats['active_bytes.all.current']\n        mem_reserved = stats['reserved_bytes.all.current']\n        mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())\n        mem_free_torch = mem_reserved - mem_active\n        mem_free_total = mem_free_cuda + mem_free_torch\n\n        gb = 1024 ** 3\n        tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4\n        mem_required = tensor_size * 2.5\n        steps = 1\n\n        if mem_required > mem_free_total:\n            steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))\n\n        if steps > 64:\n            max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64\n            raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '\n                            f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')\n        \n        slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]  \n        for i in range(0, q.shape[1], slice_size):\n            end = min(q.shape[1], i + slice_size)\n            s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale\n            s2 = s1.softmax(dim=-1, dtype=r1.dtype)\n            del s1\n            r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)\n            del s2 \n        return r1\n\n    def forward(self, x, context=None, mask=None):\n        h = self.heads\n\n        q_in = self.to_q(x)\n        context = default(context, x)\n        k_in = self.to_k(context)\n        v_in = self.to_v(context)\n        device_type = 'mps' if x.device.type == 'mps' else 'cuda'\n        del context, x\n\n        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))\n        del q_in, k_in, v_in\n        r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)\n        r1 = self.einsum_op(q, k, v, r1)\n        del q, k, v\n\n        r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)\n        del r1\n\n        return self.to_out(r2)\n\n\nclass BasicTransformerBlock(nn.Module):\n    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):\n        super().__init__()\n        self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout)  # is a self-attention\n        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)\n        self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,\n                                    heads=n_heads, dim_head=d_head, dropout=dropout)  # is self-attn if context is none\n        self.norm1 = nn.LayerNorm(dim)\n        self.norm2 = nn.LayerNorm(dim)\n        self.norm3 = nn.LayerNorm(dim)\n        self.checkpoint = checkpoint\n\n    def forward(self, x, context=None):\n        return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)\n\n    def _forward(self, x, context=None):\n        x = x.contiguous() if x.device.type == 'mps' else x\n        x = self.attn1(self.norm1(x)) + x\n        x = self.attn2(self.norm2(x), context=context) + x\n        x = self.ff(self.norm3(x)) + x\n        return x\n\n\nclass SpatialTransformer(nn.Module):\n    \"\"\"\n    Transformer block for image-like data.\n    First, project the input (aka embedding)\n    and reshape to b, t, d.\n    Then apply standard transformer action.\n    Finally, reshape to image\n    \"\"\"\n    def __init__(self, in_channels, n_heads, d_head,\n                 depth=1, dropout=0., context_dim=None):\n        super().__init__()\n        self.in_channels = in_channels\n        inner_dim = n_heads * d_head\n        self.norm = Normalize(in_channels)\n\n        self.proj_in = nn.Conv2d(in_channels,\n                                 inner_dim,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n\n        self.transformer_blocks = nn.ModuleList(\n            [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)\n                for d in range(depth)]\n        )\n\n        self.proj_out = zero_module(nn.Conv2d(inner_dim,\n                                              in_channels,\n                                              kernel_size=1,\n                                              stride=1,\n                                              padding=0))\n\n    def forward(self, x, context=None):\n        # note: if no context is given, cross-attention defaults to self-attention\n        b, c, h, w = x.shape\n        x_in = x\n        x = self.norm(x)\n        x = self.proj_in(x)\n        x = rearrange(x, 'b c h w -> b (h w) c')\n        for block in self.transformer_blocks:\n            x = block(x, context=context)\n        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)\n        x = self.proj_out(x)\n        return x + x_in\n"
  },
  {
    "path": "src/stablediffusion/ldm/modules/diffusionmodules/__init__.py",
    "content": ""
  },
  {
    "path": "src/stablediffusion/ldm/modules/diffusionmodules/model.py",
    "content": "# pytorch_diffusion + derived encoder decoder\nimport gc\nimport math\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom einops import rearrange\n\nfrom src.stablediffusion.ldm.util import instantiate_from_config\nfrom src.stablediffusion.ldm.modules.attention import LinearAttention\n\nimport psutil\n\ndef get_timestep_embedding(timesteps, embedding_dim):\n    \"\"\"\n    This matches the implementation in Denoising Diffusion Probabilistic Models:\n    From Fairseq.\n    Build sinusoidal embeddings.\n    This matches the implementation in tensor2tensor, but differs slightly\n    from the description in Section 3.5 of \"Attention Is All You Need\".\n    \"\"\"\n    assert len(timesteps.shape) == 1\n\n    half_dim = embedding_dim // 2\n    emb = math.log(10000) / (half_dim - 1)\n    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)\n    emb = emb.to(device=timesteps.device)\n    emb = timesteps.float()[:, None] * emb[None, :]\n    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n    if embedding_dim % 2 == 1:  # zero pad\n        emb = torch.nn.functional.pad(emb, (0,1,0,0))\n    return emb\n\n\ndef nonlinearity(x):\n    # swish\n    return x*torch.sigmoid(x)\n\n\ndef Normalize(in_channels, num_groups=32):\n    return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)\n\n\nclass Upsample(nn.Module):\n    def __init__(self, in_channels, with_conv):\n        super().__init__()\n        self.with_conv = with_conv\n        if self.with_conv:\n            self.conv = torch.nn.Conv2d(in_channels,\n                                        in_channels,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n\n    def forward(self, x):\n        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode=\"nearest\")\n        if self.with_conv:\n            x = self.conv(x)\n        return x\n\n\nclass Downsample(nn.Module):\n    def __init__(self, in_channels, with_conv):\n        super().__init__()\n        self.with_conv = with_conv\n        if self.with_conv:\n            # no asymmetric padding in torch conv, must do it ourselves\n            self.conv = torch.nn.Conv2d(in_channels,\n                                        in_channels,\n                                        kernel_size=3,\n                                        stride=2,\n                                        padding=0)\n\n    def forward(self, x):\n        if self.with_conv:\n            pad = (0,1,0,1)\n            x = torch.nn.functional.pad(x, pad, mode=\"constant\", value=0)\n            x = self.conv(x)\n        else:\n            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)\n        return x\n\n\nclass ResnetBlock(nn.Module):\n    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,\n                 dropout, temb_channels=512):\n        super().__init__()\n        self.in_channels = in_channels\n        out_channels = in_channels if out_channels is None else out_channels\n        self.out_channels = out_channels\n        self.use_conv_shortcut = conv_shortcut\n\n        self.norm1 = Normalize(in_channels)\n        self.conv1 = torch.nn.Conv2d(in_channels,\n                                     out_channels,\n                                     kernel_size=3,\n                                     stride=1,\n                                     padding=1)\n        if temb_channels > 0:\n            self.temb_proj = torch.nn.Linear(temb_channels,\n                                             out_channels)\n        self.norm2 = Normalize(out_channels)\n        self.dropout = torch.nn.Dropout(dropout)\n        self.conv2 = torch.nn.Conv2d(out_channels,\n                                     out_channels,\n                                     kernel_size=3,\n                                     stride=1,\n                                     padding=1)\n        if self.in_channels != self.out_channels:\n            if self.use_conv_shortcut:\n                self.conv_shortcut = torch.nn.Conv2d(in_channels,\n                                                     out_channels,\n                                                     kernel_size=3,\n                                                     stride=1,\n                                                     padding=1)\n            else:\n                self.nin_shortcut = torch.nn.Conv2d(in_channels,\n                                                    out_channels,\n                                                    kernel_size=1,\n                                                    stride=1,\n                                                    padding=0)\n\n    def forward(self, x, temb):\n        h1 = x\n        h2 = self.norm1(h1)\n        del h1\n\n        h3 = nonlinearity(h2)\n        del h2\n\n        h4 = self.conv1(h3)\n        del h3\n\n        if temb is not None:\n            h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None]\n\n        h5 = self.norm2(h4)\n        del h4\n\n        h6 = nonlinearity(h5)\n        del h5\n\n        h7 = self.dropout(h6)\n        del h6\n\n        h8 = self.conv2(h7)\n        del h7\n\n        if self.in_channels != self.out_channels:\n            if self.use_conv_shortcut:\n                x = self.conv_shortcut(x)\n            else:\n                x = self.nin_shortcut(x)\n\n        return x + h8\n\nclass LinAttnBlock(LinearAttention):\n    \"\"\"to match AttnBlock usage\"\"\"\n    def __init__(self, in_channels):\n        super().__init__(dim=in_channels, heads=1, dim_head=in_channels)\n\n\nclass AttnBlock(nn.Module):\n    def __init__(self, in_channels):\n        super().__init__()\n        self.in_channels = in_channels\n\n        self.norm = Normalize(in_channels)\n        self.q = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.k = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.v = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.proj_out = torch.nn.Conv2d(in_channels,\n                                        in_channels,\n                                        kernel_size=1,\n                                        stride=1,\n                                        padding=0)\n\n\n    def forward(self, x):\n        h_ = x\n        h_ = self.norm(h_)\n        q1 = self.q(h_)\n        k1 = self.k(h_)\n        v = self.v(h_)\n\n        # compute attention\n        b, c, h, w = q1.shape\n\n        q2 = q1.reshape(b, c, h*w)\n        del q1\n\n        q = q2.permute(0, 2, 1)   # b,hw,c\n        del q2\n\n        k = k1.reshape(b, c, h*w) # b,c,hw\n        del k1\n\n        h_ = torch.zeros_like(k, device=q.device)\n\n        device_type = 'mps' if q.device.type == 'mps' else 'cuda'\n        if device_type == 'cuda':\n            stats = torch.cuda.memory_stats(q.device)\n            mem_active = stats['active_bytes.all.current']\n            mem_reserved = stats['reserved_bytes.all.current']\n            mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())\n            mem_free_torch = mem_reserved - mem_active\n            mem_free_total = mem_free_cuda + mem_free_torch\n\n            tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4\n            mem_required = tensor_size * 2.5\n            steps = 1\n\n            if mem_required > mem_free_total:\n                steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))\n            \n            slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]\n\n        else:\n            if psutil.virtual_memory().available / (1024**3) < 12:\n                slice_size = 1\n            else:\n                slice_size = min(q.shape[1], math.floor(2**30 / (q.shape[0] * q.shape[1])))\n        \n        for i in range(0, q.shape[1], slice_size):\n            end = i + slice_size\n\n            w1 = torch.bmm(q[:, i:end], k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]\n            w2 = w1 * (int(c)**(-0.5))\n            del w1\n            w3 = torch.nn.functional.softmax(w2, dim=2)\n            del w2\n\n            # attend to values\n            v1 = v.reshape(b, c, h*w)\n            w4 = w3.permute(0, 2, 1)   # b,hw,hw (first hw of k, second of q)\n            del w3\n\n            h_[:, :, i:end] = torch.bmm(v1, w4)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]\n            del v1, w4\n\n        h2 = h_.reshape(b, c, h, w)\n        del h_\n\n        h3 = self.proj_out(h2)\n        del h2\n\n        h3 += x\n\n        return h3\n\n\ndef make_attn(in_channels, attn_type=\"vanilla\"):\n    assert attn_type in [\"vanilla\", \"linear\", \"none\"], f'attn_type {attn_type} unknown'\n    print(f\"making attention of type '{attn_type}' with {in_channels} in_channels\")\n    if attn_type == \"vanilla\":\n        return AttnBlock(in_channels)\n    elif attn_type == \"none\":\n        return nn.Identity(in_channels)\n    else:\n        return LinAttnBlock(in_channels)\n\n\nclass Model(nn.Module):\n    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,\n                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,\n                 resolution, use_timestep=True, use_linear_attn=False, attn_type=\"vanilla\"):\n        super().__init__()\n        if use_linear_attn: attn_type = \"linear\"\n        self.ch = ch\n        self.temb_ch = self.ch*4\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        self.resolution = resolution\n        self.in_channels = in_channels\n\n        self.use_timestep = use_timestep\n        if self.use_timestep:\n            # timestep embedding\n            self.temb = nn.Module()\n            self.temb.dense = nn.ModuleList([\n                torch.nn.Linear(self.ch,\n                                self.temb_ch),\n                torch.nn.Linear(self.temb_ch,\n                                self.temb_ch),\n            ])\n\n        # downsampling\n        self.conv_in = torch.nn.Conv2d(in_channels,\n                                       self.ch,\n                                       kernel_size=3,\n                                       stride=1,\n                                       padding=1)\n\n        curr_res = resolution\n        in_ch_mult = (1,)+tuple(ch_mult)\n        self.down = nn.ModuleList()\n        for i_level in range(self.num_resolutions):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_in = ch*in_ch_mult[i_level]\n            block_out = ch*ch_mult[i_level]\n            for i_block in range(self.num_res_blocks):\n                block.append(ResnetBlock(in_channels=block_in,\n                                         out_channels=block_out,\n                                         temb_channels=self.temb_ch,\n                                         dropout=dropout))\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(make_attn(block_in, attn_type=attn_type))\n            down = nn.Module()\n            down.block = block\n            down.attn = attn\n            if i_level != self.num_resolutions-1:\n                down.downsample = Downsample(block_in, resamp_with_conv)\n                curr_res = curr_res // 2\n            self.down.append(down)\n\n        # middle\n        self.mid = nn.Module()\n        self.mid.block_1 = ResnetBlock(in_channels=block_in,\n                                       out_channels=block_in,\n                                       temb_channels=self.temb_ch,\n                                       dropout=dropout)\n        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)\n        self.mid.block_2 = ResnetBlock(in_channels=block_in,\n                                       out_channels=block_in,\n                                       temb_channels=self.temb_ch,\n                                       dropout=dropout)\n\n        # upsampling\n        self.up = nn.ModuleList()\n        for i_level in reversed(range(self.num_resolutions)):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_out = ch*ch_mult[i_level]\n            skip_in = ch*ch_mult[i_level]\n            for i_block in range(self.num_res_blocks+1):\n                if i_block == self.num_res_blocks:\n                    skip_in = ch*in_ch_mult[i_level]\n                block.append(ResnetBlock(in_channels=block_in+skip_in,\n                                         out_channels=block_out,\n                                         temb_channels=self.temb_ch,\n                                         dropout=dropout))\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(make_attn(block_in, attn_type=attn_type))\n            up = nn.Module()\n            up.block = block\n            up.attn = attn\n            if i_level != 0:\n                up.upsample = Upsample(block_in, resamp_with_conv)\n                curr_res = curr_res * 2\n            self.up.insert(0, up) # prepend to get consistent order\n\n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = torch.nn.Conv2d(block_in,\n                                        out_ch,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n\n    def forward(self, x, t=None, context=None):\n        #assert x.shape[2] == x.shape[3] == self.resolution\n        if context is not None:\n            # assume aligned context, cat along channel axis\n            x = torch.cat((x, context), dim=1)\n        if self.use_timestep:\n            # timestep embedding\n            assert t is not None\n            temb = get_timestep_embedding(t, self.ch)\n            temb = self.temb.dense[0](temb)\n            temb = nonlinearity(temb)\n            temb = self.temb.dense[1](temb)\n        else:\n            temb = None\n\n        # downsampling\n        hs = [self.conv_in(x)]\n        for i_level in range(self.num_resolutions):\n            for i_block in range(self.num_res_blocks):\n                h = self.down[i_level].block[i_block](hs[-1], temb)\n                if len(self.down[i_level].attn) > 0:\n                    h = self.down[i_level].attn[i_block](h)\n                hs.append(h)\n            if i_level != self.num_resolutions-1:\n                hs.append(self.down[i_level].downsample(hs[-1]))\n\n        # middle\n        h = hs[-1]\n        h = self.mid.block_1(h, temb)\n        h = self.mid.attn_1(h)\n        h = self.mid.block_2(h, temb)\n\n        # upsampling\n        for i_level in reversed(range(self.num_resolutions)):\n            for i_block in range(self.num_res_blocks+1):\n                h = self.up[i_level].block[i_block](\n                    torch.cat([h, hs.pop()], dim=1), temb)\n                if len(self.up[i_level].attn) > 0:\n                    h = self.up[i_level].attn[i_block](h)\n            if i_level != 0:\n                h = self.up[i_level].upsample(h)\n\n        # end\n        h = self.norm_out(h)\n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        return h\n\n    def get_last_layer(self):\n        return self.conv_out.weight\n\n\nclass Encoder(nn.Module):\n    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,\n                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,\n                 resolution, z_channels, double_z=True, use_linear_attn=False, attn_type=\"vanilla\",\n                 **ignore_kwargs):\n        super().__init__()\n        if use_linear_attn: attn_type = \"linear\"\n        self.ch = ch\n        self.temb_ch = 0\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        self.resolution = resolution\n        self.in_channels = in_channels\n\n        # downsampling\n        self.conv_in = torch.nn.Conv2d(in_channels,\n                                       self.ch,\n                                       kernel_size=3,\n                                       stride=1,\n                                       padding=1)\n\n        curr_res = resolution\n        in_ch_mult = (1,)+tuple(ch_mult)\n        self.in_ch_mult = in_ch_mult\n        self.down = nn.ModuleList()\n        for i_level in range(self.num_resolutions):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_in = ch*in_ch_mult[i_level]\n            block_out = ch*ch_mult[i_level]\n            for i_block in range(self.num_res_blocks):\n                block.append(ResnetBlock(in_channels=block_in,\n                                         out_channels=block_out,\n                                         temb_channels=self.temb_ch,\n                                         dropout=dropout))\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(make_attn(block_in, attn_type=attn_type))\n            down = nn.Module()\n            down.block = block\n            down.attn = attn\n            if i_level != self.num_resolutions-1:\n                down.downsample = Downsample(block_in, resamp_with_conv)\n                curr_res = curr_res // 2\n            self.down.append(down)\n\n        # middle\n        self.mid = nn.Module()\n        self.mid.block_1 = ResnetBlock(in_channels=block_in,\n                                       out_channels=block_in,\n                                       temb_channels=self.temb_ch,\n                                       dropout=dropout)\n        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)\n        self.mid.block_2 = ResnetBlock(in_channels=block_in,\n                                       out_channels=block_in,\n                                       temb_channels=self.temb_ch,\n                                       dropout=dropout)\n\n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = torch.nn.Conv2d(block_in,\n                                        2*z_channels if double_z else z_channels,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n\n    def forward(self, x):\n        # timestep embedding\n        temb = None\n\n        # downsampling\n        hs = [self.conv_in(x)]\n        for i_level in range(self.num_resolutions):\n            for i_block in range(self.num_res_blocks):\n                h = self.down[i_level].block[i_block](hs[-1], temb)\n                if len(self.down[i_level].attn) > 0:\n                    h = self.down[i_level].attn[i_block](h)\n                hs.append(h)\n            if i_level != self.num_resolutions-1:\n                hs.append(self.down[i_level].downsample(hs[-1]))\n\n        # middle\n        h = hs[-1]\n        h = self.mid.block_1(h, temb)\n        h = self.mid.attn_1(h)\n        h = self.mid.block_2(h, temb)\n\n        # end\n        h = self.norm_out(h)\n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        return h\n\n\nclass Decoder(nn.Module):\n    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,\n                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,\n                 resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,\n                 attn_type=\"vanilla\", **ignorekwargs):\n        super().__init__()\n        if use_linear_attn: attn_type = \"linear\"\n        self.ch = ch\n        self.temb_ch = 0\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        self.resolution = resolution\n        self.in_channels = in_channels\n        self.give_pre_end = give_pre_end\n        self.tanh_out = tanh_out\n\n        # compute in_ch_mult, block_in and curr_res at lowest res\n        in_ch_mult = (1,)+tuple(ch_mult)\n        block_in = ch*ch_mult[self.num_resolutions-1]\n        curr_res = resolution // 2**(self.num_resolutions-1)\n        self.z_shape = (1,z_channels,curr_res,curr_res)\n        print(\"Working with z of shape {} = {} dimensions.\".format(\n            self.z_shape, np.prod(self.z_shape)))\n\n        # z to block_in\n        self.conv_in = torch.nn.Conv2d(z_channels,\n                                       block_in,\n                                       kernel_size=3,\n                                       stride=1,\n                                       padding=1)\n\n        # middle\n        self.mid = nn.Module()\n        self.mid.block_1 = ResnetBlock(in_channels=block_in,\n                                       out_channels=block_in,\n                                       temb_channels=self.temb_ch,\n                                       dropout=dropout)\n        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)\n        self.mid.block_2 = ResnetBlock(in_channels=block_in,\n                                       out_channels=block_in,\n                                       temb_channels=self.temb_ch,\n                                       dropout=dropout)\n\n        # upsampling\n        self.up = nn.ModuleList()\n        for i_level in reversed(range(self.num_resolutions)):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_out = ch*ch_mult[i_level]\n            for i_block in range(self.num_res_blocks+1):\n                block.append(ResnetBlock(in_channels=block_in,\n                                         out_channels=block_out,\n                                         temb_channels=self.temb_ch,\n                                         dropout=dropout))\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(make_attn(block_in, attn_type=attn_type))\n            up = nn.Module()\n            up.block = block\n            up.attn = attn\n            if i_level != 0:\n                up.upsample = Upsample(block_in, resamp_with_conv)\n                curr_res = curr_res * 2\n            self.up.insert(0, up) # prepend to get consistent order\n\n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = torch.nn.Conv2d(block_in,\n                                        out_ch,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n\n    def forward(self, z):\n        #assert z.shape[1:] == self.z_shape[1:]\n        self.last_z_shape = z.shape\n\n        # timestep embedding\n        temb = None\n\n        # z to block_in\n        h1 = self.conv_in(z)\n\n        # middle\n        h2 = self.mid.block_1(h1, temb)\n        del h1\n\n        h3 = self.mid.attn_1(h2)\n        del h2\n\n        h = self.mid.block_2(h3, temb)\n        del h3\n\n        # prepare for up sampling\n        device_type = 'mps' if h.device.type == 'mps' else 'cuda'\n        gc.collect()\n        if device_type == 'cuda':\n            torch.cuda.empty_cache()\n\n        # upsampling\n        for i_level in reversed(range(self.num_resolutions)):\n            for i_block in range(self.num_res_blocks+1):\n                h = self.up[i_level].block[i_block](h, temb)\n                if len(self.up[i_level].attn) > 0:\n                    t = h\n                    h = self.up[i_level].attn[i_block](t)\n                    del t\n\n            if i_level != 0:\n                t = h\n                h = self.up[i_level].upsample(t)\n                del t\n\n        # end\n        if self.give_pre_end:\n            return h\n\n        h1 = self.norm_out(h)\n        del h\n\n        h2 = nonlinearity(h1)\n        del h1\n\n        h = self.conv_out(h2)\n        del h2\n\n        if self.tanh_out:\n            t = h\n            h = torch.tanh(t)\n            del t\n\n        return h\n\n\nclass SimpleDecoder(nn.Module):\n    def __init__(self, in_channels, out_channels, *args, **kwargs):\n        super().__init__()\n        self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),\n                                     ResnetBlock(in_channels=in_channels,\n                                                 out_channels=2 * in_channels,\n                                                 temb_channels=0, dropout=0.0),\n                                     ResnetBlock(in_channels=2 * in_channels,\n                                                out_channels=4 * in_channels,\n                                                temb_channels=0, dropout=0.0),\n                                     ResnetBlock(in_channels=4 * in_channels,\n                                                out_channels=2 * in_channels,\n                                                temb_channels=0, dropout=0.0),\n                                     nn.Conv2d(2*in_channels, in_channels, 1),\n                                     Upsample(in_channels, with_conv=True)])\n        # end\n        self.norm_out = Normalize(in_channels)\n        self.conv_out = torch.nn.Conv2d(in_channels,\n                                        out_channels,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n\n    def forward(self, x):\n        for i, layer in enumerate(self.model):\n            if i in [1,2,3]:\n                x = layer(x, None)\n            else:\n                x = layer(x)\n\n        h = self.norm_out(x)\n        h = nonlinearity(h)\n        x = self.conv_out(h)\n        return x\n\n\nclass UpsampleDecoder(nn.Module):\n    def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,\n                 ch_mult=(2,2), dropout=0.0):\n        super().__init__()\n        # upsampling\n        self.temb_ch = 0\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        block_in = in_channels\n        curr_res = resolution // 2 ** (self.num_resolutions - 1)\n        self.res_blocks = nn.ModuleList()\n        self.upsample_blocks = nn.ModuleList()\n        for i_level in range(self.num_resolutions):\n            res_block = []\n            block_out = ch * ch_mult[i_level]\n            for i_block in range(self.num_res_blocks + 1):\n                res_block.append(ResnetBlock(in_channels=block_in,\n                                         out_channels=block_out,\n                                         temb_channels=self.temb_ch,\n                                         dropout=dropout))\n                block_in = block_out\n            self.res_blocks.append(nn.ModuleList(res_block))\n            if i_level != self.num_resolutions - 1:\n                self.upsample_blocks.append(Upsample(block_in, True))\n                curr_res = curr_res * 2\n\n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = torch.nn.Conv2d(block_in,\n                                        out_channels,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n\n    def forward(self, x):\n        # upsampling\n        h = x\n        for k, i_level in enumerate(range(self.num_resolutions)):\n            for i_block in range(self.num_res_blocks + 1):\n                h = self.res_blocks[i_level][i_block](h, None)\n            if i_level != self.num_resolutions - 1:\n                h = self.upsample_blocks[k](h)\n        h = self.norm_out(h)\n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        return h\n\n\nclass LatentRescaler(nn.Module):\n    def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):\n        super().__init__()\n        # residual block, interpolate, residual block\n        self.factor = factor\n        self.conv_in = nn.Conv2d(in_channels,\n                                 mid_channels,\n                                 kernel_size=3,\n                                 stride=1,\n                                 padding=1)\n        self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,\n                                                     out_channels=mid_channels,\n                                                     temb_channels=0,\n                                                     dropout=0.0) for _ in range(depth)])\n        self.attn = AttnBlock(mid_channels)\n        self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,\n                                                     out_channels=mid_channels,\n                                                     temb_channels=0,\n                                                     dropout=0.0) for _ in range(depth)])\n\n        self.conv_out = nn.Conv2d(mid_channels,\n                                  out_channels,\n                                  kernel_size=1,\n                                  )\n\n    def forward(self, x):\n        x = self.conv_in(x)\n        for block in self.res_block1:\n            x = block(x, None)\n        x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))\n        x = self.attn(x)\n        for block in self.res_block2:\n            x = block(x, None)\n        x = self.conv_out(x)\n        return x\n\n\nclass MergedRescaleEncoder(nn.Module):\n    def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,\n                 attn_resolutions, dropout=0.0, resamp_with_conv=True,\n                 ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):\n        super().__init__()\n        intermediate_chn = ch * ch_mult[-1]\n        self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,\n                               z_channels=intermediate_chn, double_z=False, resolution=resolution,\n                               attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,\n                               out_ch=None)\n        self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,\n                                       mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)\n\n    def forward(self, x):\n        x = self.encoder(x)\n        x = self.rescaler(x)\n        return x\n\n\nclass MergedRescaleDecoder(nn.Module):\n    def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),\n                 dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):\n        super().__init__()\n        tmp_chn = z_channels*ch_mult[-1]\n        self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,\n                               resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,\n                               ch_mult=ch_mult, resolution=resolution, ch=ch)\n        self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,\n                                       out_channels=tmp_chn, depth=rescale_module_depth)\n\n    def forward(self, x):\n        x = self.rescaler(x)\n        x = self.decoder(x)\n        return x\n\n\nclass Upsampler(nn.Module):\n    def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):\n        super().__init__()\n        assert out_size >= in_size\n        num_blocks = int(np.log2(out_size//in_size))+1\n        factor_up = 1.+ (out_size % in_size)\n        print(f\"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}\")\n        self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,\n                                       out_channels=in_channels)\n        self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,\n                               attn_resolutions=[], in_channels=None, ch=in_channels,\n                               ch_mult=[ch_mult for _ in range(num_blocks)])\n\n    def forward(self, x):\n        x = self.rescaler(x)\n        x = self.decoder(x)\n        return x\n\n\nclass Resize(nn.Module):\n    def __init__(self, in_channels=None, learned=False, mode=\"bilinear\"):\n        super().__init__()\n        self.with_conv = learned\n        self.mode = mode\n        if self.with_conv:\n            print(f\"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode\")\n            raise NotImplementedError()\n            assert in_channels is not None\n            # no asymmetric padding in torch conv, must do it ourselves\n            self.conv = torch.nn.Conv2d(in_channels,\n                                        in_channels,\n                                        kernel_size=4,\n                                        stride=2,\n                                        padding=1)\n\n    def forward(self, x, scale_factor=1.0):\n        if scale_factor==1.0:\n            return x\n        else:\n            x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)\n        return x\n\nclass FirstStagePostProcessor(nn.Module):\n\n    def __init__(self, ch_mult:list, in_channels,\n                 pretrained_model:nn.Module=None,\n                 reshape=False,\n                 n_channels=None,\n                 dropout=0.,\n                 pretrained_config=None):\n        super().__init__()\n        if pretrained_config is None:\n            assert pretrained_model is not None, 'Either \"pretrained_model\" or \"pretrained_config\" must not be None'\n            self.pretrained_model = pretrained_model\n        else:\n            assert pretrained_config is not None, 'Either \"pretrained_model\" or \"pretrained_config\" must not be None'\n            self.instantiate_pretrained(pretrained_config)\n\n        self.do_reshape = reshape\n\n        if n_channels is None:\n            n_channels = self.pretrained_model.encoder.ch\n\n        self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)\n        self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,\n                            stride=1,padding=1)\n\n        blocks = []\n        downs = []\n        ch_in = n_channels\n        for m in ch_mult:\n            blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))\n            ch_in = m * n_channels\n            downs.append(Downsample(ch_in, with_conv=False))\n\n        self.model = nn.ModuleList(blocks)\n        self.downsampler = nn.ModuleList(downs)\n\n\n    def instantiate_pretrained(self, config):\n        model = instantiate_from_config(config)\n        self.pretrained_model = model.eval()\n        # self.pretrained_model.train = False\n        for param in self.pretrained_model.parameters():\n            param.requires_grad = False\n\n\n    @torch.no_grad()\n    def encode_with_pretrained(self,x):\n        c = self.pretrained_model.encode(x)\n        if isinstance(c, DiagonalGaussianDistribution):\n            c = c.mode()\n        return  c\n\n    def forward(self,x):\n        z_fs = self.encode_with_pretrained(x)\n        z = self.proj_norm(z_fs)\n        z = self.proj(z)\n        z = nonlinearity(z)\n\n        for submodel, downmodel in zip(self.model,self.downsampler):\n            z = submodel(z,temb=None)\n            z = downmodel(z)\n\n        if self.do_reshape:\n            z = rearrange(z,'b c h w -> b (h w) c')\n        return z\n"
  },
  {
    "path": "src/stablediffusion/ldm/modules/diffusionmodules/openaimodel.py",
    "content": "from abc import abstractmethod\nfrom functools import partial\nimport math\nfrom typing import Iterable\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom src.stablediffusion.ldm.modules.diffusionmodules.util import (\n    checkpoint,\n    conv_nd,\n    linear,\n    avg_pool_nd,\n    zero_module,\n    normalization,\n    timestep_embedding,\n)\nfrom src.stablediffusion.ldm.modules.attention import SpatialTransformer\n\n\n# dummy replace\ndef convert_module_to_f16(x):\n    pass\n\n\ndef convert_module_to_f32(x):\n    pass\n\n\n## go\nclass AttentionPool2d(nn.Module):\n    \"\"\"\n    Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py\n    \"\"\"\n\n    def __init__(\n        self,\n        spacial_dim: int,\n        embed_dim: int,\n        num_heads_channels: int,\n        output_dim: int = None,\n    ):\n        super().__init__()\n        self.positional_embedding = nn.Parameter(\n            th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5\n        )\n        self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)\n        self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)\n        self.num_heads = embed_dim // num_heads_channels\n        self.attention = QKVAttention(self.num_heads)\n\n    def forward(self, x):\n        b, c, *_spatial = x.shape\n        x = x.reshape(b, c, -1)  # NC(HW)\n        x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)  # NC(HW+1)\n        x = x + self.positional_embedding[None, :, :].to(x.dtype)  # NC(HW+1)\n        x = self.qkv_proj(x)\n        x = self.attention(x)\n        x = self.c_proj(x)\n        return x[:, :, 0]\n\n\nclass TimestepBlock(nn.Module):\n    \"\"\"\n    Any module where forward() takes timestep embeddings as a second argument.\n    \"\"\"\n\n    @abstractmethod\n    def forward(self, x, emb):\n        \"\"\"\n        Apply the module to `x` given `emb` timestep embeddings.\n        \"\"\"\n\n\nclass TimestepEmbedSequential(nn.Sequential, TimestepBlock):\n    \"\"\"\n    A sequential module that passes timestep embeddings to the children that\n    support it as an extra input.\n    \"\"\"\n\n    def forward(self, x, emb, context=None):\n        for layer in self:\n            if isinstance(layer, TimestepBlock):\n                x = layer(x, emb)\n            elif isinstance(layer, SpatialTransformer):\n                x = layer(x, context)\n            else:\n                x = layer(x)\n        return x\n\n\nclass Upsample(nn.Module):\n    \"\"\"\n    An upsampling layer with an optional convolution.\n    :param channels: channels in the inputs and outputs.\n    :param use_conv: a bool determining if a convolution is applied.\n    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then\n                 upsampling occurs in the inner-two dimensions.\n    \"\"\"\n\n    def __init__(\n        self, channels, use_conv, dims=2, out_channels=None, padding=1\n    ):\n        super().__init__()\n        self.channels = channels\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.dims = dims\n        if use_conv:\n            self.conv = conv_nd(\n                dims, self.channels, self.out_channels, 3, padding=padding\n            )\n\n    def forward(self, x):\n        assert x.shape[1] == self.channels\n        if self.dims == 3:\n            x = F.interpolate(\n                x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest'\n            )\n        else:\n            x = F.interpolate(x, scale_factor=2, mode='nearest')\n        if self.use_conv:\n            x = self.conv(x)\n        return x\n\n\nclass TransposedUpsample(nn.Module):\n    \"\"\"Learned 2x upsampling without padding\"\"\"\n\n    def __init__(self, channels, out_channels=None, ks=5):\n        super().__init__()\n        self.channels = channels\n        self.out_channels = out_channels or channels\n\n        self.up = nn.ConvTranspose2d(\n            self.channels, self.out_channels, kernel_size=ks, stride=2\n        )\n\n    def forward(self, x):\n        return self.up(x)\n\n\nclass Downsample(nn.Module):\n    \"\"\"\n    A downsampling layer with an optional convolution.\n    :param channels: channels in the inputs and outputs.\n    :param use_conv: a bool determining if a convolution is applied.\n    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then\n                 downsampling occurs in the inner-two dimensions.\n    \"\"\"\n\n    def __init__(\n        self, channels, use_conv, dims=2, out_channels=None, padding=1\n    ):\n        super().__init__()\n        self.channels = channels\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.dims = dims\n        stride = 2 if dims != 3 else (1, 2, 2)\n        if use_conv:\n            self.op = conv_nd(\n                dims,\n                self.channels,\n                self.out_channels,\n                3,\n                stride=stride,\n                padding=padding,\n            )\n        else:\n            assert self.channels == self.out_channels\n            self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)\n\n    def forward(self, x):\n        assert x.shape[1] == self.channels\n        return self.op(x)\n\n\nclass ResBlock(TimestepBlock):\n    \"\"\"\n    A residual block that can optionally change the number of channels.\n    :param channels: the number of input channels.\n    :param emb_channels: the number of timestep embedding channels.\n    :param dropout: the rate of dropout.\n    :param out_channels: if specified, the number of out channels.\n    :param use_conv: if True and out_channels is specified, use a spatial\n        convolution instead of a smaller 1x1 convolution to change the\n        channels in the skip connection.\n    :param dims: determines if the signal is 1D, 2D, or 3D.\n    :param use_checkpoint: if True, use gradient checkpointing on this module.\n    :param up: if True, use this block for upsampling.\n    :param down: if True, use this block for downsampling.\n    \"\"\"\n\n    def __init__(\n        self,\n        channels,\n        emb_channels,\n        dropout,\n        out_channels=None,\n        use_conv=False,\n        use_scale_shift_norm=False,\n        dims=2,\n        use_checkpoint=False,\n        up=False,\n        down=False,\n    ):\n        super().__init__()\n        self.channels = channels\n        self.emb_channels = emb_channels\n        self.dropout = dropout\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.use_checkpoint = use_checkpoint\n        self.use_scale_shift_norm = use_scale_shift_norm\n\n        self.in_layers = nn.Sequential(\n            normalization(channels),\n            nn.SiLU(),\n            conv_nd(dims, channels, self.out_channels, 3, padding=1),\n        )\n\n        self.updown = up or down\n\n        if up:\n            self.h_upd = Upsample(channels, False, dims)\n            self.x_upd = Upsample(channels, False, dims)\n        elif down:\n            self.h_upd = Downsample(channels, False, dims)\n            self.x_upd = Downsample(channels, False, dims)\n        else:\n            self.h_upd = self.x_upd = nn.Identity()\n\n        self.emb_layers = nn.Sequential(\n            nn.SiLU(),\n            linear(\n                emb_channels,\n                2 * self.out_channels\n                if use_scale_shift_norm\n                else self.out_channels,\n            ),\n        )\n        self.out_layers = nn.Sequential(\n            normalization(self.out_channels),\n            nn.SiLU(),\n            nn.Dropout(p=dropout),\n            zero_module(\n                conv_nd(\n                    dims, self.out_channels, self.out_channels, 3, padding=1\n                )\n            ),\n        )\n\n        if self.out_channels == channels:\n            self.skip_connection = nn.Identity()\n        elif use_conv:\n            self.skip_connection = conv_nd(\n                dims, channels, self.out_channels, 3, padding=1\n            )\n        else:\n            self.skip_connection = conv_nd(\n                dims, channels, self.out_channels, 1\n            )\n\n    def forward(self, x, emb):\n        \"\"\"\n        Apply the block to a Tensor, conditioned on a timestep embedding.\n        :param x: an [N x C x ...] Tensor of features.\n        :param emb: an [N x emb_channels] Tensor of timestep embeddings.\n        :return: an [N x C x ...] Tensor of outputs.\n        \"\"\"\n        return checkpoint(\n            self._forward, (x, emb), self.parameters(), self.use_checkpoint\n        )\n\n    def _forward(self, x, emb):\n        if self.updown:\n            in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]\n            h = in_rest(x)\n            h = self.h_upd(h)\n            x = self.x_upd(x)\n            h = in_conv(h)\n        else:\n            h = self.in_layers(x)\n        emb_out = self.emb_layers(emb).type(h.dtype)\n        while len(emb_out.shape) < len(h.shape):\n            emb_out = emb_out[..., None]\n        if self.use_scale_shift_norm:\n            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]\n            scale, shift = th.chunk(emb_out, 2, dim=1)\n            h = out_norm(h) * (1 + scale) + shift\n            h = out_rest(h)\n        else:\n            h = h + emb_out\n            h = self.out_layers(h)\n        return self.skip_connection(x) + h\n\n\nclass AttentionBlock(nn.Module):\n    \"\"\"\n    An attention block that allows spatial positions to attend to each other.\n    Originally ported from here, but adapted to the N-d case.\n    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.\n    \"\"\"\n\n    def __init__(\n        self,\n        channels,\n        num_heads=1,\n        num_head_channels=-1,\n        use_checkpoint=False,\n        use_new_attention_order=False,\n    ):\n        super().__init__()\n        self.channels = channels\n        if num_head_channels == -1:\n            self.num_heads = num_heads\n        else:\n            assert (\n                channels % num_head_channels == 0\n            ), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}'\n            self.num_heads = channels // num_head_channels\n        self.use_checkpoint = use_checkpoint\n        self.norm = normalization(channels)\n        self.qkv = conv_nd(1, channels, channels * 3, 1)\n        if use_new_attention_order:\n            # split qkv before split heads\n            self.attention = QKVAttention(self.num_heads)\n        else:\n            # split heads before split qkv\n            self.attention = QKVAttentionLegacy(self.num_heads)\n\n        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))\n\n    def forward(self, x):\n        return checkpoint(\n            self._forward, (x,), self.parameters(), True\n        )   # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!\n        # return pt_checkpoint(self._forward, x)  # pytorch\n\n    def _forward(self, x):\n        b, c, *spatial = x.shape\n        x = x.reshape(b, c, -1)\n        qkv = self.qkv(self.norm(x))\n        h = self.attention(qkv)\n        h = self.proj_out(h)\n        return (x + h).reshape(b, c, *spatial)\n\n\ndef count_flops_attn(model, _x, y):\n    \"\"\"\n    A counter for the `thop` package to count the operations in an\n    attention operation.\n    Meant to be used like:\n        macs, params = thop.profile(\n            model,\n            inputs=(inputs, timestamps),\n            custom_ops={QKVAttention: QKVAttention.count_flops},\n        )\n    \"\"\"\n    b, c, *spatial = y[0].shape\n    num_spatial = int(np.prod(spatial))\n    # We perform two matmuls with the same number of ops.\n    # The first computes the weight matrix, the second computes\n    # the combination of the value vectors.\n    matmul_ops = 2 * b * (num_spatial**2) * c\n    model.total_ops += th.DoubleTensor([matmul_ops])\n\n\nclass QKVAttentionLegacy(nn.Module):\n    \"\"\"\n    A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping\n    \"\"\"\n\n    def __init__(self, n_heads):\n        super().__init__()\n        self.n_heads = n_heads\n\n    def forward(self, qkv):\n        \"\"\"\n        Apply QKV attention.\n        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.\n        :return: an [N x (H * C) x T] tensor after attention.\n        \"\"\"\n        bs, width, length = qkv.shape\n        assert width % (3 * self.n_heads) == 0\n        ch = width // (3 * self.n_heads)\n        q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(\n            ch, dim=1\n        )\n        scale = 1 / math.sqrt(math.sqrt(ch))\n        weight = th.einsum(\n            'bct,bcs->bts', q * scale, k * scale\n        )  # More stable with f16 than dividing afterwards\n        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)\n        a = th.einsum('bts,bcs->bct', weight, v)\n        return a.reshape(bs, -1, length)\n\n    @staticmethod\n    def count_flops(model, _x, y):\n        return count_flops_attn(model, _x, y)\n\n\nclass QKVAttention(nn.Module):\n    \"\"\"\n    A module which performs QKV attention and splits in a different order.\n    \"\"\"\n\n    def __init__(self, n_heads):\n        super().__init__()\n        self.n_heads = n_heads\n\n    def forward(self, qkv):\n        \"\"\"\n        Apply QKV attention.\n        :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.\n        :return: an [N x (H * C) x T] tensor after attention.\n        \"\"\"\n        bs, width, length = qkv.shape\n        assert width % (3 * self.n_heads) == 0\n        ch = width // (3 * self.n_heads)\n        q, k, v = qkv.chunk(3, dim=1)\n        scale = 1 / math.sqrt(math.sqrt(ch))\n        weight = th.einsum(\n            'bct,bcs->bts',\n            (q * scale).view(bs * self.n_heads, ch, length),\n            (k * scale).view(bs * self.n_heads, ch, length),\n        )  # More stable with f16 than dividing afterwards\n        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)\n        a = th.einsum(\n            'bts,bcs->bct', weight, v.reshape(bs * self.n_heads, ch, length)\n        )\n        return a.reshape(bs, -1, length)\n\n    @staticmethod\n    def count_flops(model, _x, y):\n        return count_flops_attn(model, _x, y)\n\n\nclass UNetModel(nn.Module):\n    \"\"\"\n    The full UNet model with attention and timestep embedding.\n    :param in_channels: channels in the input Tensor.\n    :param model_channels: base channel count for the model.\n    :param out_channels: channels in the output Tensor.\n    :param num_res_blocks: number of residual blocks per downsample.\n    :param attention_resolutions: a collection of downsample rates at which\n        attention will take place. May be a set, list, or tuple.\n        For example, if this contains 4, then at 4x downsampling, attention\n        will be used.\n    :param dropout: the dropout probability.\n    :param channel_mult: channel multiplier for each level of the UNet.\n    :param conv_resample: if True, use learned convolutions for upsampling and\n        downsampling.\n    :param dims: determines if the signal is 1D, 2D, or 3D.\n    :param num_classes: if specified (as an int), then this model will be\n        class-conditional with `num_classes` classes.\n    :param use_checkpoint: use gradient checkpointing to reduce memory usage.\n    :param num_heads: the number of attention heads in each attention layer.\n    :param num_heads_channels: if specified, ignore num_heads and instead use\n                               a fixed channel width per attention head.\n    :param num_heads_upsample: works with num_heads to set a different number\n                               of heads for upsampling. Deprecated.\n    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.\n    :param resblock_updown: use residual blocks for up/downsampling.\n    :param use_new_attention_order: use a different attention pattern for potentially\n                                    increased efficiency.\n    \"\"\"\n\n    def __init__(\n        self,\n        image_size,\n        in_channels,\n        model_channels,\n        out_channels,\n        num_res_blocks,\n        attention_resolutions,\n        dropout=0,\n        channel_mult=(1, 2, 4, 8),\n        conv_resample=True,\n        dims=2,\n        num_classes=None,\n        use_checkpoint=False,\n        use_fp16=False,\n        num_heads=-1,\n        num_head_channels=-1,\n        num_heads_upsample=-1,\n        use_scale_shift_norm=False,\n        resblock_updown=False,\n        use_new_attention_order=False,\n        use_spatial_transformer=False,  # custom transformer support\n        transformer_depth=1,  # custom transformer support\n        context_dim=None,  # custom transformer support\n        n_embed=None,  # custom support for prediction of discrete ids into codebook of first stage vq model\n        legacy=True,\n    ):\n        super().__init__()\n        if use_spatial_transformer:\n            assert (\n                context_dim is not None\n            ), 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'\n\n        if context_dim is not None:\n            assert (\n                use_spatial_transformer\n            ), 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'\n            from omegaconf.listconfig import ListConfig\n\n            if type(context_dim) == ListConfig:\n                context_dim = list(context_dim)\n\n        if num_heads_upsample == -1:\n            num_heads_upsample = num_heads\n\n        if num_heads == -1:\n            assert (\n                num_head_channels != -1\n            ), 'Either num_heads or num_head_channels has to be set'\n\n        if num_head_channels == -1:\n            assert (\n                num_heads != -1\n            ), 'Either num_heads or num_head_channels has to be set'\n\n        self.image_size = image_size\n        self.in_channels = in_channels\n        self.model_channels = model_channels\n        self.out_channels = out_channels\n        self.num_res_blocks = num_res_blocks\n        self.attention_resolutions = attention_resolutions\n        self.dropout = dropout\n        self.channel_mult = channel_mult\n        self.conv_resample = conv_resample\n        self.num_classes = num_classes\n        self.use_checkpoint = use_checkpoint\n        self.dtype = th.float16 if use_fp16 else th.float32\n        self.num_heads = num_heads\n        self.num_head_channels = num_head_channels\n        self.num_heads_upsample = num_heads_upsample\n        self.predict_codebook_ids = n_embed is not None\n\n        time_embed_dim = model_channels * 4\n        self.time_embed = nn.Sequential(\n            linear(model_channels, time_embed_dim),\n            nn.SiLU(),\n            linear(time_embed_dim, time_embed_dim),\n        )\n\n        if self.num_classes is not None:\n            self.label_emb = nn.Embedding(num_classes, time_embed_dim)\n\n        self.input_blocks = nn.ModuleList(\n            [\n                TimestepEmbedSequential(\n                    conv_nd(dims, in_channels, model_channels, 3, padding=1)\n                )\n            ]\n        )\n        self._feature_size = model_channels\n        input_block_chans = [model_channels]\n        ch = model_channels\n        ds = 1\n        for level, mult in enumerate(channel_mult):\n            for _ in range(num_res_blocks):\n                layers = [\n                    ResBlock(\n                        ch,\n                        time_embed_dim,\n                        dropout,\n                        out_channels=mult * model_channels,\n                        dims=dims,\n                        use_checkpoint=use_checkpoint,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                    )\n                ]\n                ch = mult * model_channels\n                if ds in attention_resolutions:\n                    if num_head_channels == -1:\n                        dim_head = ch // num_heads\n                    else:\n                        num_heads = ch // num_head_channels\n                        dim_head = num_head_channels\n                    if legacy:\n                        # num_heads = 1\n                        dim_head = (\n                            ch // num_heads\n                            if use_spatial_transformer\n                            else num_head_channels\n                        )\n                    layers.append(\n                        AttentionBlock(\n                            ch,\n                            use_checkpoint=use_checkpoint,\n                            num_heads=num_heads,\n                            num_head_channels=dim_head,\n                            use_new_attention_order=use_new_attention_order,\n                        )\n                        if not use_spatial_transformer\n                        else SpatialTransformer(\n                            ch,\n                            num_heads,\n                            dim_head,\n                            depth=transformer_depth,\n                            context_dim=context_dim,\n                        )\n                    )\n                self.input_blocks.append(TimestepEmbedSequential(*layers))\n                self._feature_size += ch\n                input_block_chans.append(ch)\n            if level != len(channel_mult) - 1:\n                out_ch = ch\n                self.input_blocks.append(\n                    TimestepEmbedSequential(\n                        ResBlock(\n                            ch,\n                            time_embed_dim,\n                            dropout,\n                            out_channels=out_ch,\n                            dims=dims,\n                            use_checkpoint=use_checkpoint,\n                            use_scale_shift_norm=use_scale_shift_norm,\n                            down=True,\n                        )\n                        if resblock_updown\n                        else Downsample(\n                            ch, conv_resample, dims=dims, out_channels=out_ch\n                        )\n                    )\n                )\n                ch = out_ch\n                input_block_chans.append(ch)\n                ds *= 2\n                self._feature_size += ch\n\n        if num_head_channels == -1:\n            dim_head = ch // num_heads\n        else:\n            num_heads = ch // num_head_channels\n            dim_head = num_head_channels\n        if legacy:\n            # num_heads = 1\n            dim_head = (\n                ch // num_heads\n                if use_spatial_transformer\n                else num_head_channels\n            )\n        self.middle_block = TimestepEmbedSequential(\n            ResBlock(\n                ch,\n                time_embed_dim,\n                dropout,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n            AttentionBlock(\n                ch,\n                use_checkpoint=use_checkpoint,\n                num_heads=num_heads,\n                num_head_channels=dim_head,\n                use_new_attention_order=use_new_attention_order,\n            )\n            if not use_spatial_transformer\n            else SpatialTransformer(\n                ch,\n                num_heads,\n                dim_head,\n                depth=transformer_depth,\n                context_dim=context_dim,\n            ),\n            ResBlock(\n                ch,\n                time_embed_dim,\n                dropout,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n        )\n        self._feature_size += ch\n\n        self.output_blocks = nn.ModuleList([])\n        for level, mult in list(enumerate(channel_mult))[::-1]:\n            for i in range(num_res_blocks + 1):\n                ich = input_block_chans.pop()\n                layers = [\n                    ResBlock(\n                        ch + ich,\n                        time_embed_dim,\n                        dropout,\n                        out_channels=model_channels * mult,\n                        dims=dims,\n                        use_checkpoint=use_checkpoint,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                    )\n                ]\n                ch = model_channels * mult\n                if ds in attention_resolutions:\n                    if num_head_channels == -1:\n                        dim_head = ch // num_heads\n                    else:\n                        num_heads = ch // num_head_channels\n                        dim_head = num_head_channels\n                    if legacy:\n                        # num_heads = 1\n                        dim_head = (\n                            ch // num_heads\n                            if use_spatial_transformer\n                            else num_head_channels\n                        )\n                    layers.append(\n                        AttentionBlock(\n                            ch,\n                            use_checkpoint=use_checkpoint,\n                            num_heads=num_heads_upsample,\n                            num_head_channels=dim_head,\n                            use_new_attention_order=use_new_attention_order,\n                        )\n                        if not use_spatial_transformer\n                        else SpatialTransformer(\n                            ch,\n                            num_heads,\n                            dim_head,\n                            depth=transformer_depth,\n                            context_dim=context_dim,\n                        )\n                    )\n                if level and i == num_res_blocks:\n                    out_ch = ch\n                    layers.append(\n                        ResBlock(\n                            ch,\n                            time_embed_dim,\n                            dropout,\n                            out_channels=out_ch,\n                            dims=dims,\n                            use_checkpoint=use_checkpoint,\n                            use_scale_shift_norm=use_scale_shift_norm,\n                            up=True,\n                        )\n                        if resblock_updown\n                        else Upsample(\n                            ch, conv_resample, dims=dims, out_channels=out_ch\n                        )\n                    )\n                    ds //= 2\n                self.output_blocks.append(TimestepEmbedSequential(*layers))\n                self._feature_size += ch\n\n        self.out = nn.Sequential(\n            normalization(ch),\n            nn.SiLU(),\n            zero_module(\n                conv_nd(dims, model_channels, out_channels, 3, padding=1)\n            ),\n        )\n        if self.predict_codebook_ids:\n            self.id_predictor = nn.Sequential(\n                normalization(ch),\n                conv_nd(dims, model_channels, n_embed, 1),\n                # nn.LogSoftmax(dim=1)  # change to cross_entropy and produce non-normalized logits\n            )\n\n    def convert_to_fp16(self):\n        \"\"\"\n        Convert the torso of the model to float16.\n        \"\"\"\n        self.input_blocks.apply(convert_module_to_f16)\n        self.middle_block.apply(convert_module_to_f16)\n        self.output_blocks.apply(convert_module_to_f16)\n\n    def convert_to_fp32(self):\n        \"\"\"\n        Convert the torso of the model to float32.\n        \"\"\"\n        self.input_blocks.apply(convert_module_to_f32)\n        self.middle_block.apply(convert_module_to_f32)\n        self.output_blocks.apply(convert_module_to_f32)\n\n    def forward(self, x, timesteps=None, context=None, y=None, **kwargs):\n        \"\"\"\n        Apply the model to an input batch.\n        :param x: an [N x C x ...] Tensor of inputs.\n        :param timesteps: a 1-D batch of timesteps.\n        :param context: conditioning plugged in via crossattn\n        :param y: an [N] Tensor of labels, if class-conditional.\n        :return: an [N x C x ...] Tensor of outputs.\n        \"\"\"\n        assert (y is not None) == (\n            self.num_classes is not None\n        ), 'must specify y if and only if the model is class-conditional'\n        hs = []\n        t_emb = timestep_embedding(\n            timesteps, self.model_channels, repeat_only=False\n        )\n        emb = self.time_embed(t_emb)\n\n        if self.num_classes is not None:\n            assert y.shape == (x.shape[0],)\n            emb = emb + self.label_emb(y)\n\n        h = x.type(self.dtype)\n        for module in self.input_blocks:\n            h = module(h, emb, context)\n            hs.append(h)\n        h = self.middle_block(h, emb, context)\n        for module in self.output_blocks:\n            h = th.cat([h, hs.pop()], dim=1)\n            h = module(h, emb, context)\n        h = h.type(x.dtype)\n        if self.predict_codebook_ids:\n            return self.id_predictor(h)\n        else:\n            return self.out(h)\n\n\nclass EncoderUNetModel(nn.Module):\n    \"\"\"\n    The half UNet model with attention and timestep embedding.\n    For usage, see UNet.\n    \"\"\"\n\n    def __init__(\n        self,\n        image_size,\n        in_channels,\n        model_channels,\n        out_channels,\n        num_res_blocks,\n        attention_resolutions,\n        dropout=0,\n        channel_mult=(1, 2, 4, 8),\n        conv_resample=True,\n        dims=2,\n        use_checkpoint=False,\n        use_fp16=False,\n        num_heads=1,\n        num_head_channels=-1,\n        num_heads_upsample=-1,\n        use_scale_shift_norm=False,\n        resblock_updown=False,\n        use_new_attention_order=False,\n        pool='adaptive',\n        *args,\n        **kwargs,\n    ):\n        super().__init__()\n\n        if num_heads_upsample == -1:\n            num_heads_upsample = num_heads\n\n        self.in_channels = in_channels\n        self.model_channels = model_channels\n        self.out_channels = out_channels\n        self.num_res_blocks = num_res_blocks\n        self.attention_resolutions = attention_resolutions\n        self.dropout = dropout\n        self.channel_mult = channel_mult\n        self.conv_resample = conv_resample\n        self.use_checkpoint = use_checkpoint\n        self.dtype = th.float16 if use_fp16 else th.float32\n        self.num_heads = num_heads\n        self.num_head_channels = num_head_channels\n        self.num_heads_upsample = num_heads_upsample\n\n        time_embed_dim = model_channels * 4\n        self.time_embed = nn.Sequential(\n            linear(model_channels, time_embed_dim),\n            nn.SiLU(),\n            linear(time_embed_dim, time_embed_dim),\n        )\n\n        self.input_blocks = nn.ModuleList(\n            [\n                TimestepEmbedSequential(\n                    conv_nd(dims, in_channels, model_channels, 3, padding=1)\n                )\n            ]\n        )\n        self._feature_size = model_channels\n        input_block_chans = [model_channels]\n        ch = model_channels\n        ds = 1\n        for level, mult in enumerate(channel_mult):\n            for _ in range(num_res_blocks):\n                layers = [\n                    ResBlock(\n                        ch,\n                        time_embed_dim,\n                        dropout,\n                        out_channels=mult * model_channels,\n                        dims=dims,\n                        use_checkpoint=use_checkpoint,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                    )\n                ]\n                ch = mult * model_channels\n                if ds in attention_resolutions:\n                    layers.append(\n                        AttentionBlock(\n                            ch,\n                            use_checkpoint=use_checkpoint,\n                            num_heads=num_heads,\n                            num_head_channels=num_head_channels,\n                            use_new_attention_order=use_new_attention_order,\n                        )\n                    )\n                self.input_blocks.append(TimestepEmbedSequential(*layers))\n                self._feature_size += ch\n                input_block_chans.append(ch)\n            if level != len(channel_mult) - 1:\n                out_ch = ch\n                self.input_blocks.append(\n                    TimestepEmbedSequential(\n                        ResBlock(\n                            ch,\n                            time_embed_dim,\n                            dropout,\n                            out_channels=out_ch,\n                            dims=dims,\n                            use_checkpoint=use_checkpoint,\n                            use_scale_shift_norm=use_scale_shift_norm,\n                            down=True,\n                        )\n                        if resblock_updown\n                        else Downsample(\n                            ch, conv_resample, dims=dims, out_channels=out_ch\n                        )\n                    )\n                )\n                ch = out_ch\n                input_block_chans.append(ch)\n                ds *= 2\n                self._feature_size += ch\n\n        self.middle_block = TimestepEmbedSequential(\n            ResBlock(\n                ch,\n                time_embed_dim,\n                dropout,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n            AttentionBlock(\n                ch,\n                use_checkpoint=use_checkpoint,\n                num_heads=num_heads,\n                num_head_channels=num_head_channels,\n                use_new_attention_order=use_new_attention_order,\n            ),\n            ResBlock(\n                ch,\n                time_embed_dim,\n                dropout,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n        )\n        self._feature_size += ch\n        self.pool = pool\n        if pool == 'adaptive':\n            self.out = nn.Sequential(\n                normalization(ch),\n                nn.SiLU(),\n                nn.AdaptiveAvgPool2d((1, 1)),\n                zero_module(conv_nd(dims, ch, out_channels, 1)),\n                nn.Flatten(),\n            )\n        elif pool == 'attention':\n            assert num_head_channels != -1\n            self.out = nn.Sequential(\n                normalization(ch),\n                nn.SiLU(),\n                AttentionPool2d(\n                    (image_size // ds), ch, num_head_channels, out_channels\n                ),\n            )\n        elif pool == 'spatial':\n            self.out = nn.Sequential(\n                nn.Linear(self._feature_size, 2048),\n                nn.ReLU(),\n                nn.Linear(2048, self.out_channels),\n            )\n        elif pool == 'spatial_v2':\n            self.out = nn.Sequential(\n                nn.Linear(self._feature_size, 2048),\n                normalization(2048),\n                nn.SiLU(),\n                nn.Linear(2048, self.out_channels),\n            )\n        else:\n            raise NotImplementedError(f'Unexpected {pool} pooling')\n\n    def convert_to_fp16(self):\n        \"\"\"\n        Convert the torso of the model to float16.\n        \"\"\"\n        self.input_blocks.apply(convert_module_to_f16)\n        self.middle_block.apply(convert_module_to_f16)\n\n    def convert_to_fp32(self):\n        \"\"\"\n        Convert the torso of the model to float32.\n        \"\"\"\n        self.input_blocks.apply(convert_module_to_f32)\n        self.middle_block.apply(convert_module_to_f32)\n\n    def forward(self, x, timesteps):\n        \"\"\"\n        Apply the model to an input batch.\n        :param x: an [N x C x ...] Tensor of inputs.\n        :param timesteps: a 1-D batch of timesteps.\n        :return: an [N x K] Tensor of outputs.\n        \"\"\"\n        emb = self.time_embed(\n            timestep_embedding(timesteps, self.model_channels)\n        )\n\n        results = []\n        h = x.type(self.dtype)\n        for module in self.input_blocks:\n            h = module(h, emb)\n            if self.pool.startswith('spatial'):\n                results.append(h.type(x.dtype).mean(dim=(2, 3)))\n        h = self.middle_block(h, emb)\n        if self.pool.startswith('spatial'):\n            results.append(h.type(x.dtype).mean(dim=(2, 3)))\n            h = th.cat(results, axis=-1)\n            return self.out(h)\n        else:\n            h = h.type(x.dtype)\n            return self.out(h)\n"
  },
  {
    "path": "src/stablediffusion/ldm/modules/diffusionmodules/util.py",
    "content": "# adopted from\n# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py\n# and\n# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py\n# and\n# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py\n#\n# thanks!\n\n\nimport os\nimport math\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom einops import repeat\n\nfrom src.stablediffusion.ldm.util import instantiate_from_config\n\n\ndef make_beta_schedule(\n    schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3\n):\n    if schedule == 'linear':\n        betas = (\n            torch.linspace(\n                linear_start**0.5,\n                linear_end**0.5,\n                n_timestep,\n                dtype=torch.float64,\n            )\n            ** 2\n        )\n\n    elif schedule == 'cosine':\n        timesteps = (\n            torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep\n            + cosine_s\n        )\n        alphas = timesteps / (1 + cosine_s) * np.pi / 2\n        alphas = torch.cos(alphas).pow(2)\n        alphas = alphas / alphas[0]\n        betas = 1 - alphas[1:] / alphas[:-1]\n        betas = np.clip(betas, a_min=0, a_max=0.999)\n\n    elif schedule == 'sqrt_linear':\n        betas = torch.linspace(\n            linear_start, linear_end, n_timestep, dtype=torch.float64\n        )\n    elif schedule == 'sqrt':\n        betas = (\n            torch.linspace(\n                linear_start, linear_end, n_timestep, dtype=torch.float64\n            )\n            ** 0.5\n        )\n    else:\n        raise ValueError(f\"schedule '{schedule}' unknown.\")\n    return betas.numpy()\n\n\ndef make_ddim_timesteps(\n    ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True\n):\n    if ddim_discr_method == 'uniform':\n        c = num_ddpm_timesteps // num_ddim_timesteps\n        ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))\n    elif ddim_discr_method == 'quad':\n        ddim_timesteps = (\n            (\n                np.linspace(\n                    0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps\n                )\n            )\n            ** 2\n        ).astype(int)\n    else:\n        raise NotImplementedError(\n            f'There is no ddim discretization method called \"{ddim_discr_method}\"'\n        )\n\n    # assert ddim_timesteps.shape[0] == num_ddim_timesteps\n    # add one to get the final alpha values right (the ones from first scale to data during sampling)\n#    steps_out = ddim_timesteps + 1\n    steps_out = ddim_timesteps\n\n    if verbose:\n        print(f'Selected timesteps for ddim sampler: {steps_out}')\n    return steps_out\n\n\ndef make_ddim_sampling_parameters(\n    alphacums, ddim_timesteps, eta, verbose=True\n):\n    # select alphas for computing the variance schedule\n    alphas = alphacums[ddim_timesteps]\n    alphas_prev = np.asarray(\n        [alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()\n    )\n\n    # according the the formula provided in https://arxiv.org/abs/2010.02502\n    sigmas = eta * np.sqrt(\n        (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)\n    )\n    if verbose:\n        print(\n            f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'\n        )\n        print(\n            f'For the chosen value of eta, which is {eta}, '\n            f'this results in the following sigma_t schedule for ddim sampler {sigmas}'\n        )\n    return sigmas, alphas, alphas_prev\n\n\ndef betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):\n    \"\"\"\n    Create a beta schedule that discretizes the given alpha_t_bar function,\n    which defines the cumulative product of (1-beta) over time from t = [0,1].\n    :param num_diffusion_timesteps: the number of betas to produce.\n    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and\n                      produces the cumulative product of (1-beta) up to that\n                      part of the diffusion process.\n    :param max_beta: the maximum beta to use; use values lower than 1 to\n                     prevent singularities.\n    \"\"\"\n    betas = []\n    for i in range(num_diffusion_timesteps):\n        t1 = i / num_diffusion_timesteps\n        t2 = (i + 1) / num_diffusion_timesteps\n        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))\n    return np.array(betas)\n\n\ndef extract_into_tensor(a, t, x_shape):\n    b, *_ = t.shape\n    out = a.gather(-1, t)\n    return out.reshape(b, *((1,) * (len(x_shape) - 1)))\n\n\ndef checkpoint(func, inputs, params, flag):\n    \"\"\"\n    Evaluate a function without caching intermediate activations, allowing for\n    reduced memory at the expense of extra compute in the backward pass.\n    :param func: the function to evaluate.\n    :param inputs: the argument sequence to pass to `func`.\n    :param params: a sequence of parameters `func` depends on but does not\n                   explicitly take as arguments.\n    :param flag: if False, disable gradient checkpointing.\n    \"\"\"\n    if (\n        False\n    ):   # disabled checkpointing to allow requires_grad = False for main model\n        args = tuple(inputs) + tuple(params)\n        return CheckpointFunction.apply(func, len(inputs), *args)\n    else:\n        return func(*inputs)\n\n\nclass CheckpointFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, run_function, length, *args):\n        ctx.run_function = run_function\n        ctx.input_tensors = list(args[:length])\n        ctx.input_params = list(args[length:])\n\n        with torch.no_grad():\n            output_tensors = ctx.run_function(*ctx.input_tensors)\n        return output_tensors\n\n    @staticmethod\n    def backward(ctx, *output_grads):\n        ctx.input_tensors = [\n            x.detach().requires_grad_(True) for x in ctx.input_tensors\n        ]\n        with torch.enable_grad():\n            # Fixes a bug where the first op in run_function modifies the\n            # Tensor storage in place, which is not allowed for detach()'d\n            # Tensors.\n            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]\n            output_tensors = ctx.run_function(*shallow_copies)\n        input_grads = torch.autograd.grad(\n            output_tensors,\n            ctx.input_tensors + ctx.input_params,\n            output_grads,\n            allow_unused=True,\n        )\n        del ctx.input_tensors\n        del ctx.input_params\n        del output_tensors\n        return (None, None) + input_grads\n\n\ndef timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):\n    \"\"\"\n    Create sinusoidal timestep embeddings.\n    :param timesteps: a 1-D Tensor of N indices, one per batch element.\n                      These may be fractional.\n    :param dim: the dimension of the output.\n    :param max_period: controls the minimum frequency of the embeddings.\n    :return: an [N x dim] Tensor of positional embeddings.\n    \"\"\"\n    if not repeat_only:\n        half = dim // 2\n        freqs = torch.exp(\n            -math.log(max_period)\n            * torch.arange(start=0, end=half, dtype=torch.float32)\n            / half\n        ).to(device=timesteps.device)\n        args = timesteps[:, None].float() * freqs[None]\n        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n        if dim % 2:\n            embedding = torch.cat(\n                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1\n            )\n    else:\n        embedding = repeat(timesteps, 'b -> b d', d=dim)\n    return embedding\n\n\ndef zero_module(module):\n    \"\"\"\n    Zero out the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().zero_()\n    return module\n\n\ndef scale_module(module, scale):\n    \"\"\"\n    Scale the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().mul_(scale)\n    return module\n\n\ndef mean_flat(tensor):\n    \"\"\"\n    Take the mean over all non-batch dimensions.\n    \"\"\"\n    return tensor.mean(dim=list(range(1, len(tensor.shape))))\n\n\ndef normalization(channels):\n    \"\"\"\n    Make a standard normalization layer.\n    :param channels: number of input channels.\n    :return: an nn.Module for normalization.\n    \"\"\"\n    return GroupNorm32(32, channels)\n\n\n# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.\nclass SiLU(nn.Module):\n    def forward(self, x):\n        return x * torch.sigmoid(x)\n\n\nclass GroupNorm32(nn.GroupNorm):\n    def forward(self, x):\n        return super().forward(x.float()).type(x.dtype)\n\n\ndef conv_nd(dims, *args, **kwargs):\n    \"\"\"\n    Create a 1D, 2D, or 3D convolution module.\n    \"\"\"\n    if dims == 1:\n        return nn.Conv1d(*args, **kwargs)\n    elif dims == 2:\n        return nn.Conv2d(*args, **kwargs)\n    elif dims == 3:\n        return nn.Conv3d(*args, **kwargs)\n    raise ValueError(f'unsupported dimensions: {dims}')\n\n\ndef linear(*args, **kwargs):\n    \"\"\"\n    Create a linear module.\n    \"\"\"\n    return nn.Linear(*args, **kwargs)\n\n\ndef avg_pool_nd(dims, *args, **kwargs):\n    \"\"\"\n    Create a 1D, 2D, or 3D average pooling module.\n    \"\"\"\n    if dims == 1:\n        return nn.AvgPool1d(*args, **kwargs)\n    elif dims == 2:\n        return nn.AvgPool2d(*args, **kwargs)\n    elif dims == 3:\n        return nn.AvgPool3d(*args, **kwargs)\n    raise ValueError(f'unsupported dimensions: {dims}')\n\n\nclass HybridConditioner(nn.Module):\n    def __init__(self, c_concat_config, c_crossattn_config):\n        super().__init__()\n        self.concat_conditioner = instantiate_from_config(c_concat_config)\n        self.crossattn_conditioner = instantiate_from_config(\n            c_crossattn_config\n        )\n\n    def forward(self, c_concat, c_crossattn):\n        c_concat = self.concat_conditioner(c_concat)\n        c_crossattn = self.crossattn_conditioner(c_crossattn)\n        return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}\n\n\ndef noise_like(shape, device, repeat=False):\n    repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(\n        shape[0], *((1,) * (len(shape) - 1))\n    )\n    noise = lambda: torch.randn(shape, device=device)\n    return repeat_noise() if repeat else noise()\n"
  },
  {
    "path": "src/stablediffusion/ldm/modules/distributions/__init__.py",
    "content": ""
  },
  {
    "path": "src/stablediffusion/ldm/modules/distributions/distributions.py",
    "content": "import torch\nimport numpy as np\n\n\nclass AbstractDistribution:\n    def sample(self):\n        raise NotImplementedError()\n\n    def mode(self):\n        raise NotImplementedError()\n\n\nclass DiracDistribution(AbstractDistribution):\n    def __init__(self, value):\n        self.value = value\n\n    def sample(self):\n        return self.value\n\n    def mode(self):\n        return self.value\n\n\nclass DiagonalGaussianDistribution(object):\n    def __init__(self, parameters, deterministic=False):\n        self.parameters = parameters\n        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)\n        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)\n        self.deterministic = deterministic\n        self.std = torch.exp(0.5 * self.logvar)\n        self.var = torch.exp(self.logvar)\n        if self.deterministic:\n            self.var = self.std = torch.zeros_like(self.mean).to(\n                device=self.parameters.device\n            )\n\n    def sample(self):\n        x = self.mean + self.std * torch.randn(self.mean.shape).to(\n            device=self.parameters.device\n        )\n        return x\n\n    def kl(self, other=None):\n        if self.deterministic:\n            return torch.Tensor([0.0])\n        else:\n            if other is None:\n                return 0.5 * torch.sum(\n                    torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,\n                    dim=[1, 2, 3],\n                )\n            else:\n                return 0.5 * torch.sum(\n                    torch.pow(self.mean - other.mean, 2) / other.var\n                    + self.var / other.var\n                    - 1.0\n                    - self.logvar\n                    + other.logvar,\n                    dim=[1, 2, 3],\n                )\n\n    def nll(self, sample, dims=[1, 2, 3]):\n        if self.deterministic:\n            return torch.Tensor([0.0])\n        logtwopi = np.log(2.0 * np.pi)\n        return 0.5 * torch.sum(\n            logtwopi\n            + self.logvar\n            + torch.pow(sample - self.mean, 2) / self.var,\n            dim=dims,\n        )\n\n    def mode(self):\n        return self.mean\n\n\ndef normal_kl(mean1, logvar1, mean2, logvar2):\n    \"\"\"\n    source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12\n    Compute the KL divergence between two gaussians.\n    Shapes are automatically broadcasted, so batches can be compared to\n    scalars, among other use cases.\n    \"\"\"\n    tensor = None\n    for obj in (mean1, logvar1, mean2, logvar2):\n        if isinstance(obj, torch.Tensor):\n            tensor = obj\n            break\n    assert tensor is not None, 'at least one argument must be a Tensor'\n\n    # Force variances to be Tensors. Broadcasting helps convert scalars to\n    # Tensors, but it does not work for torch.exp().\n    logvar1, logvar2 = [\n        x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)\n        for x in (logvar1, logvar2)\n    ]\n\n    return 0.5 * (\n        -1.0\n        + logvar2\n        - logvar1\n        + torch.exp(logvar1 - logvar2)\n        + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)\n    )\n"
  },
  {
    "path": "src/stablediffusion/ldm/modules/ema.py",
    "content": "import torch\nfrom torch import nn\n\n\nclass LitEma(nn.Module):\n    def __init__(self, model, decay=0.9999, use_num_upates=True):\n        super().__init__()\n        if decay < 0.0 or decay > 1.0:\n            raise ValueError('Decay must be between 0 and 1')\n\n        self.m_name2s_name = {}\n        self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))\n        self.register_buffer(\n            'num_updates',\n            torch.tensor(0, dtype=torch.int)\n            if use_num_upates\n            else torch.tensor(-1, dtype=torch.int),\n        )\n\n        for name, p in model.named_parameters():\n            if p.requires_grad:\n                # remove as '.'-character is not allowed in buffers\n                s_name = name.replace('.', '')\n                self.m_name2s_name.update({name: s_name})\n                self.register_buffer(s_name, p.clone().detach().data)\n\n        self.collected_params = []\n\n    def forward(self, model):\n        decay = self.decay\n\n        if self.num_updates >= 0:\n            self.num_updates += 1\n            decay = min(\n                self.decay, (1 + self.num_updates) / (10 + self.num_updates)\n            )\n\n        one_minus_decay = 1.0 - decay\n\n        with torch.no_grad():\n            m_param = dict(model.named_parameters())\n            shadow_params = dict(self.named_buffers())\n\n            for key in m_param:\n                if m_param[key].requires_grad:\n                    sname = self.m_name2s_name[key]\n                    shadow_params[sname] = shadow_params[sname].type_as(\n                        m_param[key]\n                    )\n                    shadow_params[sname].sub_(\n                        one_minus_decay * (shadow_params[sname] - m_param[key])\n                    )\n                else:\n                    assert not key in self.m_name2s_name\n\n    def copy_to(self, model):\n        m_param = dict(model.named_parameters())\n        shadow_params = dict(self.named_buffers())\n        for key in m_param:\n            if m_param[key].requires_grad:\n                m_param[key].data.copy_(\n                    shadow_params[self.m_name2s_name[key]].data\n                )\n            else:\n                assert not key in self.m_name2s_name\n\n    def store(self, parameters):\n        \"\"\"\n        Save the current parameters for restoring later.\n        Args:\n          parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n            temporarily stored.\n        \"\"\"\n        self.collected_params = [param.clone() for param in parameters]\n\n    def restore(self, parameters):\n        \"\"\"\n        Restore the parameters stored with the `store` method.\n        Useful to validate the model with EMA parameters without affecting the\n        original optimization process. Store the parameters before the\n        `copy_to` method. After validation (or model saving), use this to\n        restore the former parameters.\n        Args:\n          parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n            updated with the stored parameters.\n        \"\"\"\n        for c_param, param in zip(self.collected_params, parameters):\n            param.data.copy_(c_param.data)\n"
  },
  {
    "path": "src/stablediffusion/ldm/modules/embedding_manager.py",
    "content": "from cmath import log\nimport torch\nfrom torch import nn\n\nimport sys\n\nfrom src.stablediffusion.ldm.data.personalized import per_img_token_list\nfrom transformers import CLIPTokenizer\nfrom functools import partial\n\nDEFAULT_PLACEHOLDER_TOKEN = ['*']\n\nPROGRESSIVE_SCALE = 2000\n\n\ndef get_clip_token_for_string(tokenizer, string):\n    batch_encoding = tokenizer(\n        string,\n        truncation=True,\n        max_length=77,\n        return_length=True,\n        return_overflowing_tokens=False,\n        padding='max_length',\n        return_tensors='pt',\n    )\n    tokens = batch_encoding['input_ids']\n    \"\"\" assert (\n        torch.count_nonzero(tokens - 49407) == 2\n    ), f\"String '{string}' maps to more than a single token. Please use another string\" \"\"\"\n\n    return tokens[0, 1]\n\n\ndef get_bert_token_for_string(tokenizer, string):\n    token = tokenizer(string)\n    # assert torch.count_nonzero(token) == 3, f\"String '{string}' maps to more than a single token. Please use another string\"\n\n    token = token[0, 1]\n\n    return token\n\n\ndef get_embedding_for_clip_token(embedder, token):\n    return embedder(token.unsqueeze(0))[0, 0]\n\n\nclass EmbeddingManager(nn.Module):\n    def __init__(\n        self,\n        embedder,\n        placeholder_strings=None,\n        initializer_words=None,\n        per_image_tokens=False,\n        num_vectors_per_token=1,\n        progressive_words=False,\n        **kwargs,\n    ):\n        super().__init__()\n\n        self.embedder = embedder\n\n        self.string_to_token_dict = {}\n        self.string_to_param_dict = nn.ParameterDict()\n\n        self.initial_embeddings = (\n            nn.ParameterDict()\n        )   # These should not be optimized\n\n        self.progressive_words = progressive_words\n        self.progressive_counter = 0\n\n        self.max_vectors_per_token = num_vectors_per_token\n\n        if hasattr(\n            embedder, 'tokenizer'\n        ):   # using Stable Diffusion's CLIP encoder\n            self.is_clip = True\n            get_token_for_string = partial(\n                get_clip_token_for_string, embedder.tokenizer\n            )\n            get_embedding_for_tkn = partial(\n                get_embedding_for_clip_token,\n                embedder.transformer.text_model.embeddings,\n            )\n            token_dim = 1280\n        else:   # using LDM's BERT encoder\n            self.is_clip = False\n            get_token_for_string = partial(\n                get_bert_token_for_string, embedder.tknz_fn\n            )\n            get_embedding_for_tkn = embedder.transformer.token_emb\n            token_dim = 1280\n\n        if per_image_tokens:\n            placeholder_strings.extend(per_img_token_list)\n\n        for idx, placeholder_string in enumerate(placeholder_strings):\n\n            token = get_token_for_string(placeholder_string)\n\n            if initializer_words and idx < len(initializer_words):\n                init_word_token = get_token_for_string(initializer_words[idx])\n\n                with torch.no_grad():\n                    init_word_embedding = get_embedding_for_tkn(\n                        init_word_token.cpu()\n                    )\n\n                token_params = torch.nn.Parameter(\n                    init_word_embedding.unsqueeze(0).repeat(\n                        num_vectors_per_token, 1\n                    ),\n                    requires_grad=True,\n                )\n                self.initial_embeddings[\n                    placeholder_string\n                ] = torch.nn.Parameter(\n                    init_word_embedding.unsqueeze(0).repeat(\n                        num_vectors_per_token, 1\n                    ),\n                    requires_grad=False,\n                )\n            else:\n                token_params = torch.nn.Parameter(\n                    torch.rand(\n                        size=(num_vectors_per_token, token_dim),\n                        requires_grad=True,\n                    )\n                )\n\n            self.string_to_token_dict[placeholder_string] = token\n            self.string_to_param_dict[placeholder_string] = token_params\n\n    def forward(\n        self,\n        tokenized_text,\n        embedded_text,\n    ):\n        b, n, device = *tokenized_text.shape, tokenized_text.device\n\n        for (\n            placeholder_string,\n            placeholder_token,\n        ) in self.string_to_token_dict.items():\n\n            placeholder_embedding = self.string_to_param_dict[\n                placeholder_string\n            ].to(device)\n\n            if (\n                self.max_vectors_per_token == 1\n            ):   # If there's only one vector per token, we can do a simple replacement\n                placeholder_idx = torch.where(\n                    tokenized_text == placeholder_token.to(device)\n                )\n                embedded_text[placeholder_idx] = placeholder_embedding\n            else:   # otherwise, need to insert and keep track of changing indices\n                if self.progressive_words:\n                    self.progressive_counter += 1\n                    max_step_tokens = (\n                        1 + self.progressive_counter // PROGRESSIVE_SCALE\n                    )\n                else:\n                    max_step_tokens = self.max_vectors_per_token\n\n                num_vectors_for_token = min(\n                    placeholder_embedding.shape[0], max_step_tokens\n                )\n\n                placeholder_rows, placeholder_cols = torch.where(\n                    tokenized_text == placeholder_token.to(device)\n                )\n\n                if placeholder_rows.nelement() == 0:\n                    continue\n\n                sorted_cols, sort_idx = torch.sort(\n                    placeholder_cols, descending=True\n                )\n                sorted_rows = placeholder_rows[sort_idx]\n\n                for idx in range(len(sorted_rows)):\n                    row = sorted_rows[idx]\n                    col = sorted_cols[idx]\n\n                    new_token_row = torch.cat(\n                        [\n                            tokenized_text[row][:col],\n                            placeholder_token.repeat(num_vectors_for_token).to(\n                                device\n                            ),\n                            tokenized_text[row][col + 1 :],\n                        ],\n                        axis=0,\n                    )[:n]\n                    new_embed_row = torch.cat(\n                        [\n                            embedded_text[row][:col],\n                            placeholder_embedding[:num_vectors_for_token],\n                            embedded_text[row][col + 1 :],\n                        ],\n                        axis=0,\n                    )[:n]\n\n                    embedded_text[row] = new_embed_row\n                    tokenized_text[row] = new_token_row\n\n        return embedded_text\n\n    def save(self, ckpt_path):\n        torch.save(\n            {\n                'string_to_token': self.string_to_token_dict,\n                'string_to_param': self.string_to_param_dict,\n            },\n            ckpt_path,\n        )\n\n    def load(self, ckpt_path, full=True):\n        ckpt = torch.load(ckpt_path, map_location='cpu')\n\n        # Handle .pt textual inversion files\n        if 'string_to_token' in ckpt and 'string_to_param' in ckpt:\n            self.string_to_token_dict = ckpt[\"string_to_token\"]\n            self.string_to_param_dict = ckpt[\"string_to_param\"]\n\n        # Handle .bin textual inversion files from Huggingface Concepts\n        # https://huggingface.co/sd-concepts-library\n        else:\n            for token_str in list(ckpt.keys()):\n                token = get_clip_token_for_string(self.embedder.tokenizer, token_str)\n                self.string_to_token_dict[token_str] = token\n                ckpt[token_str] = torch.nn.Parameter(ckpt[token_str])\n                \n            self.string_to_param_dict.update(ckpt)\n\n        if not full:\n            for key, value in self.string_to_param_dict.items():\n                self.string_to_param_dict[key] = torch.nn.Parameter(value.half())\n\n        print(f'Added terms: {\", \".join(self.string_to_param_dict.keys())}')\n\n    def get_embedding_norms_squared(self):\n        all_params = torch.cat(\n            list(self.string_to_param_dict.values()), axis=0\n        )   # num_placeholders x embedding_dim\n        param_norm_squared = (all_params * all_params).sum(\n            axis=-1\n        )              # num_placeholders\n\n        return param_norm_squared\n\n    def embedding_parameters(self):\n        return self.string_to_param_dict.parameters()\n\n    def embedding_to_coarse_loss(self):\n\n        loss = 0.0\n        num_embeddings = len(self.initial_embeddings)\n\n        for key in self.initial_embeddings:\n            optimized = self.string_to_param_dict[key]\n            coarse = self.initial_embeddings[key].clone().to(optimized.device)\n\n            loss = (\n                loss\n                + (optimized - coarse)\n                @ (optimized - coarse).T\n                / num_embeddings\n            )\n\n        return loss\n"
  },
  {
    "path": "src/stablediffusion/ldm/modules/encoders/__init__.py",
    "content": ""
  },
  {
    "path": "src/stablediffusion/ldm/modules/encoders/modules.py",
    "content": "import torch\nimport torch.nn as nn\nfrom functools import partial\nimport clip\nfrom einops import rearrange, repeat\nfrom transformers import CLIPTokenizer, CLIPTextModel\nimport kornia\nfrom src.stablediffusion.ldm.dream.devices import choose_torch_device\n\nfrom src.stablediffusion.ldm.modules.x_transformer import (\n    Encoder,\n    TransformerWrapper,\n)  # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test\n\n\ndef _expand_mask(mask, dtype, tgt_len=None):\n    \"\"\"\n    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n    \"\"\"\n    bsz, src_len = mask.size()\n    tgt_len = tgt_len if tgt_len is not None else src_len\n\n    expanded_mask = (\n        mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n    )\n\n    inverted_mask = 1.0 - expanded_mask\n\n    return inverted_mask.masked_fill(\n        inverted_mask.to(torch.bool), torch.finfo(dtype).min\n    )\n\n\ndef _build_causal_attention_mask(bsz, seq_len, dtype):\n    # lazily create causal attention mask, with full attention between the vision tokens\n    # pytorch uses additive attention mask; fill with -inf\n    mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)\n    mask.fill_(torch.tensor(torch.finfo(dtype).min))\n    mask.triu_(1)  # zero out the lower diagonal\n    mask = mask.unsqueeze(1)  # expand mask\n    return mask\n\n\nclass AbstractEncoder(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def encode(self, *args, **kwargs):\n        raise NotImplementedError\n\n\nclass ClassEmbedder(nn.Module):\n    def __init__(self, embed_dim, n_classes=1000, key='class'):\n        super().__init__()\n        self.key = key\n        self.embedding = nn.Embedding(n_classes, embed_dim)\n\n    def forward(self, batch, key=None):\n        if key is None:\n            key = self.key\n        # this is for use in crossattn\n        c = batch[key][:, None]\n        c = self.embedding(c)\n        return c\n\n\nclass TransformerEmbedder(AbstractEncoder):\n    \"\"\"Some transformer encoder layers\"\"\"\n\n    def __init__(\n        self,\n        n_embed,\n        n_layer,\n        vocab_size,\n        max_seq_len=77,\n        device=choose_torch_device(),\n    ):\n        super().__init__()\n        self.device = device\n        self.transformer = TransformerWrapper(\n            num_tokens=vocab_size,\n            max_seq_len=max_seq_len,\n            attn_layers=Encoder(dim=n_embed, depth=n_layer),\n        )\n\n    def forward(self, tokens):\n        tokens = tokens.to(self.device)  # meh\n        z = self.transformer(tokens, return_embeddings=True)\n        return z\n\n    def encode(self, x):\n        return self(x)\n\n\nclass BERTTokenizer(AbstractEncoder):\n    \"\"\"Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)\"\"\"\n\n    def __init__(\n        self, device=choose_torch_device(), vq_interface=True, max_length=77\n    ):\n        super().__init__()\n        from transformers import (\n            BertTokenizerFast,\n        )  # TODO: add to reuquirements\n\n        # Modified to allow to run on non-internet connected compute nodes.\n        # Model needs to be loaded into cache from an internet-connected machine\n        # by running:\n        #   from transformers import BertTokenizerFast\n        #   BertTokenizerFast.from_pretrained(\"bert-base-uncased\")\n        try:\n            self.tokenizer = BertTokenizerFast.from_pretrained(\n                'bert-base-uncased', local_files_only=False\n            )\n        except OSError:\n            raise SystemExit(\n                \"* Couldn't load Bert tokenizer files. Try running scripts/preload_models.py from an internet-conected machine.\"\n            )\n        self.device = device\n        self.vq_interface = vq_interface\n        self.max_length = max_length\n\n    def forward(self, text):\n        batch_encoding = self.tokenizer(\n            text,\n            truncation=True,\n            max_length=self.max_length,\n            return_length=True,\n            return_overflowing_tokens=False,\n            padding='max_length',\n            return_tensors='pt',\n        )\n        tokens = batch_encoding['input_ids'].to(self.device)\n        return tokens\n\n    @torch.no_grad()\n    def encode(self, text):\n        tokens = self(text)\n        if not self.vq_interface:\n            return tokens\n        return None, None, [None, None, tokens]\n\n    def decode(self, text):\n        return text\n\n\nclass BERTEmbedder(AbstractEncoder):\n    \"\"\"Uses the BERT tokenizr model and add some transformer encoder layers\"\"\"\n\n    def __init__(\n        self,\n        n_embed,\n        n_layer,\n        vocab_size=30522,\n        max_seq_len=77,\n        device=choose_torch_device(),\n        use_tokenizer=True,\n        embedding_dropout=0.0,\n    ):\n        super().__init__()\n        self.use_tknz_fn = use_tokenizer\n        if self.use_tknz_fn:\n            self.tknz_fn = BERTTokenizer(\n                vq_interface=False, max_length=max_seq_len\n            )\n        self.device = device\n        self.transformer = TransformerWrapper(\n            num_tokens=vocab_size,\n            max_seq_len=max_seq_len,\n            attn_layers=Encoder(dim=n_embed, depth=n_layer),\n            emb_dropout=embedding_dropout,\n        )\n\n    def forward(self, text, embedding_manager=None):\n        if self.use_tknz_fn:\n            tokens = self.tknz_fn(text)  # .to(self.device)\n        else:\n            tokens = text\n        z = self.transformer(\n            tokens, return_embeddings=True, embedding_manager=embedding_manager\n        )\n        return z\n\n    def encode(self, text, **kwargs):\n        # output of length 77\n        return self(text, **kwargs)\n\n\nclass SpatialRescaler(nn.Module):\n    def __init__(\n        self,\n        n_stages=1,\n        method='bilinear',\n        multiplier=0.5,\n        in_channels=3,\n        out_channels=None,\n        bias=False,\n    ):\n        super().__init__()\n        self.n_stages = n_stages\n        assert self.n_stages >= 0\n        assert method in [\n            'nearest',\n            'linear',\n            'bilinear',\n            'trilinear',\n            'bicubic',\n            'area',\n        ]\n        self.multiplier = multiplier\n        self.interpolator = partial(\n            torch.nn.functional.interpolate, mode=method\n        )\n        self.remap_output = out_channels is not None\n        if self.remap_output:\n            print(\n                f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.'\n            )\n            self.channel_mapper = nn.Conv2d(\n                in_channels, out_channels, 1, bias=bias\n            )\n\n    def forward(self, x):\n        for stage in range(self.n_stages):\n            x = self.interpolator(x, scale_factor=self.multiplier)\n\n        if self.remap_output:\n            x = self.channel_mapper(x)\n        return x\n\n    def encode(self, x):\n        return self(x)\n\n\nclass FrozenCLIPEmbedder(AbstractEncoder):\n    \"\"\"Uses the CLIP transformer encoder for text (from Hugging Face)\"\"\"\n\n    def __init__(\n        self,\n        version='openai/clip-vit-large-patch14',\n        device=choose_torch_device(),\n        max_length=77,\n    ):\n        super().__init__()\n        self.tokenizer = CLIPTokenizer.from_pretrained(\n            version, local_files_only=False\n        )\n        self.transformer = CLIPTextModel.from_pretrained(\n            version, local_files_only=False\n        )\n        self.device = device\n        self.max_length = max_length\n        self.freeze()\n\n        def embedding_forward(\n            self,\n            input_ids=None,\n            position_ids=None,\n            inputs_embeds=None,\n            embedding_manager=None,\n        ) -> torch.Tensor:\n\n            seq_length = (\n                input_ids.shape[-1]\n                if input_ids is not None\n                else inputs_embeds.shape[-2]\n            )\n\n            if position_ids is None:\n                position_ids = self.position_ids[:, :seq_length]\n\n            if inputs_embeds is None:\n                inputs_embeds = self.token_embedding(input_ids)\n\n            if embedding_manager is not None:\n                inputs_embeds = embedding_manager(input_ids, inputs_embeds)\n\n            position_embeddings = self.position_embedding(position_ids)\n            embeddings = inputs_embeds + position_embeddings\n\n            return embeddings\n\n        self.transformer.text_model.embeddings.forward = (\n            embedding_forward.__get__(self.transformer.text_model.embeddings)\n        )\n\n        def encoder_forward(\n            self,\n            inputs_embeds,\n            attention_mask=None,\n            causal_attention_mask=None,\n            output_attentions=None,\n            output_hidden_states=None,\n            return_dict=None,\n        ):\n            output_attentions = (\n                output_attentions\n                if output_attentions is not None\n                else self.config.output_attentions\n            )\n            output_hidden_states = (\n                output_hidden_states\n                if output_hidden_states is not None\n                else self.config.output_hidden_states\n            )\n            return_dict = (\n                return_dict\n                if return_dict is not None\n                else self.config.use_return_dict\n            )\n\n            encoder_states = () if output_hidden_states else None\n            all_attentions = () if output_attentions else None\n\n            hidden_states = inputs_embeds\n            for idx, encoder_layer in enumerate(self.layers):\n                if output_hidden_states:\n                    encoder_states = encoder_states + (hidden_states,)\n\n                layer_outputs = encoder_layer(\n                    hidden_states,\n                    attention_mask,\n                    causal_attention_mask,\n                    output_attentions=output_attentions,\n                )\n\n                hidden_states = layer_outputs[0]\n\n                if output_attentions:\n                    all_attentions = all_attentions + (layer_outputs[1],)\n\n            if output_hidden_states:\n                encoder_states = encoder_states + (hidden_states,)\n\n            return hidden_states\n\n        self.transformer.text_model.encoder.forward = encoder_forward.__get__(\n            self.transformer.text_model.encoder\n        )\n\n        def text_encoder_forward(\n            self,\n            input_ids=None,\n            attention_mask=None,\n            position_ids=None,\n            output_attentions=None,\n            output_hidden_states=None,\n            return_dict=None,\n            embedding_manager=None,\n        ):\n            output_attentions = (\n                output_attentions\n                if output_attentions is not None\n                else self.config.output_attentions\n            )\n            output_hidden_states = (\n                output_hidden_states\n                if output_hidden_states is not None\n                else self.config.output_hidden_states\n            )\n            return_dict = (\n                return_dict\n                if return_dict is not None\n                else self.config.use_return_dict\n            )\n\n            if input_ids is None:\n                raise ValueError('You have to specify either input_ids')\n\n            input_shape = input_ids.size()\n            input_ids = input_ids.view(-1, input_shape[-1])\n\n            hidden_states = self.embeddings(\n                input_ids=input_ids,\n                position_ids=position_ids,\n                embedding_manager=embedding_manager,\n            )\n\n            bsz, seq_len = input_shape\n            # CLIP's text model uses causal mask, prepare it here.\n            # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324\n            causal_attention_mask = _build_causal_attention_mask(\n                bsz, seq_len, hidden_states.dtype\n            ).to(hidden_states.device)\n\n            # expand attention_mask\n            if attention_mask is not None:\n                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n                attention_mask = _expand_mask(\n                    attention_mask, hidden_states.dtype\n                )\n\n            last_hidden_state = self.encoder(\n                inputs_embeds=hidden_states,\n                attention_mask=attention_mask,\n                causal_attention_mask=causal_attention_mask,\n                output_attentions=output_attentions,\n                output_hidden_states=output_hidden_states,\n                return_dict=return_dict,\n            )\n\n            last_hidden_state = self.final_layer_norm(last_hidden_state)\n\n            return last_hidden_state\n\n        self.transformer.text_model.forward = text_encoder_forward.__get__(\n            self.transformer.text_model\n        )\n\n        def transformer_forward(\n            self,\n            input_ids=None,\n            attention_mask=None,\n            position_ids=None,\n            output_attentions=None,\n            output_hidden_states=None,\n            return_dict=None,\n            embedding_manager=None,\n        ):\n            return self.text_model(\n                input_ids=input_ids,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                output_attentions=output_attentions,\n                output_hidden_states=output_hidden_states,\n                return_dict=return_dict,\n                embedding_manager=embedding_manager,\n            )\n\n        self.transformer.forward = transformer_forward.__get__(\n            self.transformer\n        )\n\n    def freeze(self):\n        self.transformer = self.transformer.eval()\n        for param in self.parameters():\n            param.requires_grad = False\n\n    def forward(self, text, **kwargs):\n        batch_encoding = self.tokenizer(\n            text,\n            truncation=True,\n            max_length=self.max_length,\n            return_length=True,\n            return_overflowing_tokens=False,\n            padding='max_length',\n            return_tensors='pt',\n        )\n        tokens = batch_encoding['input_ids'].to(self.device)\n        z = self.transformer(input_ids=tokens, **kwargs)\n\n        return z\n\n    def encode(self, text, **kwargs):\n        return self(text, **kwargs)\n\n\nclass FrozenCLIPTextEmbedder(nn.Module):\n    \"\"\"\n    Uses the CLIP transformer encoder for text.\n    \"\"\"\n\n    def __init__(\n        self,\n        version='ViT-L/14',\n        device=choose_torch_device(),\n        max_length=77,\n        n_repeat=1,\n        normalize=True,\n    ):\n        super().__init__()\n        self.model, _ = clip.load(version, jit=False, device=device)\n        self.device = device\n        self.max_length = max_length\n        self.n_repeat = n_repeat\n        self.normalize = normalize\n\n    def freeze(self):\n        self.model = self.model.eval()\n        for param in self.parameters():\n            param.requires_grad = False\n\n    def forward(self, text):\n        tokens = clip.tokenize(text).to(self.device)\n        z = self.model.encode_text(tokens)\n        if self.normalize:\n            z = z / torch.linalg.norm(z, dim=1, keepdim=True)\n        return z\n\n    def encode(self, text):\n        z = self(text)\n        if z.ndim == 2:\n            z = z[:, None, :]\n        z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)\n        return z\n\n\nclass FrozenClipImageEmbedder(nn.Module):\n    \"\"\"\n    Uses the CLIP image encoder.\n    \"\"\"\n\n    def __init__(\n        self,\n        model,\n        jit=False,\n        device=choose_torch_device(),\n        antialias=False,\n    ):\n        super().__init__()\n        self.model, _ = clip.load(name=model, device=device, jit=jit)\n\n        self.antialias = antialias\n\n        self.register_buffer(\n            'mean',\n            torch.Tensor([0.48145466, 0.4578275, 0.40821073]),\n            persistent=False,\n        )\n        self.register_buffer(\n            'std',\n            torch.Tensor([0.26862954, 0.26130258, 0.27577711]),\n            persistent=False,\n        )\n\n    def preprocess(self, x):\n        # normalize to [0,1]\n        x = kornia.geometry.resize(\n            x,\n            (224, 224),\n            interpolation='bicubic',\n            align_corners=True,\n            antialias=self.antialias,\n        )\n        x = (x + 1.0) / 2.0\n        # renormalize according to clip\n        x = kornia.enhance.normalize(x, self.mean, self.std)\n        return x\n\n    def forward(self, x):\n        # x is assumed to be in range [-1,1]\n        return self.model.encode_image(self.preprocess(x))\n\n\nif __name__ == '__main__':\n    from src.stablediffusion.ldm.util import count_params\n\n    model = FrozenCLIPEmbedder()\n    count_params(model, verbose=True)\n"
  },
  {
    "path": "src/stablediffusion/ldm/modules/image_degradation/__init__.py",
    "content": "from src.stablediffusion.ldm.modules.image_degradation.bsrgan import (\n    degradation_bsrgan_variant as degradation_fn_bsr,\n)\nfrom src.stablediffusion.ldm.modules.image_degradation.bsrgan_light import (\n    degradation_bsrgan_variant as degradation_fn_bsr_light,\n)\n"
  },
  {
    "path": "src/stablediffusion/ldm/modules/image_degradation/bsrgan.py",
    "content": "# -*- coding: utf-8 -*-\n\"\"\"\n# --------------------------------------------\n# Super-Resolution\n# --------------------------------------------\n#\n# Kai Zhang (cskaizhang@gmail.com)\n# https://github.com/cszn\n# From 2019/03--2021/08\n# --------------------------------------------\n\"\"\"\n\nimport numpy as np\nimport cv2\nimport torch\n\nfrom functools import partial\nimport random\nfrom scipy import ndimage\nimport scipy\nimport scipy.stats as ss\nfrom scipy.interpolate import interp2d\nfrom scipy.linalg import orth\nimport albumentations\n\nimport ldm.modules.image_degradation.utils_image as util\n\n\ndef modcrop_np(img, sf):\n    \"\"\"\n    Args:\n        img: numpy image, WxH or WxHxC\n        sf: scale factor\n    Return:\n        cropped image\n    \"\"\"\n    w, h = img.shape[:2]\n    im = np.copy(img)\n    return im[: w - w % sf, : h - h % sf, ...]\n\n\n\"\"\"\n# --------------------------------------------\n# anisotropic Gaussian kernels\n# --------------------------------------------\n\"\"\"\n\n\ndef analytic_kernel(k):\n    \"\"\"Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)\"\"\"\n    k_size = k.shape[0]\n    # Calculate the big kernels size\n    big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))\n    # Loop over the small kernel to fill the big one\n    for r in range(k_size):\n        for c in range(k_size):\n            big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += (\n                k[r, c] * k\n            )\n    # Crop the edges of the big kernel to ignore very small values and increase run time of SR\n    crop = k_size // 2\n    cropped_big_k = big_k[crop:-crop, crop:-crop]\n    # Normalize to 1\n    return cropped_big_k / cropped_big_k.sum()\n\n\ndef anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):\n    \"\"\"generate an anisotropic Gaussian kernel\n    Args:\n        ksize : e.g., 15, kernel size\n        theta : [0,  pi], rotation angle range\n        l1    : [0.1,50], scaling of eigenvalues\n        l2    : [0.1,l1], scaling of eigenvalues\n        If l1 = l2, will get an isotropic Gaussian kernel.\n    Returns:\n        k     : kernel\n    \"\"\"\n\n    v = np.dot(\n        np.array(\n            [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]\n        ),\n        np.array([1.0, 0.0]),\n    )\n    V = np.array([[v[0], v[1]], [v[1], -v[0]]])\n    D = np.array([[l1, 0], [0, l2]])\n    Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))\n    k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)\n\n    return k\n\n\ndef gm_blur_kernel(mean, cov, size=15):\n    center = size / 2.0 + 0.5\n    k = np.zeros([size, size])\n    for y in range(size):\n        for x in range(size):\n            cy = y - center + 1\n            cx = x - center + 1\n            k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)\n\n    k = k / np.sum(k)\n    return k\n\n\ndef shift_pixel(x, sf, upper_left=True):\n    \"\"\"shift pixel for super-resolution with different scale factors\n    Args:\n        x: WxHxC or WxH\n        sf: scale factor\n        upper_left: shift direction\n    \"\"\"\n    h, w = x.shape[:2]\n    shift = (sf - 1) * 0.5\n    xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)\n    if upper_left:\n        x1 = xv + shift\n        y1 = yv + shift\n    else:\n        x1 = xv - shift\n        y1 = yv - shift\n\n    x1 = np.clip(x1, 0, w - 1)\n    y1 = np.clip(y1, 0, h - 1)\n\n    if x.ndim == 2:\n        x = interp2d(xv, yv, x)(x1, y1)\n    if x.ndim == 3:\n        for i in range(x.shape[-1]):\n            x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)\n\n    return x\n\n\ndef blur(x, k):\n    \"\"\"\n    x: image, NxcxHxW\n    k: kernel, Nx1xhxw\n    \"\"\"\n    n, c = x.shape[:2]\n    p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2\n    x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')\n    k = k.repeat(1, c, 1, 1)\n    k = k.view(-1, 1, k.shape[2], k.shape[3])\n    x = x.view(1, -1, x.shape[2], x.shape[3])\n    x = torch.nn.functional.conv2d(\n        x, k, bias=None, stride=1, padding=0, groups=n * c\n    )\n    x = x.view(n, c, x.shape[2], x.shape[3])\n\n    return x\n\n\ndef gen_kernel(\n    k_size=np.array([15, 15]),\n    scale_factor=np.array([4, 4]),\n    min_var=0.6,\n    max_var=10.0,\n    noise_level=0,\n):\n    \"\"\" \"\n    # modified version of https://github.com/assafshocher/BlindSR_dataset_generator\n    # Kai Zhang\n    # min_var = 0.175 * sf  # variance of the gaussian kernel will be sampled between min_var and max_var\n    # max_var = 2.5 * sf\n    \"\"\"\n    # Set random eigen-vals (lambdas) and angle (theta) for COV matrix\n    lambda_1 = min_var + np.random.rand() * (max_var - min_var)\n    lambda_2 = min_var + np.random.rand() * (max_var - min_var)\n    theta = np.random.rand() * np.pi  # random theta\n    noise = -noise_level + np.random.rand(*k_size) * noise_level * 2\n\n    # Set COV matrix using Lambdas and Theta\n    LAMBDA = np.diag([lambda_1, lambda_2])\n    Q = np.array(\n        [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]\n    )\n    SIGMA = Q @ LAMBDA @ Q.T\n    INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]\n\n    # Set expectation position (shifting kernel for aligned image)\n    MU = k_size // 2 - 0.5 * (\n        scale_factor - 1\n    )  # - 0.5 * (scale_factor - k_size % 2)\n    MU = MU[None, None, :, None]\n\n    # Create meshgrid for Gaussian\n    [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))\n    Z = np.stack([X, Y], 2)[:, :, :, None]\n\n    # Calcualte Gaussian for every pixel of the kernel\n    ZZ = Z - MU\n    ZZ_t = ZZ.transpose(0, 1, 3, 2)\n    raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)\n\n    # shift the kernel so it will be centered\n    # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)\n\n    # Normalize the kernel and return\n    # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)\n    kernel = raw_kernel / np.sum(raw_kernel)\n    return kernel\n\n\ndef fspecial_gaussian(hsize, sigma):\n    hsize = [hsize, hsize]\n    siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]\n    std = sigma\n    [x, y] = np.meshgrid(\n        np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)\n    )\n    arg = -(x * x + y * y) / (2 * std * std)\n    h = np.exp(arg)\n    h[h < scipy.finfo(float).eps * h.max()] = 0\n    sumh = h.sum()\n    if sumh != 0:\n        h = h / sumh\n    return h\n\n\ndef fspecial_laplacian(alpha):\n    alpha = max([0, min([alpha, 1])])\n    h1 = alpha / (alpha + 1)\n    h2 = (1 - alpha) / (alpha + 1)\n    h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]\n    h = np.array(h)\n    return h\n\n\ndef fspecial(filter_type, *args, **kwargs):\n    \"\"\"\n    python code from:\n    https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py\n    \"\"\"\n    if filter_type == 'gaussian':\n        return fspecial_gaussian(*args, **kwargs)\n    if filter_type == 'laplacian':\n        return fspecial_laplacian(*args, **kwargs)\n\n\n\"\"\"\n# --------------------------------------------\n# degradation models\n# --------------------------------------------\n\"\"\"\n\n\ndef bicubic_degradation(x, sf=3):\n    \"\"\"\n    Args:\n        x: HxWxC image, [0, 1]\n        sf: down-scale factor\n    Return:\n        bicubicly downsampled LR image\n    \"\"\"\n    x = util.imresize_np(x, scale=1 / sf)\n    return x\n\n\ndef srmd_degradation(x, k, sf=3):\n    \"\"\"blur + bicubic downsampling\n    Args:\n        x: HxWxC image, [0, 1]\n        k: hxw, double\n        sf: down-scale factor\n    Return:\n        downsampled LR image\n    Reference:\n        @inproceedings{zhang2018learning,\n          title={Learning a single convolutional super-resolution network for multiple degradations},\n          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},\n          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},\n          pages={3262--3271},\n          year={2018}\n        }\n    \"\"\"\n    x = ndimage.filters.convolve(\n        x, np.expand_dims(k, axis=2), mode='wrap'\n    )  # 'nearest' | 'mirror'\n    x = bicubic_degradation(x, sf=sf)\n    return x\n\n\ndef dpsr_degradation(x, k, sf=3):\n    \"\"\"bicubic downsampling + blur\n    Args:\n        x: HxWxC image, [0, 1]\n        k: hxw, double\n        sf: down-scale factor\n    Return:\n        downsampled LR image\n    Reference:\n        @inproceedings{zhang2019deep,\n          title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},\n          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},\n          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},\n          pages={1671--1681},\n          year={2019}\n        }\n    \"\"\"\n    x = bicubic_degradation(x, sf=sf)\n    x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')\n    return x\n\n\ndef classical_degradation(x, k, sf=3):\n    \"\"\"blur + downsampling\n    Args:\n        x: HxWxC image, [0, 1]/[0, 255]\n        k: hxw, double\n        sf: down-scale factor\n    Return:\n        downsampled LR image\n    \"\"\"\n    x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')\n    # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))\n    st = 0\n    return x[st::sf, st::sf, ...]\n\n\ndef add_sharpening(img, weight=0.5, radius=50, threshold=10):\n    \"\"\"USM sharpening. borrowed from real-ESRGAN\n    Input image: I; Blurry image: B.\n    1. K = I + weight * (I - B)\n    2. Mask = 1 if abs(I - B) > threshold, else: 0\n    3. Blur mask:\n    4. Out = Mask * K + (1 - Mask) * I\n    Args:\n        img (Numpy array): Input image, HWC, BGR; float32, [0, 1].\n        weight (float): Sharp weight. Default: 1.\n        radius (float): Kernel size of Gaussian blur. Default: 50.\n        threshold (int):\n    \"\"\"\n    if radius % 2 == 0:\n        radius += 1\n    blur = cv2.GaussianBlur(img, (radius, radius), 0)\n    residual = img - blur\n    mask = np.abs(residual) * 255 > threshold\n    mask = mask.astype('float32')\n    soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)\n\n    K = img + weight * residual\n    K = np.clip(K, 0, 1)\n    return soft_mask * K + (1 - soft_mask) * img\n\n\ndef add_blur(img, sf=4):\n    wd2 = 4.0 + sf\n    wd = 2.0 + 0.2 * sf\n    if random.random() < 0.5:\n        l1 = wd2 * random.random()\n        l2 = wd2 * random.random()\n        k = anisotropic_Gaussian(\n            ksize=2 * random.randint(2, 11) + 3,\n            theta=random.random() * np.pi,\n            l1=l1,\n            l2=l2,\n        )\n    else:\n        k = fspecial(\n            'gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()\n        )\n    img = ndimage.filters.convolve(\n        img, np.expand_dims(k, axis=2), mode='mirror'\n    )\n\n    return img\n\n\ndef add_resize(img, sf=4):\n    rnum = np.random.rand()\n    if rnum > 0.8:  # up\n        sf1 = random.uniform(1, 2)\n    elif rnum < 0.7:  # down\n        sf1 = random.uniform(0.5 / sf, 1)\n    else:\n        sf1 = 1.0\n    img = cv2.resize(\n        img,\n        (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),\n        interpolation=random.choice([1, 2, 3]),\n    )\n    img = np.clip(img, 0.0, 1.0)\n\n    return img\n\n\n# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):\n#     noise_level = random.randint(noise_level1, noise_level2)\n#     rnum = np.random.rand()\n#     if rnum > 0.6:  # add color Gaussian noise\n#         img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)\n#     elif rnum < 0.4:  # add grayscale Gaussian noise\n#         img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)\n#     else:  # add  noise\n#         L = noise_level2 / 255.\n#         D = np.diag(np.random.rand(3))\n#         U = orth(np.random.rand(3, 3))\n#         conv = np.dot(np.dot(np.transpose(U), D), U)\n#         img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)\n#     img = np.clip(img, 0.0, 1.0)\n#     return img\n\n\ndef add_Gaussian_noise(img, noise_level1=2, noise_level2=25):\n    noise_level = random.randint(noise_level1, noise_level2)\n    rnum = np.random.rand()\n    if rnum > 0.6:  # add color Gaussian noise\n        img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(\n            np.float32\n        )\n    elif rnum < 0.4:  # add grayscale Gaussian noise\n        img = img + np.random.normal(\n            0, noise_level / 255.0, (*img.shape[:2], 1)\n        ).astype(np.float32)\n    else:  # add  noise\n        L = noise_level2 / 255.0\n        D = np.diag(np.random.rand(3))\n        U = orth(np.random.rand(3, 3))\n        conv = np.dot(np.dot(np.transpose(U), D), U)\n        img = img + np.random.multivariate_normal(\n            [0, 0, 0], np.abs(L**2 * conv), img.shape[:2]\n        ).astype(np.float32)\n    img = np.clip(img, 0.0, 1.0)\n    return img\n\n\ndef add_speckle_noise(img, noise_level1=2, noise_level2=25):\n    noise_level = random.randint(noise_level1, noise_level2)\n    img = np.clip(img, 0.0, 1.0)\n    rnum = random.random()\n    if rnum > 0.6:\n        img += img * np.random.normal(\n            0, noise_level / 255.0, img.shape\n        ).astype(np.float32)\n    elif rnum < 0.4:\n        img += img * np.random.normal(\n            0, noise_level / 255.0, (*img.shape[:2], 1)\n        ).astype(np.float32)\n    else:\n        L = noise_level2 / 255.0\n        D = np.diag(np.random.rand(3))\n        U = orth(np.random.rand(3, 3))\n        conv = np.dot(np.dot(np.transpose(U), D), U)\n        img += img * np.random.multivariate_normal(\n            [0, 0, 0], np.abs(L**2 * conv), img.shape[:2]\n        ).astype(np.float32)\n    img = np.clip(img, 0.0, 1.0)\n    return img\n\n\ndef add_Poisson_noise(img):\n    img = np.clip((img * 255.0).round(), 0, 255) / 255.0\n    vals = 10 ** (2 * random.random() + 2.0)  # [2, 4]\n    if random.random() < 0.5:\n        img = np.random.poisson(img * vals).astype(np.float32) / vals\n    else:\n        img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])\n        img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0\n        noise_gray = (\n            np.random.poisson(img_gray * vals).astype(np.float32) / vals\n            - img_gray\n        )\n        img += noise_gray[:, :, np.newaxis]\n    img = np.clip(img, 0.0, 1.0)\n    return img\n\n\ndef add_JPEG_noise(img):\n    quality_factor = random.randint(30, 95)\n    img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)\n    result, encimg = cv2.imencode(\n        '.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]\n    )\n    img = cv2.imdecode(encimg, 1)\n    img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)\n    return img\n\n\ndef random_crop(lq, hq, sf=4, lq_patchsize=64):\n    h, w = lq.shape[:2]\n    rnd_h = random.randint(0, h - lq_patchsize)\n    rnd_w = random.randint(0, w - lq_patchsize)\n    lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]\n\n    rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)\n    hq = hq[\n        rnd_h_H : rnd_h_H + lq_patchsize * sf,\n        rnd_w_H : rnd_w_H + lq_patchsize * sf,\n        :,\n    ]\n    return lq, hq\n\n\ndef degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):\n    \"\"\"\n    This is the degradation model of BSRGAN from the paper\n    \"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution\"\n    ----------\n    img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)\n    sf: scale factor\n    isp_model: camera ISP model\n    Returns\n    -------\n    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]\n    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]\n    \"\"\"\n    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25\n    sf_ori = sf\n\n    h1, w1 = img.shape[:2]\n    img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...]  # mod crop\n    h, w = img.shape[:2]\n\n    if h < lq_patchsize * sf or w < lq_patchsize * sf:\n        raise ValueError(f'img size ({h1}X{w1}) is too small!')\n\n    hq = img.copy()\n\n    if sf == 4 and random.random() < scale2_prob:  # downsample1\n        if np.random.rand() < 0.5:\n            img = cv2.resize(\n                img,\n                (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),\n                interpolation=random.choice([1, 2, 3]),\n            )\n        else:\n            img = util.imresize_np(img, 1 / 2, True)\n        img = np.clip(img, 0.0, 1.0)\n        sf = 2\n\n    shuffle_order = random.sample(range(7), 7)\n    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)\n    if idx1 > idx2:  # keep downsample3 last\n        shuffle_order[idx1], shuffle_order[idx2] = (\n            shuffle_order[idx2],\n            shuffle_order[idx1],\n        )\n\n    for i in shuffle_order:\n\n        if i == 0:\n            img = add_blur(img, sf=sf)\n\n        elif i == 1:\n            img = add_blur(img, sf=sf)\n\n        elif i == 2:\n            a, b = img.shape[1], img.shape[0]\n            # downsample2\n            if random.random() < 0.75:\n                sf1 = random.uniform(1, 2 * sf)\n                img = cv2.resize(\n                    img,\n                    (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),\n                    interpolation=random.choice([1, 2, 3]),\n                )\n            else:\n                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))\n                k_shifted = shift_pixel(k, sf)\n                k_shifted = (\n                    k_shifted / k_shifted.sum()\n                )  # blur with shifted kernel\n                img = ndimage.filters.convolve(\n                    img, np.expand_dims(k_shifted, axis=2), mode='mirror'\n                )\n                img = img[0::sf, 0::sf, ...]  # nearest downsampling\n            img = np.clip(img, 0.0, 1.0)\n\n        elif i == 3:\n            # downsample3\n            img = cv2.resize(\n                img,\n                (int(1 / sf * a), int(1 / sf * b)),\n                interpolation=random.choice([1, 2, 3]),\n            )\n            img = np.clip(img, 0.0, 1.0)\n\n        elif i == 4:\n            # add Gaussian noise\n            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)\n\n        elif i == 5:\n            # add JPEG noise\n            if random.random() < jpeg_prob:\n                img = add_JPEG_noise(img)\n\n        elif i == 6:\n            # add processed camera sensor noise\n            if random.random() < isp_prob and isp_model is not None:\n                with torch.no_grad():\n                    img, hq = isp_model.forward(img.copy(), hq)\n\n    # add final JPEG compression noise\n    img = add_JPEG_noise(img)\n\n    # random crop\n    img, hq = random_crop(img, hq, sf_ori, lq_patchsize)\n\n    return img, hq\n\n\n# todo no isp_model?\ndef degradation_bsrgan_variant(image, sf=4, isp_model=None):\n    \"\"\"\n    This is the degradation model of BSRGAN from the paper\n    \"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution\"\n    ----------\n    sf: scale factor\n    isp_model: camera ISP model\n    Returns\n    -------\n    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]\n    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]\n    \"\"\"\n    image = util.uint2single(image)\n    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25\n    sf_ori = sf\n\n    h1, w1 = image.shape[:2]\n    image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...]  # mod crop\n    h, w = image.shape[:2]\n\n    hq = image.copy()\n\n    if sf == 4 and random.random() < scale2_prob:  # downsample1\n        if np.random.rand() < 0.5:\n            image = cv2.resize(\n                image,\n                (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),\n                interpolation=random.choice([1, 2, 3]),\n            )\n        else:\n            image = util.imresize_np(image, 1 / 2, True)\n        image = np.clip(image, 0.0, 1.0)\n        sf = 2\n\n    shuffle_order = random.sample(range(7), 7)\n    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)\n    if idx1 > idx2:  # keep downsample3 last\n        shuffle_order[idx1], shuffle_order[idx2] = (\n            shuffle_order[idx2],\n            shuffle_order[idx1],\n        )\n\n    for i in shuffle_order:\n\n        if i == 0:\n            image = add_blur(image, sf=sf)\n\n        elif i == 1:\n            image = add_blur(image, sf=sf)\n\n        elif i == 2:\n            a, b = image.shape[1], image.shape[0]\n            # downsample2\n            if random.random() < 0.75:\n                sf1 = random.uniform(1, 2 * sf)\n                image = cv2.resize(\n                    image,\n                    (\n                        int(1 / sf1 * image.shape[1]),\n                        int(1 / sf1 * image.shape[0]),\n                    ),\n                    interpolation=random.choice([1, 2, 3]),\n                )\n            else:\n                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))\n                k_shifted = shift_pixel(k, sf)\n                k_shifted = (\n                    k_shifted / k_shifted.sum()\n                )  # blur with shifted kernel\n                image = ndimage.filters.convolve(\n                    image, np.expand_dims(k_shifted, axis=2), mode='mirror'\n                )\n                image = image[0::sf, 0::sf, ...]  # nearest downsampling\n            image = np.clip(image, 0.0, 1.0)\n\n        elif i == 3:\n            # downsample3\n            image = cv2.resize(\n                image,\n                (int(1 / sf * a), int(1 / sf * b)),\n                interpolation=random.choice([1, 2, 3]),\n            )\n            image = np.clip(image, 0.0, 1.0)\n\n        elif i == 4:\n            # add Gaussian noise\n            image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)\n\n        elif i == 5:\n            # add JPEG noise\n            if random.random() < jpeg_prob:\n                image = add_JPEG_noise(image)\n\n        # elif i == 6:\n        #     # add processed camera sensor noise\n        #     if random.random() < isp_prob and isp_model is not None:\n        #         with torch.no_grad():\n        #             img, hq = isp_model.forward(img.copy(), hq)\n\n    # add final JPEG compression noise\n    image = add_JPEG_noise(image)\n    image = util.single2uint(image)\n    example = {'image': image}\n    return example\n\n\n# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...\ndef degradation_bsrgan_plus(\n    img,\n    sf=4,\n    shuffle_prob=0.5,\n    use_sharp=True,\n    lq_patchsize=64,\n    isp_model=None,\n):\n    \"\"\"\n    This is an extended degradation model by combining\n    the degradation models of BSRGAN and Real-ESRGAN\n    ----------\n    img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)\n    sf: scale factor\n    use_shuffle: the degradation shuffle\n    use_sharp: sharpening the img\n    Returns\n    -------\n    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]\n    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]\n    \"\"\"\n\n    h1, w1 = img.shape[:2]\n    img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...]  # mod crop\n    h, w = img.shape[:2]\n\n    if h < lq_patchsize * sf or w < lq_patchsize * sf:\n        raise ValueError(f'img size ({h1}X{w1}) is too small!')\n\n    if use_sharp:\n        img = add_sharpening(img)\n    hq = img.copy()\n\n    if random.random() < shuffle_prob:\n        shuffle_order = random.sample(range(13), 13)\n    else:\n        shuffle_order = list(range(13))\n        # local shuffle for noise, JPEG is always the last one\n        shuffle_order[2:6] = random.sample(\n            shuffle_order[2:6], len(range(2, 6))\n        )\n        shuffle_order[9:13] = random.sample(\n            shuffle_order[9:13], len(range(9, 13))\n        )\n\n    poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1\n\n    for i in shuffle_order:\n        if i == 0:\n            img = add_blur(img, sf=sf)\n        elif i == 1:\n            img = add_resize(img, sf=sf)\n        elif i == 2:\n            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)\n        elif i == 3:\n            if random.random() < poisson_prob:\n                img = add_Poisson_noise(img)\n        elif i == 4:\n            if random.random() < speckle_prob:\n                img = add_speckle_noise(img)\n        elif i == 5:\n            if random.random() < isp_prob and isp_model is not None:\n                with torch.no_grad():\n                    img, hq = isp_model.forward(img.copy(), hq)\n        elif i == 6:\n            img = add_JPEG_noise(img)\n        elif i == 7:\n            img = add_blur(img, sf=sf)\n        elif i == 8:\n            img = add_resize(img, sf=sf)\n        elif i == 9:\n            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)\n        elif i == 10:\n            if random.random() < poisson_prob:\n                img = add_Poisson_noise(img)\n        elif i == 11:\n            if random.random() < speckle_prob:\n                img = add_speckle_noise(img)\n        elif i == 12:\n            if random.random() < isp_prob and isp_model is not None:\n                with torch.no_grad():\n                    img, hq = isp_model.forward(img.copy(), hq)\n        else:\n            print('check the shuffle!')\n\n    # resize to desired size\n    img = cv2.resize(\n        img,\n        (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),\n        interpolation=random.choice([1, 2, 3]),\n    )\n\n    # add final JPEG compression noise\n    img = add_JPEG_noise(img)\n\n    # random crop\n    img, hq = random_crop(img, hq, sf, lq_patchsize)\n\n    return img, hq\n\n\nif __name__ == '__main__':\n    print('hey')\n    img = util.imread_uint('utils/test.png', 3)\n    print(img)\n    img = util.uint2single(img)\n    print(img)\n    img = img[:448, :448]\n    h = img.shape[0] // 4\n    print('resizing to', h)\n    sf = 4\n    deg_fn = partial(degradation_bsrgan_variant, sf=sf)\n    for i in range(20):\n        print(i)\n        img_lq = deg_fn(img)\n        print(img_lq)\n        img_lq_bicubic = albumentations.SmallestMaxSize(\n            max_size=h, interpolation=cv2.INTER_CUBIC\n        )(image=img)['image']\n        print(img_lq.shape)\n        print('bicubic', img_lq_bicubic.shape)\n        print(img_hq.shape)\n        lq_nearest = cv2.resize(\n            util.single2uint(img_lq),\n            (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),\n            interpolation=0,\n        )\n        lq_bicubic_nearest = cv2.resize(\n            util.single2uint(img_lq_bicubic),\n            (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),\n            interpolation=0,\n        )\n        img_concat = np.concatenate(\n            [lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1\n        )\n        util.imsave(img_concat, str(i) + '.png')\n"
  },
  {
    "path": "src/stablediffusion/ldm/modules/image_degradation/bsrgan_light.py",
    "content": "# -*- coding: utf-8 -*-\nimport numpy as np\nimport cv2\nimport torch\n\nfrom functools import partial\nimport random\nfrom scipy import ndimage\nimport scipy\nimport scipy.stats as ss\nfrom scipy.interpolate import interp2d\nfrom scipy.linalg import orth\nimport albumentations\n\nimport ldm.modules.image_degradation.utils_image as util\n\n\"\"\"\n# --------------------------------------------\n# Super-Resolution\n# --------------------------------------------\n#\n# Kai Zhang (cskaizhang@gmail.com)\n# https://github.com/cszn\n# From 2019/03--2021/08\n# --------------------------------------------\n\"\"\"\n\n\ndef modcrop_np(img, sf):\n    \"\"\"\n    Args:\n        img: numpy image, WxH or WxHxC\n        sf: scale factor\n    Return:\n        cropped image\n    \"\"\"\n    w, h = img.shape[:2]\n    im = np.copy(img)\n    return im[: w - w % sf, : h - h % sf, ...]\n\n\n\"\"\"\n# --------------------------------------------\n# anisotropic Gaussian kernels\n# --------------------------------------------\n\"\"\"\n\n\ndef analytic_kernel(k):\n    \"\"\"Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)\"\"\"\n    k_size = k.shape[0]\n    # Calculate the big kernels size\n    big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))\n    # Loop over the small kernel to fill the big one\n    for r in range(k_size):\n        for c in range(k_size):\n            big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += (\n                k[r, c] * k\n            )\n    # Crop the edges of the big kernel to ignore very small values and increase run time of SR\n    crop = k_size // 2\n    cropped_big_k = big_k[crop:-crop, crop:-crop]\n    # Normalize to 1\n    return cropped_big_k / cropped_big_k.sum()\n\n\ndef anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):\n    \"\"\"generate an anisotropic Gaussian kernel\n    Args:\n        ksize : e.g., 15, kernel size\n        theta : [0,  pi], rotation angle range\n        l1    : [0.1,50], scaling of eigenvalues\n        l2    : [0.1,l1], scaling of eigenvalues\n        If l1 = l2, will get an isotropic Gaussian kernel.\n    Returns:\n        k     : kernel\n    \"\"\"\n\n    v = np.dot(\n        np.array(\n            [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]\n        ),\n        np.array([1.0, 0.0]),\n    )\n    V = np.array([[v[0], v[1]], [v[1], -v[0]]])\n    D = np.array([[l1, 0], [0, l2]])\n    Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))\n    k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)\n\n    return k\n\n\ndef gm_blur_kernel(mean, cov, size=15):\n    center = size / 2.0 + 0.5\n    k = np.zeros([size, size])\n    for y in range(size):\n        for x in range(size):\n            cy = y - center + 1\n            cx = x - center + 1\n            k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)\n\n    k = k / np.sum(k)\n    return k\n\n\ndef shift_pixel(x, sf, upper_left=True):\n    \"\"\"shift pixel for super-resolution with different scale factors\n    Args:\n        x: WxHxC or WxH\n        sf: scale factor\n        upper_left: shift direction\n    \"\"\"\n    h, w = x.shape[:2]\n    shift = (sf - 1) * 0.5\n    xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)\n    if upper_left:\n        x1 = xv + shift\n        y1 = yv + shift\n    else:\n        x1 = xv - shift\n        y1 = yv - shift\n\n    x1 = np.clip(x1, 0, w - 1)\n    y1 = np.clip(y1, 0, h - 1)\n\n    if x.ndim == 2:\n        x = interp2d(xv, yv, x)(x1, y1)\n    if x.ndim == 3:\n        for i in range(x.shape[-1]):\n            x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)\n\n    return x\n\n\ndef blur(x, k):\n    \"\"\"\n    x: image, NxcxHxW\n    k: kernel, Nx1xhxw\n    \"\"\"\n    n, c = x.shape[:2]\n    p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2\n    x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')\n    k = k.repeat(1, c, 1, 1)\n    k = k.view(-1, 1, k.shape[2], k.shape[3])\n    x = x.view(1, -1, x.shape[2], x.shape[3])\n    x = torch.nn.functional.conv2d(\n        x, k, bias=None, stride=1, padding=0, groups=n * c\n    )\n    x = x.view(n, c, x.shape[2], x.shape[3])\n\n    return x\n\n\ndef gen_kernel(\n    k_size=np.array([15, 15]),\n    scale_factor=np.array([4, 4]),\n    min_var=0.6,\n    max_var=10.0,\n    noise_level=0,\n):\n    \"\"\" \"\n    # modified version of https://github.com/assafshocher/BlindSR_dataset_generator\n    # Kai Zhang\n    # min_var = 0.175 * sf  # variance of the gaussian kernel will be sampled between min_var and max_var\n    # max_var = 2.5 * sf\n    \"\"\"\n    # Set random eigen-vals (lambdas) and angle (theta) for COV matrix\n    lambda_1 = min_var + np.random.rand() * (max_var - min_var)\n    lambda_2 = min_var + np.random.rand() * (max_var - min_var)\n    theta = np.random.rand() * np.pi  # random theta\n    noise = -noise_level + np.random.rand(*k_size) * noise_level * 2\n\n    # Set COV matrix using Lambdas and Theta\n    LAMBDA = np.diag([lambda_1, lambda_2])\n    Q = np.array(\n        [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]\n    )\n    SIGMA = Q @ LAMBDA @ Q.T\n    INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]\n\n    # Set expectation position (shifting kernel for aligned image)\n    MU = k_size // 2 - 0.5 * (\n        scale_factor - 1\n    )  # - 0.5 * (scale_factor - k_size % 2)\n    MU = MU[None, None, :, None]\n\n    # Create meshgrid for Gaussian\n    [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))\n    Z = np.stack([X, Y], 2)[:, :, :, None]\n\n    # Calcualte Gaussian for every pixel of the kernel\n    ZZ = Z - MU\n    ZZ_t = ZZ.transpose(0, 1, 3, 2)\n    raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)\n\n    # shift the kernel so it will be centered\n    # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)\n\n    # Normalize the kernel and return\n    # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)\n    kernel = raw_kernel / np.sum(raw_kernel)\n    return kernel\n\n\ndef fspecial_gaussian(hsize, sigma):\n    hsize = [hsize, hsize]\n    siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]\n    std = sigma\n    [x, y] = np.meshgrid(\n        np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)\n    )\n    arg = -(x * x + y * y) / (2 * std * std)\n    h = np.exp(arg)\n    h[h < scipy.finfo(float).eps * h.max()] = 0\n    sumh = h.sum()\n    if sumh != 0:\n        h = h / sumh\n    return h\n\n\ndef fspecial_laplacian(alpha):\n    alpha = max([0, min([alpha, 1])])\n    h1 = alpha / (alpha + 1)\n    h2 = (1 - alpha) / (alpha + 1)\n    h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]\n    h = np.array(h)\n    return h\n\n\ndef fspecial(filter_type, *args, **kwargs):\n    \"\"\"\n    python code from:\n    https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py\n    \"\"\"\n    if filter_type == 'gaussian':\n        return fspecial_gaussian(*args, **kwargs)\n    if filter_type == 'laplacian':\n        return fspecial_laplacian(*args, **kwargs)\n\n\n\"\"\"\n# --------------------------------------------\n# degradation models\n# --------------------------------------------\n\"\"\"\n\n\ndef bicubic_degradation(x, sf=3):\n    \"\"\"\n    Args:\n        x: HxWxC image, [0, 1]\n        sf: down-scale factor\n    Return:\n        bicubicly downsampled LR image\n    \"\"\"\n    x = util.imresize_np(x, scale=1 / sf)\n    return x\n\n\ndef srmd_degradation(x, k, sf=3):\n    \"\"\"blur + bicubic downsampling\n    Args:\n        x: HxWxC image, [0, 1]\n        k: hxw, double\n        sf: down-scale factor\n    Return:\n        downsampled LR image\n    Reference:\n        @inproceedings{zhang2018learning,\n          title={Learning a single convolutional super-resolution network for multiple degradations},\n          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},\n          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},\n          pages={3262--3271},\n          year={2018}\n        }\n    \"\"\"\n    x = ndimage.filters.convolve(\n        x, np.expand_dims(k, axis=2), mode='wrap'\n    )  # 'nearest' | 'mirror'\n    x = bicubic_degradation(x, sf=sf)\n    return x\n\n\ndef dpsr_degradation(x, k, sf=3):\n    \"\"\"bicubic downsampling + blur\n    Args:\n        x: HxWxC image, [0, 1]\n        k: hxw, double\n        sf: down-scale factor\n    Return:\n        downsampled LR image\n    Reference:\n        @inproceedings{zhang2019deep,\n          title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},\n          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},\n          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},\n          pages={1671--1681},\n          year={2019}\n        }\n    \"\"\"\n    x = bicubic_degradation(x, sf=sf)\n    x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')\n    return x\n\n\ndef classical_degradation(x, k, sf=3):\n    \"\"\"blur + downsampling\n    Args:\n        x: HxWxC image, [0, 1]/[0, 255]\n        k: hxw, double\n        sf: down-scale factor\n    Return:\n        downsampled LR image\n    \"\"\"\n    x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')\n    # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))\n    st = 0\n    return x[st::sf, st::sf, ...]\n\n\ndef add_sharpening(img, weight=0.5, radius=50, threshold=10):\n    \"\"\"USM sharpening. borrowed from real-ESRGAN\n    Input image: I; Blurry image: B.\n    1. K = I + weight * (I - B)\n    2. Mask = 1 if abs(I - B) > threshold, else: 0\n    3. Blur mask:\n    4. Out = Mask * K + (1 - Mask) * I\n    Args:\n        img (Numpy array): Input image, HWC, BGR; float32, [0, 1].\n        weight (float): Sharp weight. Default: 1.\n        radius (float): Kernel size of Gaussian blur. Default: 50.\n        threshold (int):\n    \"\"\"\n    if radius % 2 == 0:\n        radius += 1\n    blur = cv2.GaussianBlur(img, (radius, radius), 0)\n    residual = img - blur\n    mask = np.abs(residual) * 255 > threshold\n    mask = mask.astype('float32')\n    soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)\n\n    K = img + weight * residual\n    K = np.clip(K, 0, 1)\n    return soft_mask * K + (1 - soft_mask) * img\n\n\ndef add_blur(img, sf=4):\n    wd2 = 4.0 + sf\n    wd = 2.0 + 0.2 * sf\n\n    wd2 = wd2 / 4\n    wd = wd / 4\n\n    if random.random() < 0.5:\n        l1 = wd2 * random.random()\n        l2 = wd2 * random.random()\n        k = anisotropic_Gaussian(\n            ksize=random.randint(2, 11) + 3,\n            theta=random.random() * np.pi,\n            l1=l1,\n            l2=l2,\n        )\n    else:\n        k = fspecial(\n            'gaussian', random.randint(2, 4) + 3, wd * random.random()\n        )\n    img = ndimage.filters.convolve(\n        img, np.expand_dims(k, axis=2), mode='mirror'\n    )\n\n    return img\n\n\ndef add_resize(img, sf=4):\n    rnum = np.random.rand()\n    if rnum > 0.8:  # up\n        sf1 = random.uniform(1, 2)\n    elif rnum < 0.7:  # down\n        sf1 = random.uniform(0.5 / sf, 1)\n    else:\n        sf1 = 1.0\n    img = cv2.resize(\n        img,\n        (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),\n        interpolation=random.choice([1, 2, 3]),\n    )\n    img = np.clip(img, 0.0, 1.0)\n\n    return img\n\n\n# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):\n#     noise_level = random.randint(noise_level1, noise_level2)\n#     rnum = np.random.rand()\n#     if rnum > 0.6:  # add color Gaussian noise\n#         img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)\n#     elif rnum < 0.4:  # add grayscale Gaussian noise\n#         img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)\n#     else:  # add  noise\n#         L = noise_level2 / 255.\n#         D = np.diag(np.random.rand(3))\n#         U = orth(np.random.rand(3, 3))\n#         conv = np.dot(np.dot(np.transpose(U), D), U)\n#         img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)\n#     img = np.clip(img, 0.0, 1.0)\n#     return img\n\n\ndef add_Gaussian_noise(img, noise_level1=2, noise_level2=25):\n    noise_level = random.randint(noise_level1, noise_level2)\n    rnum = np.random.rand()\n    if rnum > 0.6:  # add color Gaussian noise\n        img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(\n            np.float32\n        )\n    elif rnum < 0.4:  # add grayscale Gaussian noise\n        img = img + np.random.normal(\n            0, noise_level / 255.0, (*img.shape[:2], 1)\n        ).astype(np.float32)\n    else:  # add  noise\n        L = noise_level2 / 255.0\n        D = np.diag(np.random.rand(3))\n        U = orth(np.random.rand(3, 3))\n        conv = np.dot(np.dot(np.transpose(U), D), U)\n        img = img + np.random.multivariate_normal(\n            [0, 0, 0], np.abs(L**2 * conv), img.shape[:2]\n        ).astype(np.float32)\n    img = np.clip(img, 0.0, 1.0)\n    return img\n\n\ndef add_speckle_noise(img, noise_level1=2, noise_level2=25):\n    noise_level = random.randint(noise_level1, noise_level2)\n    img = np.clip(img, 0.0, 1.0)\n    rnum = random.random()\n    if rnum > 0.6:\n        img += img * np.random.normal(\n            0, noise_level / 255.0, img.shape\n        ).astype(np.float32)\n    elif rnum < 0.4:\n        img += img * np.random.normal(\n            0, noise_level / 255.0, (*img.shape[:2], 1)\n        ).astype(np.float32)\n    else:\n        L = noise_level2 / 255.0\n        D = np.diag(np.random.rand(3))\n        U = orth(np.random.rand(3, 3))\n        conv = np.dot(np.dot(np.transpose(U), D), U)\n        img += img * np.random.multivariate_normal(\n            [0, 0, 0], np.abs(L**2 * conv), img.shape[:2]\n        ).astype(np.float32)\n    img = np.clip(img, 0.0, 1.0)\n    return img\n\n\ndef add_Poisson_noise(img):\n    img = np.clip((img * 255.0).round(), 0, 255) / 255.0\n    vals = 10 ** (2 * random.random() + 2.0)  # [2, 4]\n    if random.random() < 0.5:\n        img = np.random.poisson(img * vals).astype(np.float32) / vals\n    else:\n        img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])\n        img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0\n        noise_gray = (\n            np.random.poisson(img_gray * vals).astype(np.float32) / vals\n            - img_gray\n        )\n        img += noise_gray[:, :, np.newaxis]\n    img = np.clip(img, 0.0, 1.0)\n    return img\n\n\ndef add_JPEG_noise(img):\n    quality_factor = random.randint(80, 95)\n    img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)\n    result, encimg = cv2.imencode(\n        '.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]\n    )\n    img = cv2.imdecode(encimg, 1)\n    img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)\n    return img\n\n\ndef random_crop(lq, hq, sf=4, lq_patchsize=64):\n    h, w = lq.shape[:2]\n    rnd_h = random.randint(0, h - lq_patchsize)\n    rnd_w = random.randint(0, w - lq_patchsize)\n    lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]\n\n    rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)\n    hq = hq[\n        rnd_h_H : rnd_h_H + lq_patchsize * sf,\n        rnd_w_H : rnd_w_H + lq_patchsize * sf,\n        :,\n    ]\n    return lq, hq\n\n\ndef degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):\n    \"\"\"\n    This is the degradation model of BSRGAN from the paper\n    \"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution\"\n    ----------\n    img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)\n    sf: scale factor\n    isp_model: camera ISP model\n    Returns\n    -------\n    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]\n    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]\n    \"\"\"\n    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25\n    sf_ori = sf\n\n    h1, w1 = img.shape[:2]\n    img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...]  # mod crop\n    h, w = img.shape[:2]\n\n    if h < lq_patchsize * sf or w < lq_patchsize * sf:\n        raise ValueError(f'img size ({h1}X{w1}) is too small!')\n\n    hq = img.copy()\n\n    if sf == 4 and random.random() < scale2_prob:  # downsample1\n        if np.random.rand() < 0.5:\n            img = cv2.resize(\n                img,\n                (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),\n                interpolation=random.choice([1, 2, 3]),\n            )\n        else:\n            img = util.imresize_np(img, 1 / 2, True)\n        img = np.clip(img, 0.0, 1.0)\n        sf = 2\n\n    shuffle_order = random.sample(range(7), 7)\n    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)\n    if idx1 > idx2:  # keep downsample3 last\n        shuffle_order[idx1], shuffle_order[idx2] = (\n            shuffle_order[idx2],\n            shuffle_order[idx1],\n        )\n\n    for i in shuffle_order:\n\n        if i == 0:\n            img = add_blur(img, sf=sf)\n\n        elif i == 1:\n            img = add_blur(img, sf=sf)\n\n        elif i == 2:\n            a, b = img.shape[1], img.shape[0]\n            # downsample2\n            if random.random() < 0.75:\n                sf1 = random.uniform(1, 2 * sf)\n                img = cv2.resize(\n                    img,\n                    (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),\n                    interpolation=random.choice([1, 2, 3]),\n                )\n            else:\n                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))\n                k_shifted = shift_pixel(k, sf)\n                k_shifted = (\n                    k_shifted / k_shifted.sum()\n                )  # blur with shifted kernel\n                img = ndimage.filters.convolve(\n                    img, np.expand_dims(k_shifted, axis=2), mode='mirror'\n                )\n                img = img[0::sf, 0::sf, ...]  # nearest downsampling\n            img = np.clip(img, 0.0, 1.0)\n\n        elif i == 3:\n            # downsample3\n            img = cv2.resize(\n                img,\n                (int(1 / sf * a), int(1 / sf * b)),\n                interpolation=random.choice([1, 2, 3]),\n            )\n            img = np.clip(img, 0.0, 1.0)\n\n        elif i == 4:\n            # add Gaussian noise\n            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)\n\n        elif i == 5:\n            # add JPEG noise\n            if random.random() < jpeg_prob:\n                img = add_JPEG_noise(img)\n\n        elif i == 6:\n            # add processed camera sensor noise\n            if random.random() < isp_prob and isp_model is not None:\n                with torch.no_grad():\n                    img, hq = isp_model.forward(img.copy(), hq)\n\n    # add final JPEG compression noise\n    img = add_JPEG_noise(img)\n\n    # random crop\n    img, hq = random_crop(img, hq, sf_ori, lq_patchsize)\n\n    return img, hq\n\n\n# todo no isp_model?\ndef degradation_bsrgan_variant(image, sf=4, isp_model=None):\n    \"\"\"\n    This is the degradation model of BSRGAN from the paper\n    \"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution\"\n    ----------\n    sf: scale factor\n    isp_model: camera ISP model\n    Returns\n    -------\n    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]\n    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]\n    \"\"\"\n    image = util.uint2single(image)\n    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25\n    sf_ori = sf\n\n    h1, w1 = image.shape[:2]\n    image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...]  # mod crop\n    h, w = image.shape[:2]\n\n    hq = image.copy()\n\n    if sf == 4 and random.random() < scale2_prob:  # downsample1\n        if np.random.rand() < 0.5:\n            image = cv2.resize(\n                image,\n                (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),\n                interpolation=random.choice([1, 2, 3]),\n            )\n        else:\n            image = util.imresize_np(image, 1 / 2, True)\n        image = np.clip(image, 0.0, 1.0)\n        sf = 2\n\n    shuffle_order = random.sample(range(7), 7)\n    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)\n    if idx1 > idx2:  # keep downsample3 last\n        shuffle_order[idx1], shuffle_order[idx2] = (\n            shuffle_order[idx2],\n            shuffle_order[idx1],\n        )\n\n    for i in shuffle_order:\n\n        if i == 0:\n            image = add_blur(image, sf=sf)\n\n        # elif i == 1:\n        #     image = add_blur(image, sf=sf)\n\n        if i == 0:\n            pass\n\n        elif i == 2:\n            a, b = image.shape[1], image.shape[0]\n            # downsample2\n            if random.random() < 0.8:\n                sf1 = random.uniform(1, 2 * sf)\n                image = cv2.resize(\n                    image,\n                    (\n                        int(1 / sf1 * image.shape[1]),\n                        int(1 / sf1 * image.shape[0]),\n                    ),\n                    interpolation=random.choice([1, 2, 3]),\n                )\n            else:\n                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))\n                k_shifted = shift_pixel(k, sf)\n                k_shifted = (\n                    k_shifted / k_shifted.sum()\n                )  # blur with shifted kernel\n                image = ndimage.filters.convolve(\n                    image, np.expand_dims(k_shifted, axis=2), mode='mirror'\n                )\n                image = image[0::sf, 0::sf, ...]  # nearest downsampling\n\n            image = np.clip(image, 0.0, 1.0)\n\n        elif i == 3:\n            # downsample3\n            image = cv2.resize(\n                image,\n                (int(1 / sf * a), int(1 / sf * b)),\n                interpolation=random.choice([1, 2, 3]),\n            )\n            image = np.clip(image, 0.0, 1.0)\n\n        elif i == 4:\n            # add Gaussian noise\n            image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)\n\n        elif i == 5:\n            # add JPEG noise\n            if random.random() < jpeg_prob:\n                image = add_JPEG_noise(image)\n        #\n        # elif i == 6:\n        #     # add processed camera sensor noise\n        #     if random.random() < isp_prob and isp_model is not None:\n        #         with torch.no_grad():\n        #             img, hq = isp_model.forward(img.copy(), hq)\n\n    # add final JPEG compression noise\n    image = add_JPEG_noise(image)\n    image = util.single2uint(image)\n    example = {'image': image}\n    return example\n\n\nif __name__ == '__main__':\n    print('hey')\n    img = util.imread_uint('utils/test.png', 3)\n    img = img[:448, :448]\n    h = img.shape[0] // 4\n    print('resizing to', h)\n    sf = 4\n    deg_fn = partial(degradation_bsrgan_variant, sf=sf)\n    for i in range(20):\n        print(i)\n        img_hq = img\n        img_lq = deg_fn(img)['image']\n        img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)\n        print(img_lq)\n        img_lq_bicubic = albumentations.SmallestMaxSize(\n            max_size=h, interpolation=cv2.INTER_CUBIC\n        )(image=img_hq)['image']\n        print(img_lq.shape)\n        print('bicubic', img_lq_bicubic.shape)\n        print(img_hq.shape)\n        lq_nearest = cv2.resize(\n            util.single2uint(img_lq),\n            (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),\n            interpolation=0,\n        )\n        lq_bicubic_nearest = cv2.resize(\n            util.single2uint(img_lq_bicubic),\n            (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),\n            interpolation=0,\n        )\n        img_concat = np.concatenate(\n            [lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1\n        )\n        util.imsave(img_concat, str(i) + '.png')\n"
  },
  {
    "path": "src/stablediffusion/ldm/modules/image_degradation/utils_image.py",
    "content": "import os\nimport math\nimport random\nimport numpy as np\nimport torch\nimport cv2\nfrom torchvision.utils import make_grid\nfrom datetime import datetime\n\n# import matplotlib.pyplot as plt   # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py\n\n\nos.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'\n\n\n\"\"\"\n# --------------------------------------------\n# Kai Zhang (github: https://github.com/cszn)\n# 03/Mar/2019\n# --------------------------------------------\n# https://github.com/twhui/SRGAN-pyTorch\n# https://github.com/xinntao/BasicSR\n# --------------------------------------------\n\"\"\"\n\n\nIMG_EXTENSIONS = [\n    '.jpg',\n    '.JPG',\n    '.jpeg',\n    '.JPEG',\n    '.png',\n    '.PNG',\n    '.ppm',\n    '.PPM',\n    '.bmp',\n    '.BMP',\n    '.tif',\n]\n\n\ndef is_image_file(filename):\n    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)\n\n\ndef get_timestamp():\n    return datetime.now().strftime('%y%m%d-%H%M%S')\n\n\ndef imshow(x, title=None, cbar=False, figsize=None):\n    plt.figure(figsize=figsize)\n    plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')\n    if title:\n        plt.title(title)\n    if cbar:\n        plt.colorbar()\n    plt.show()\n\n\ndef surf(Z, cmap='rainbow', figsize=None):\n    plt.figure(figsize=figsize)\n    ax3 = plt.axes(projection='3d')\n\n    w, h = Z.shape[:2]\n    xx = np.arange(0, w, 1)\n    yy = np.arange(0, h, 1)\n    X, Y = np.meshgrid(xx, yy)\n    ax3.plot_surface(X, Y, Z, cmap=cmap)\n    # ax3.contour(X,Y,Z, zdim='z',offset=-2，cmap=cmap)\n    plt.show()\n\n\n\"\"\"\n# --------------------------------------------\n# get image pathes\n# --------------------------------------------\n\"\"\"\n\n\ndef get_image_paths(dataroot):\n    paths = None  # return None if dataroot is None\n    if dataroot is not None:\n        paths = sorted(_get_paths_from_images(dataroot))\n    return paths\n\n\ndef _get_paths_from_images(path):\n    assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)\n    images = []\n    for dirpath, _, fnames in sorted(os.walk(path)):\n        for fname in sorted(fnames):\n            if is_image_file(fname):\n                img_path = os.path.join(dirpath, fname)\n                images.append(img_path)\n    assert images, '{:s} has no valid image file'.format(path)\n    return images\n\n\n\"\"\"\n# --------------------------------------------\n# split large images into small images \n# --------------------------------------------\n\"\"\"\n\n\ndef patches_from_image(img, p_size=512, p_overlap=64, p_max=800):\n    w, h = img.shape[:2]\n    patches = []\n    if w > p_max and h > p_max:\n        w1 = list(np.arange(0, w - p_size, p_size - p_overlap, dtype=np.int))\n        h1 = list(np.arange(0, h - p_size, p_size - p_overlap, dtype=np.int))\n        w1.append(w - p_size)\n        h1.append(h - p_size)\n        #        print(w1)\n        #        print(h1)\n        for i in w1:\n            for j in h1:\n                patches.append(img[i : i + p_size, j : j + p_size, :])\n    else:\n        patches.append(img)\n\n    return patches\n\n\ndef imssave(imgs, img_path):\n    \"\"\"\n    imgs: list, N images of size WxHxC\n    \"\"\"\n    img_name, ext = os.path.splitext(os.path.basename(img_path))\n\n    for i, img in enumerate(imgs):\n        if img.ndim == 3:\n            img = img[:, :, [2, 1, 0]]\n        new_path = os.path.join(\n            os.path.dirname(img_path),\n            img_name + str('_s{:04d}'.format(i)) + '.png',\n        )\n        cv2.imwrite(new_path, img)\n\n\ndef split_imageset(\n    original_dataroot,\n    taget_dataroot,\n    n_channels=3,\n    p_size=800,\n    p_overlap=96,\n    p_max=1000,\n):\n    \"\"\"\n    split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),\n    and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)\n    will be splitted.\n    Args:\n        original_dataroot:\n        taget_dataroot:\n        p_size: size of small images\n        p_overlap: patch size in training is a good choice\n        p_max: images with smaller size than (p_max)x(p_max) keep unchanged.\n    \"\"\"\n    paths = get_image_paths(original_dataroot)\n    for img_path in paths:\n        # img_name, ext = os.path.splitext(os.path.basename(img_path))\n        img = imread_uint(img_path, n_channels=n_channels)\n        patches = patches_from_image(img, p_size, p_overlap, p_max)\n        imssave(\n            patches, os.path.join(taget_dataroot, os.path.basename(img_path))\n        )\n        # if original_dataroot == taget_dataroot:\n        # del img_path\n\n\n\"\"\"\n# --------------------------------------------\n# makedir\n# --------------------------------------------\n\"\"\"\n\n\ndef mkdir(path):\n    if not os.path.exists(path):\n        os.makedirs(path)\n\n\ndef mkdirs(paths):\n    if isinstance(paths, str):\n        mkdir(paths)\n    else:\n        for path in paths:\n            mkdir(path)\n\n\ndef mkdir_and_rename(path):\n    if os.path.exists(path):\n        new_name = path + '_archived_' + get_timestamp()\n        print('Path already exists. Rename it to [{:s}]'.format(new_name))\n        os.rename(path, new_name)\n    os.makedirs(path)\n\n\n\"\"\"\n# --------------------------------------------\n# read image from path\n# opencv is fast, but read BGR numpy image\n# --------------------------------------------\n\"\"\"\n\n\n# --------------------------------------------\n# get uint8 image of size HxWxn_channles (RGB)\n# --------------------------------------------\ndef imread_uint(path, n_channels=3):\n    #  input: path\n    # output: HxWx3(RGB or GGG), or HxWx1 (G)\n    if n_channels == 1:\n        img = cv2.imread(path, 0)  # cv2.IMREAD_GRAYSCALE\n        img = np.expand_dims(img, axis=2)  # HxWx1\n    elif n_channels == 3:\n        img = cv2.imread(path, cv2.IMREAD_UNCHANGED)  # BGR or G\n        if img.ndim == 2:\n            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)  # GGG\n        else:\n            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # RGB\n    return img\n\n\n# --------------------------------------------\n# matlab's imwrite\n# --------------------------------------------\ndef imsave(img, img_path):\n    img = np.squeeze(img)\n    if img.ndim == 3:\n        img = img[:, :, [2, 1, 0]]\n    cv2.imwrite(img_path, img)\n\n\ndef imwrite(img, img_path):\n    img = np.squeeze(img)\n    if img.ndim == 3:\n        img = img[:, :, [2, 1, 0]]\n    cv2.imwrite(img_path, img)\n\n\n# --------------------------------------------\n# get single image of size HxWxn_channles (BGR)\n# --------------------------------------------\ndef read_img(path):\n    # read image by cv2\n    # return: Numpy float32, HWC, BGR, [0,1]\n    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)  # cv2.IMREAD_GRAYSCALE\n    img = img.astype(np.float32) / 255.0\n    if img.ndim == 2:\n        img = np.expand_dims(img, axis=2)\n    # some images have 4 channels\n    if img.shape[2] > 3:\n        img = img[:, :, :3]\n    return img\n\n\n\"\"\"\n# --------------------------------------------\n# image format conversion\n# --------------------------------------------\n# numpy(single) <--->  numpy(unit)\n# numpy(single) <--->  tensor\n# numpy(unit)   <--->  tensor\n# --------------------------------------------\n\"\"\"\n\n\n# --------------------------------------------\n# numpy(single) [0, 1] <--->  numpy(unit)\n# --------------------------------------------\n\n\ndef uint2single(img):\n\n    return np.float32(img / 255.0)\n\n\ndef single2uint(img):\n\n    return np.uint8((img.clip(0, 1) * 255.0).round())\n\n\ndef uint162single(img):\n\n    return np.float32(img / 65535.0)\n\n\ndef single2uint16(img):\n\n    return np.uint16((img.clip(0, 1) * 65535.0).round())\n\n\n# --------------------------------------------\n# numpy(unit) (HxWxC or HxW) <--->  tensor\n# --------------------------------------------\n\n\n# convert uint to 4-dimensional torch tensor\ndef uint2tensor4(img):\n    if img.ndim == 2:\n        img = np.expand_dims(img, axis=2)\n    return (\n        torch.from_numpy(np.ascontiguousarray(img))\n        .permute(2, 0, 1)\n        .float()\n        .div(255.0)\n        .unsqueeze(0)\n    )\n\n\n# convert uint to 3-dimensional torch tensor\ndef uint2tensor3(img):\n    if img.ndim == 2:\n        img = np.expand_dims(img, axis=2)\n    return (\n        torch.from_numpy(np.ascontiguousarray(img))\n        .permute(2, 0, 1)\n        .float()\n        .div(255.0)\n    )\n\n\n# convert 2/3/4-dimensional torch tensor to uint\ndef tensor2uint(img):\n    img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()\n    if img.ndim == 3:\n        img = np.transpose(img, (1, 2, 0))\n    return np.uint8((img * 255.0).round())\n\n\n# --------------------------------------------\n# numpy(single) (HxWxC) <--->  tensor\n# --------------------------------------------\n\n\n# convert single (HxWxC) to 3-dimensional torch tensor\ndef single2tensor3(img):\n    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()\n\n\n# convert single (HxWxC) to 4-dimensional torch tensor\ndef single2tensor4(img):\n    return (\n        torch.from_numpy(np.ascontiguousarray(img))\n        .permute(2, 0, 1)\n        .float()\n        .unsqueeze(0)\n    )\n\n\n# convert torch tensor to single\ndef tensor2single(img):\n    img = img.data.squeeze().float().cpu().numpy()\n    if img.ndim == 3:\n        img = np.transpose(img, (1, 2, 0))\n\n    return img\n\n\n# convert torch tensor to single\ndef tensor2single3(img):\n    img = img.data.squeeze().float().cpu().numpy()\n    if img.ndim == 3:\n        img = np.transpose(img, (1, 2, 0))\n    elif img.ndim == 2:\n        img = np.expand_dims(img, axis=2)\n    return img\n\n\ndef single2tensor5(img):\n    return (\n        torch.from_numpy(np.ascontiguousarray(img))\n        .permute(2, 0, 1, 3)\n        .float()\n        .unsqueeze(0)\n    )\n\n\ndef single32tensor5(img):\n    return (\n        torch.from_numpy(np.ascontiguousarray(img))\n        .float()\n        .unsqueeze(0)\n        .unsqueeze(0)\n    )\n\n\ndef single42tensor4(img):\n    return (\n        torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()\n    )\n\n\n# from skimage.io import imread, imsave\ndef tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):\n    \"\"\"\n    Converts a torch Tensor into an image Numpy array of BGR channel order\n    Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order\n    Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)\n    \"\"\"\n    tensor = (\n        tensor.squeeze().float().cpu().clamp_(*min_max)\n    )  # squeeze first, then clamp\n    tensor = (tensor - min_max[0]) / (\n        min_max[1] - min_max[0]\n    )  # to range [0,1]\n    n_dim = tensor.dim()\n    if n_dim == 4:\n        n_img = len(tensor)\n        img_np = make_grid(\n            tensor, nrow=int(math.sqrt(n_img)), normalize=False\n        ).numpy()\n        img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR\n    elif n_dim == 3:\n        img_np = tensor.numpy()\n        img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR\n    elif n_dim == 2:\n        img_np = tensor.numpy()\n    else:\n        raise TypeError(\n            'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(\n                n_dim\n            )\n        )\n    if out_type == np.uint8:\n        img_np = (img_np * 255.0).round()\n        # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.\n    return img_np.astype(out_type)\n\n\n\"\"\"\n# --------------------------------------------\n# Augmentation, flipe and/or rotate\n# --------------------------------------------\n# The following two are enough.\n# (1) augmet_img: numpy image of WxHxC or WxH\n# (2) augment_img_tensor4: tensor image 1xCxWxH\n# --------------------------------------------\n\"\"\"\n\n\ndef augment_img(img, mode=0):\n    \"\"\"Kai Zhang (github: https://github.com/cszn)\"\"\"\n    if mode == 0:\n        return img\n    elif mode == 1:\n        return np.flipud(np.rot90(img))\n    elif mode == 2:\n        return np.flipud(img)\n    elif mode == 3:\n        return np.rot90(img, k=3)\n    elif mode == 4:\n        return np.flipud(np.rot90(img, k=2))\n    elif mode == 5:\n        return np.rot90(img)\n    elif mode == 6:\n        return np.rot90(img, k=2)\n    elif mode == 7:\n        return np.flipud(np.rot90(img, k=3))\n\n\ndef augment_img_tensor4(img, mode=0):\n    \"\"\"Kai Zhang (github: https://github.com/cszn)\"\"\"\n    if mode == 0:\n        return img\n    elif mode == 1:\n        return img.rot90(1, [2, 3]).flip([2])\n    elif mode == 2:\n        return img.flip([2])\n    elif mode == 3:\n        return img.rot90(3, [2, 3])\n    elif mode == 4:\n        return img.rot90(2, [2, 3]).flip([2])\n    elif mode == 5:\n        return img.rot90(1, [2, 3])\n    elif mode == 6:\n        return img.rot90(2, [2, 3])\n    elif mode == 7:\n        return img.rot90(3, [2, 3]).flip([2])\n\n\ndef augment_img_tensor(img, mode=0):\n    \"\"\"Kai Zhang (github: https://github.com/cszn)\"\"\"\n    img_size = img.size()\n    img_np = img.data.cpu().numpy()\n    if len(img_size) == 3:\n        img_np = np.transpose(img_np, (1, 2, 0))\n    elif len(img_size) == 4:\n        img_np = np.transpose(img_np, (2, 3, 1, 0))\n    img_np = augment_img(img_np, mode=mode)\n    img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))\n    if len(img_size) == 3:\n        img_tensor = img_tensor.permute(2, 0, 1)\n    elif len(img_size) == 4:\n        img_tensor = img_tensor.permute(3, 2, 0, 1)\n\n    return img_tensor.type_as(img)\n\n\ndef augment_img_np3(img, mode=0):\n    if mode == 0:\n        return img\n    elif mode == 1:\n        return img.transpose(1, 0, 2)\n    elif mode == 2:\n        return img[::-1, :, :]\n    elif mode == 3:\n        img = img[::-1, :, :]\n        img = img.transpose(1, 0, 2)\n        return img\n    elif mode == 4:\n        return img[:, ::-1, :]\n    elif mode == 5:\n        img = img[:, ::-1, :]\n        img = img.transpose(1, 0, 2)\n        return img\n    elif mode == 6:\n        img = img[:, ::-1, :]\n        img = img[::-1, :, :]\n        return img\n    elif mode == 7:\n        img = img[:, ::-1, :]\n        img = img[::-1, :, :]\n        img = img.transpose(1, 0, 2)\n        return img\n\n\ndef augment_imgs(img_list, hflip=True, rot=True):\n    # horizontal flip OR rotate\n    hflip = hflip and random.random() < 0.5\n    vflip = rot and random.random() < 0.5\n    rot90 = rot and random.random() < 0.5\n\n    def _augment(img):\n        if hflip:\n            img = img[:, ::-1, :]\n        if vflip:\n            img = img[::-1, :, :]\n        if rot90:\n            img = img.transpose(1, 0, 2)\n        return img\n\n    return [_augment(img) for img in img_list]\n\n\n\"\"\"\n# --------------------------------------------\n# modcrop and shave\n# --------------------------------------------\n\"\"\"\n\n\ndef modcrop(img_in, scale):\n    # img_in: Numpy, HWC or HW\n    img = np.copy(img_in)\n    if img.ndim == 2:\n        H, W = img.shape\n        H_r, W_r = H % scale, W % scale\n        img = img[: H - H_r, : W - W_r]\n    elif img.ndim == 3:\n        H, W, C = img.shape\n        H_r, W_r = H % scale, W % scale\n        img = img[: H - H_r, : W - W_r, :]\n    else:\n        raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))\n    return img\n\n\ndef shave(img_in, border=0):\n    # img_in: Numpy, HWC or HW\n    img = np.copy(img_in)\n    h, w = img.shape[:2]\n    img = img[border : h - border, border : w - border]\n    return img\n\n\n\"\"\"\n# --------------------------------------------\n# image processing process on numpy image\n# channel_convert(in_c, tar_type, img_list):\n# rgb2ycbcr(img, only_y=True):\n# bgr2ycbcr(img, only_y=True):\n# ycbcr2rgb(img):\n# --------------------------------------------\n\"\"\"\n\n\ndef rgb2ycbcr(img, only_y=True):\n    \"\"\"same as matlab rgb2ycbcr\n    only_y: only return Y channel\n    Input:\n        uint8, [0, 255]\n        float, [0, 1]\n    \"\"\"\n    in_img_type = img.dtype\n    img.astype(np.float32)\n    if in_img_type != np.uint8:\n        img *= 255.0\n    # convert\n    if only_y:\n        rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0\n    else:\n        rlt = np.matmul(\n            img,\n            [\n                [65.481, -37.797, 112.0],\n                [128.553, -74.203, -93.786],\n                [24.966, 112.0, -18.214],\n            ],\n        ) / 255.0 + [16, 128, 128]\n    if in_img_type == np.uint8:\n        rlt = rlt.round()\n    else:\n        rlt /= 255.0\n    return rlt.astype(in_img_type)\n\n\ndef ycbcr2rgb(img):\n    \"\"\"same as matlab ycbcr2rgb\n    Input:\n        uint8, [0, 255]\n        float, [0, 1]\n    \"\"\"\n    in_img_type = img.dtype\n    img.astype(np.float32)\n    if in_img_type != np.uint8:\n        img *= 255.0\n    # convert\n    rlt = np.matmul(\n        img,\n        [\n            [0.00456621, 0.00456621, 0.00456621],\n            [0, -0.00153632, 0.00791071],\n            [0.00625893, -0.00318811, 0],\n        ],\n    ) * 255.0 + [-222.921, 135.576, -276.836]\n    if in_img_type == np.uint8:\n        rlt = rlt.round()\n    else:\n        rlt /= 255.0\n    return rlt.astype(in_img_type)\n\n\ndef bgr2ycbcr(img, only_y=True):\n    \"\"\"bgr version of rgb2ycbcr\n    only_y: only return Y channel\n    Input:\n        uint8, [0, 255]\n        float, [0, 1]\n    \"\"\"\n    in_img_type = img.dtype\n    img.astype(np.float32)\n    if in_img_type != np.uint8:\n        img *= 255.0\n    # convert\n    if only_y:\n        rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0\n    else:\n        rlt = np.matmul(\n            img,\n            [\n                [24.966, 112.0, -18.214],\n                [128.553, -74.203, -93.786],\n                [65.481, -37.797, 112.0],\n            ],\n        ) / 255.0 + [16, 128, 128]\n    if in_img_type == np.uint8:\n        rlt = rlt.round()\n    else:\n        rlt /= 255.0\n    return rlt.astype(in_img_type)\n\n\ndef channel_convert(in_c, tar_type, img_list):\n    # conversion among BGR, gray and y\n    if in_c == 3 and tar_type == 'gray':  # BGR to gray\n        gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]\n        return [np.expand_dims(img, axis=2) for img in gray_list]\n    elif in_c == 3 and tar_type == 'y':  # BGR to y\n        y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]\n        return [np.expand_dims(img, axis=2) for img in y_list]\n    elif in_c == 1 and tar_type == 'RGB':  # gray/y to BGR\n        return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]\n    else:\n        return img_list\n\n\n\"\"\"\n# --------------------------------------------\n# metric, PSNR and SSIM\n# --------------------------------------------\n\"\"\"\n\n\n# --------------------------------------------\n# PSNR\n# --------------------------------------------\ndef calculate_psnr(img1, img2, border=0):\n    # img1 and img2 have range [0, 255]\n    # img1 = img1.squeeze()\n    # img2 = img2.squeeze()\n    if not img1.shape == img2.shape:\n        raise ValueError('Input images must have the same dimensions.')\n    h, w = img1.shape[:2]\n    img1 = img1[border : h - border, border : w - border]\n    img2 = img2[border : h - border, border : w - border]\n\n    img1 = img1.astype(np.float64)\n    img2 = img2.astype(np.float64)\n    mse = np.mean((img1 - img2) ** 2)\n    if mse == 0:\n        return float('inf')\n    return 20 * math.log10(255.0 / math.sqrt(mse))\n\n\n# --------------------------------------------\n# SSIM\n# --------------------------------------------\ndef calculate_ssim(img1, img2, border=0):\n    \"\"\"calculate SSIM\n    the same outputs as MATLAB's\n    img1, img2: [0, 255]\n    \"\"\"\n    # img1 = img1.squeeze()\n    # img2 = img2.squeeze()\n    if not img1.shape == img2.shape:\n        raise ValueError('Input images must have the same dimensions.')\n    h, w = img1.shape[:2]\n    img1 = img1[border : h - border, border : w - border]\n    img2 = img2[border : h - border, border : w - border]\n\n    if img1.ndim == 2:\n        return ssim(img1, img2)\n    elif img1.ndim == 3:\n        if img1.shape[2] == 3:\n            ssims = []\n            for i in range(3):\n                ssims.append(ssim(img1[:, :, i], img2[:, :, i]))\n            return np.array(ssims).mean()\n        elif img1.shape[2] == 1:\n            return ssim(np.squeeze(img1), np.squeeze(img2))\n    else:\n        raise ValueError('Wrong input image dimensions.')\n\n\ndef ssim(img1, img2):\n    C1 = (0.01 * 255) ** 2\n    C2 = (0.03 * 255) ** 2\n\n    img1 = img1.astype(np.float64)\n    img2 = img2.astype(np.float64)\n    kernel = cv2.getGaussianKernel(11, 1.5)\n    window = np.outer(kernel, kernel.transpose())\n\n    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid\n    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]\n    mu1_sq = mu1**2\n    mu2_sq = mu2**2\n    mu1_mu2 = mu1 * mu2\n    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq\n    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq\n    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2\n\n    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (\n        (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)\n    )\n    return ssim_map.mean()\n\n\n\"\"\"\n# --------------------------------------------\n# matlab's bicubic imresize (numpy and torch) [0, 1]\n# --------------------------------------------\n\"\"\"\n\n\n# matlab 'imresize' function, now only support 'bicubic'\ndef cubic(x):\n    absx = torch.abs(x)\n    absx2 = absx**2\n    absx3 = absx**3\n    return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (\n        -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2\n    ) * (((absx > 1) * (absx <= 2)).type_as(absx))\n\n\ndef calculate_weights_indices(\n    in_length, out_length, scale, kernel, kernel_width, antialiasing\n):\n    if (scale < 1) and (antialiasing):\n        # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width\n        kernel_width = kernel_width / scale\n\n    # Output-space coordinates\n    x = torch.linspace(1, out_length, out_length)\n\n    # Input-space coordinates. Calculate the inverse mapping such that 0.5\n    # in output space maps to 0.5 in input space, and 0.5+scale in output\n    # space maps to 1.5 in input space.\n    u = x / scale + 0.5 * (1 - 1 / scale)\n\n    # What is the left-most pixel that can be involved in the computation?\n    left = torch.floor(u - kernel_width / 2)\n\n    # What is the maximum number of pixels that can be involved in the\n    # computation?  Note: it's OK to use an extra pixel here; if the\n    # corresponding weights are all zero, it will be eliminated at the end\n    # of this function.\n    P = math.ceil(kernel_width) + 2\n\n    # The indices of the input pixels involved in computing the k-th output\n    # pixel are in row k of the indices matrix.\n    indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(\n        0, P - 1, P\n    ).view(1, P).expand(out_length, P)\n\n    # The weights used to compute the k-th output pixel are in row k of the\n    # weights matrix.\n    distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices\n    # apply cubic kernel\n    if (scale < 1) and (antialiasing):\n        weights = scale * cubic(distance_to_center * scale)\n    else:\n        weights = cubic(distance_to_center)\n    # Normalize the weights matrix so that each row sums to 1.\n    weights_sum = torch.sum(weights, 1).view(out_length, 1)\n    weights = weights / weights_sum.expand(out_length, P)\n\n    # If a column in weights is all zero, get rid of it. only consider the first and last column.\n    weights_zero_tmp = torch.sum((weights == 0), 0)\n    if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):\n        indices = indices.narrow(1, 1, P - 2)\n        weights = weights.narrow(1, 1, P - 2)\n    if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):\n        indices = indices.narrow(1, 0, P - 2)\n        weights = weights.narrow(1, 0, P - 2)\n    weights = weights.contiguous()\n    indices = indices.contiguous()\n    sym_len_s = -indices.min() + 1\n    sym_len_e = indices.max() - in_length\n    indices = indices + sym_len_s - 1\n    return weights, indices, int(sym_len_s), int(sym_len_e)\n\n\n# --------------------------------------------\n# imresize for tensor image [0, 1]\n# --------------------------------------------\ndef imresize(img, scale, antialiasing=True):\n    # Now the scale should be the same for H and W\n    # input: img: pytorch tensor, CHW or HW [0,1]\n    # output: CHW or HW [0,1] w/o round\n    need_squeeze = True if img.dim() == 2 else False\n    if need_squeeze:\n        img.unsqueeze_(0)\n    in_C, in_H, in_W = img.size()\n    out_C, out_H, out_W = (\n        in_C,\n        math.ceil(in_H * scale),\n        math.ceil(in_W * scale),\n    )\n    kernel_width = 4\n    kernel = 'cubic'\n\n    # Return the desired dimension order for performing the resize.  The\n    # strategy is to perform the resize first along the dimension with the\n    # smallest scale factor.\n    # Now we do not support this.\n\n    # get weights and indices\n    weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(\n        in_H, out_H, scale, kernel, kernel_width, antialiasing\n    )\n    weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(\n        in_W, out_W, scale, kernel, kernel_width, antialiasing\n    )\n    # process H dimension\n    # symmetric copying\n    img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)\n    img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)\n\n    sym_patch = img[:, :sym_len_Hs, :]\n    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(1, inv_idx)\n    img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)\n\n    sym_patch = img[:, -sym_len_He:, :]\n    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(1, inv_idx)\n    img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)\n\n    out_1 = torch.FloatTensor(in_C, out_H, in_W)\n    kernel_width = weights_H.size(1)\n    for i in range(out_H):\n        idx = int(indices_H[i][0])\n        for j in range(out_C):\n            out_1[j, i, :] = (\n                img_aug[j, idx : idx + kernel_width, :]\n                .transpose(0, 1)\n                .mv(weights_H[i])\n            )\n\n    # process W dimension\n    # symmetric copying\n    out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)\n    out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)\n\n    sym_patch = out_1[:, :, :sym_len_Ws]\n    inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(2, inv_idx)\n    out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)\n\n    sym_patch = out_1[:, :, -sym_len_We:]\n    inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(2, inv_idx)\n    out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)\n\n    out_2 = torch.FloatTensor(in_C, out_H, out_W)\n    kernel_width = weights_W.size(1)\n    for i in range(out_W):\n        idx = int(indices_W[i][0])\n        for j in range(out_C):\n            out_2[j, :, i] = out_1_aug[j, :, idx : idx + kernel_width].mv(\n                weights_W[i]\n            )\n    if need_squeeze:\n        out_2.squeeze_()\n    return out_2\n\n\n# --------------------------------------------\n# imresize for numpy image [0, 1]\n# --------------------------------------------\ndef imresize_np(img, scale, antialiasing=True):\n    # Now the scale should be the same for H and W\n    # input: img: Numpy, HWC or HW [0,1]\n    # output: HWC or HW [0,1] w/o round\n    img = torch.from_numpy(img)\n    need_squeeze = True if img.dim() == 2 else False\n    if need_squeeze:\n        img.unsqueeze_(2)\n\n    in_H, in_W, in_C = img.size()\n    out_C, out_H, out_W = (\n        in_C,\n        math.ceil(in_H * scale),\n        math.ceil(in_W * scale),\n    )\n    kernel_width = 4\n    kernel = 'cubic'\n\n    # Return the desired dimension order for performing the resize.  The\n    # strategy is to perform the resize first along the dimension with the\n    # smallest scale factor.\n    # Now we do not support this.\n\n    # get weights and indices\n    weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(\n        in_H, out_H, scale, kernel, kernel_width, antialiasing\n    )\n    weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(\n        in_W, out_W, scale, kernel, kernel_width, antialiasing\n    )\n    # process H dimension\n    # symmetric copying\n    img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)\n    img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)\n\n    sym_patch = img[:sym_len_Hs, :, :]\n    inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(0, inv_idx)\n    img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)\n\n    sym_patch = img[-sym_len_He:, :, :]\n    inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(0, inv_idx)\n    img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)\n\n    out_1 = torch.FloatTensor(out_H, in_W, in_C)\n    kernel_width = weights_H.size(1)\n    for i in range(out_H):\n        idx = int(indices_H[i][0])\n        for j in range(out_C):\n            out_1[i, :, j] = (\n                img_aug[idx : idx + kernel_width, :, j]\n                .transpose(0, 1)\n                .mv(weights_H[i])\n            )\n\n    # process W dimension\n    # symmetric copying\n    out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)\n    out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)\n\n    sym_patch = out_1[:, :sym_len_Ws, :]\n    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(1, inv_idx)\n    out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)\n\n    sym_patch = out_1[:, -sym_len_We:, :]\n    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(1, inv_idx)\n    out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)\n\n    out_2 = torch.FloatTensor(out_H, out_W, in_C)\n    kernel_width = weights_W.size(1)\n    for i in range(out_W):\n        idx = int(indices_W[i][0])\n        for j in range(out_C):\n            out_2[:, i, j] = out_1_aug[:, idx : idx + kernel_width, j].mv(\n                weights_W[i]\n            )\n    if need_squeeze:\n        out_2.squeeze_()\n\n    return out_2.numpy()\n\n\nif __name__ == '__main__':\n    print('---')\n#    img = imread_uint('test.bmp', 3)\n#    img = uint2single(img)\n#    img_bicubic = imresize_np(img, 1/4)\n"
  },
  {
    "path": "src/stablediffusion/ldm/modules/losses/__init__.py",
    "content": "from src.stablediffusion.ldm.modules.losses.contperceptual import LPIPSWithDiscriminator\n"
  },
  {
    "path": "src/stablediffusion/ldm/modules/losses/contperceptual.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom taming.modules.losses.vqperceptual import *  # TODO: taming dependency yes/no?\n\n\nclass LPIPSWithDiscriminator(nn.Module):\n    def __init__(\n        self,\n        disc_start,\n        logvar_init=0.0,\n        kl_weight=1.0,\n        pixelloss_weight=1.0,\n        disc_num_layers=3,\n        disc_in_channels=3,\n        disc_factor=1.0,\n        disc_weight=1.0,\n        perceptual_weight=1.0,\n        use_actnorm=False,\n        disc_conditional=False,\n        disc_loss='hinge',\n    ):\n\n        super().__init__()\n        assert disc_loss in ['hinge', 'vanilla']\n        self.kl_weight = kl_weight\n        self.pixel_weight = pixelloss_weight\n        self.perceptual_loss = LPIPS().eval()\n        self.perceptual_weight = perceptual_weight\n        # output log variance\n        self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)\n\n        self.discriminator = NLayerDiscriminator(\n            input_nc=disc_in_channels,\n            n_layers=disc_num_layers,\n            use_actnorm=use_actnorm,\n        ).apply(weights_init)\n        self.discriminator_iter_start = disc_start\n        self.disc_loss = (\n            hinge_d_loss if disc_loss == 'hinge' else vanilla_d_loss\n        )\n        self.disc_factor = disc_factor\n        self.discriminator_weight = disc_weight\n        self.disc_conditional = disc_conditional\n\n    def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):\n        if last_layer is not None:\n            nll_grads = torch.autograd.grad(\n                nll_loss, last_layer, retain_graph=True\n            )[0]\n            g_grads = torch.autograd.grad(\n                g_loss, last_layer, retain_graph=True\n            )[0]\n        else:\n            nll_grads = torch.autograd.grad(\n                nll_loss, self.last_layer[0], retain_graph=True\n            )[0]\n            g_grads = torch.autograd.grad(\n                g_loss, self.last_layer[0], retain_graph=True\n            )[0]\n\n        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)\n        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()\n        d_weight = d_weight * self.discriminator_weight\n        return d_weight\n\n    def forward(\n        self,\n        inputs,\n        reconstructions,\n        posteriors,\n        optimizer_idx,\n        global_step,\n        last_layer=None,\n        cond=None,\n        split='train',\n        weights=None,\n    ):\n        rec_loss = torch.abs(\n            inputs.contiguous() - reconstructions.contiguous()\n        )\n        if self.perceptual_weight > 0:\n            p_loss = self.perceptual_loss(\n                inputs.contiguous(), reconstructions.contiguous()\n            )\n            rec_loss = rec_loss + self.perceptual_weight * p_loss\n\n        nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar\n        weighted_nll_loss = nll_loss\n        if weights is not None:\n            weighted_nll_loss = weights * nll_loss\n        weighted_nll_loss = (\n            torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]\n        )\n        nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]\n        kl_loss = posteriors.kl()\n        kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]\n\n        # now the GAN part\n        if optimizer_idx == 0:\n            # generator update\n            if cond is None:\n                assert not self.disc_conditional\n                logits_fake = self.discriminator(reconstructions.contiguous())\n            else:\n                assert self.disc_conditional\n                logits_fake = self.discriminator(\n                    torch.cat((reconstructions.contiguous(), cond), dim=1)\n                )\n            g_loss = -torch.mean(logits_fake)\n\n            if self.disc_factor > 0.0:\n                try:\n                    d_weight = self.calculate_adaptive_weight(\n                        nll_loss, g_loss, last_layer=last_layer\n                    )\n                except RuntimeError:\n                    assert not self.training\n                    d_weight = torch.tensor(0.0)\n            else:\n                d_weight = torch.tensor(0.0)\n\n            disc_factor = adopt_weight(\n                self.disc_factor,\n                global_step,\n                threshold=self.discriminator_iter_start,\n            )\n            loss = (\n                weighted_nll_loss\n                + self.kl_weight * kl_loss\n                + d_weight * disc_factor * g_loss\n            )\n\n            log = {\n                '{}/total_loss'.format(split): loss.clone().detach().mean(),\n                '{}/logvar'.format(split): self.logvar.detach(),\n                '{}/kl_loss'.format(split): kl_loss.detach().mean(),\n                '{}/nll_loss'.format(split): nll_loss.detach().mean(),\n                '{}/rec_loss'.format(split): rec_loss.detach().mean(),\n                '{}/d_weight'.format(split): d_weight.detach(),\n                '{}/disc_factor'.format(split): torch.tensor(disc_factor),\n                '{}/g_loss'.format(split): g_loss.detach().mean(),\n            }\n            return loss, log\n\n        if optimizer_idx == 1:\n            # second pass for discriminator update\n            if cond is None:\n                logits_real = self.discriminator(inputs.contiguous().detach())\n                logits_fake = self.discriminator(\n                    reconstructions.contiguous().detach()\n                )\n            else:\n                logits_real = self.discriminator(\n                    torch.cat((inputs.contiguous().detach(), cond), dim=1)\n                )\n                logits_fake = self.discriminator(\n                    torch.cat(\n                        (reconstructions.contiguous().detach(), cond), dim=1\n                    )\n                )\n\n            disc_factor = adopt_weight(\n                self.disc_factor,\n                global_step,\n                threshold=self.discriminator_iter_start,\n            )\n            d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)\n\n            log = {\n                '{}/disc_loss'.format(split): d_loss.clone().detach().mean(),\n                '{}/logits_real'.format(split): logits_real.detach().mean(),\n                '{}/logits_fake'.format(split): logits_fake.detach().mean(),\n            }\n            return d_loss, log\n"
  },
  {
    "path": "src/stablediffusion/ldm/modules/losses/vqperceptual.py",
    "content": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom einops import repeat\n\nfrom taming.modules.discriminator.model import (\n    NLayerDiscriminator,\n    weights_init,\n)\nfrom taming.modules.losses.lpips import LPIPS\nfrom taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss\n\n\ndef hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):\n    assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]\n    loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3])\n    loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3])\n    loss_real = (weights * loss_real).sum() / weights.sum()\n    loss_fake = (weights * loss_fake).sum() / weights.sum()\n    d_loss = 0.5 * (loss_real + loss_fake)\n    return d_loss\n\n\ndef adopt_weight(weight, global_step, threshold=0, value=0.0):\n    if global_step < threshold:\n        weight = value\n    return weight\n\n\ndef measure_perplexity(predicted_indices, n_embed):\n    # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py\n    # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally\n    encodings = (\n        F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)\n    )\n    avg_probs = encodings.mean(0)\n    perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()\n    cluster_use = torch.sum(avg_probs > 0)\n    return perplexity, cluster_use\n\n\ndef l1(x, y):\n    return torch.abs(x - y)\n\n\ndef l2(x, y):\n    return torch.pow((x - y), 2)\n\n\nclass VQLPIPSWithDiscriminator(nn.Module):\n    def __init__(\n        self,\n        disc_start,\n        codebook_weight=1.0,\n        pixelloss_weight=1.0,\n        disc_num_layers=3,\n        disc_in_channels=3,\n        disc_factor=1.0,\n        disc_weight=1.0,\n        perceptual_weight=1.0,\n        use_actnorm=False,\n        disc_conditional=False,\n        disc_ndf=64,\n        disc_loss='hinge',\n        n_classes=None,\n        perceptual_loss='lpips',\n        pixel_loss='l1',\n    ):\n        super().__init__()\n        assert disc_loss in ['hinge', 'vanilla']\n        assert perceptual_loss in ['lpips', 'clips', 'dists']\n        assert pixel_loss in ['l1', 'l2']\n        self.codebook_weight = codebook_weight\n        self.pixel_weight = pixelloss_weight\n        if perceptual_loss == 'lpips':\n            print(f'{self.__class__.__name__}: Running with LPIPS.')\n            self.perceptual_loss = LPIPS().eval()\n        else:\n            raise ValueError(\n                f'Unknown perceptual loss: >> {perceptual_loss} <<'\n            )\n        self.perceptual_weight = perceptual_weight\n\n        if pixel_loss == 'l1':\n            self.pixel_loss = l1\n        else:\n            self.pixel_loss = l2\n\n        self.discriminator = NLayerDiscriminator(\n            input_nc=disc_in_channels,\n            n_layers=disc_num_layers,\n            use_actnorm=use_actnorm,\n            ndf=disc_ndf,\n        ).apply(weights_init)\n        self.discriminator_iter_start = disc_start\n        if disc_loss == 'hinge':\n            self.disc_loss = hinge_d_loss\n        elif disc_loss == 'vanilla':\n            self.disc_loss = vanilla_d_loss\n        else:\n            raise ValueError(f\"Unknown GAN loss '{disc_loss}'.\")\n        print(f'VQLPIPSWithDiscriminator running with {disc_loss} loss.')\n        self.disc_factor = disc_factor\n        self.discriminator_weight = disc_weight\n        self.disc_conditional = disc_conditional\n        self.n_classes = n_classes\n\n    def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):\n        if last_layer is not None:\n            nll_grads = torch.autograd.grad(\n                nll_loss, last_layer, retain_graph=True\n            )[0]\n            g_grads = torch.autograd.grad(\n                g_loss, last_layer, retain_graph=True\n            )[0]\n        else:\n            nll_grads = torch.autograd.grad(\n                nll_loss, self.last_layer[0], retain_graph=True\n            )[0]\n            g_grads = torch.autograd.grad(\n                g_loss, self.last_layer[0], retain_graph=True\n            )[0]\n\n        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)\n        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()\n        d_weight = d_weight * self.discriminator_weight\n        return d_weight\n\n    def forward(\n        self,\n        codebook_loss,\n        inputs,\n        reconstructions,\n        optimizer_idx,\n        global_step,\n        last_layer=None,\n        cond=None,\n        split='train',\n        predicted_indices=None,\n    ):\n        if not exists(codebook_loss):\n            codebook_loss = torch.tensor([0.0]).to(inputs.device)\n        # rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())\n        rec_loss = self.pixel_loss(\n            inputs.contiguous(), reconstructions.contiguous()\n        )\n        if self.perceptual_weight > 0:\n            p_loss = self.perceptual_loss(\n                inputs.contiguous(), reconstructions.contiguous()\n            )\n            rec_loss = rec_loss + self.perceptual_weight * p_loss\n        else:\n            p_loss = torch.tensor([0.0])\n\n        nll_loss = rec_loss\n        # nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]\n        nll_loss = torch.mean(nll_loss)\n\n        # now the GAN part\n        if optimizer_idx == 0:\n            # generator update\n            if cond is None:\n                assert not self.disc_conditional\n                logits_fake = self.discriminator(reconstructions.contiguous())\n            else:\n                assert self.disc_conditional\n                logits_fake = self.discriminator(\n                    torch.cat((reconstructions.contiguous(), cond), dim=1)\n                )\n            g_loss = -torch.mean(logits_fake)\n\n            try:\n                d_weight = self.calculate_adaptive_weight(\n                    nll_loss, g_loss, last_layer=last_layer\n                )\n            except RuntimeError:\n                assert not self.training\n                d_weight = torch.tensor(0.0)\n\n            disc_factor = adopt_weight(\n                self.disc_factor,\n                global_step,\n                threshold=self.discriminator_iter_start,\n            )\n            loss = (\n                nll_loss\n                + d_weight * disc_factor * g_loss\n                + self.codebook_weight * codebook_loss.mean()\n            )\n\n            log = {\n                '{}/total_loss'.format(split): loss.clone().detach().mean(),\n                '{}/quant_loss'.format(split): codebook_loss.detach().mean(),\n                '{}/nll_loss'.format(split): nll_loss.detach().mean(),\n                '{}/rec_loss'.format(split): rec_loss.detach().mean(),\n                '{}/p_loss'.format(split): p_loss.detach().mean(),\n                '{}/d_weight'.format(split): d_weight.detach(),\n                '{}/disc_factor'.format(split): torch.tensor(disc_factor),\n                '{}/g_loss'.format(split): g_loss.detach().mean(),\n            }\n            if predicted_indices is not None:\n                assert self.n_classes is not None\n                with torch.no_grad():\n                    perplexity, cluster_usage = measure_perplexity(\n                        predicted_indices, self.n_classes\n                    )\n                log[f'{split}/perplexity'] = perplexity\n                log[f'{split}/cluster_usage'] = cluster_usage\n            return loss, log\n\n        if optimizer_idx == 1:\n            # second pass for discriminator update\n            if cond is None:\n                logits_real = self.discriminator(inputs.contiguous().detach())\n                logits_fake = self.discriminator(\n                    reconstructions.contiguous().detach()\n                )\n            else:\n                logits_real = self.discriminator(\n                    torch.cat((inputs.contiguous().detach(), cond), dim=1)\n                )\n                logits_fake = self.discriminator(\n                    torch.cat(\n                        (reconstructions.contiguous().detach(), cond), dim=1\n                    )\n                )\n\n            disc_factor = adopt_weight(\n                self.disc_factor,\n                global_step,\n                threshold=self.discriminator_iter_start,\n            )\n            d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)\n\n            log = {\n                '{}/disc_loss'.format(split): d_loss.clone().detach().mean(),\n                '{}/logits_real'.format(split): logits_real.detach().mean(),\n                '{}/logits_fake'.format(split): logits_fake.detach().mean(),\n            }\n            return d_loss, log\n"
  },
  {
    "path": "src/stablediffusion/ldm/modules/x_transformer.py",
    "content": "\"\"\"shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers\"\"\"\nimport torch\nfrom torch import nn, einsum\nimport torch.nn.functional as F\nfrom functools import partial\nfrom inspect import isfunction\nfrom collections import namedtuple\nfrom einops import rearrange, repeat, reduce\n\n# constants\n\nDEFAULT_DIM_HEAD = 64\n\nIntermediates = namedtuple(\n    'Intermediates', ['pre_softmax_attn', 'post_softmax_attn']\n)\n\nLayerIntermediates = namedtuple(\n    'Intermediates', ['hiddens', 'attn_intermediates']\n)\n\n\nclass AbsolutePositionalEmbedding(nn.Module):\n    def __init__(self, dim, max_seq_len):\n        super().__init__()\n        self.emb = nn.Embedding(max_seq_len, dim)\n        self.init_()\n\n    def init_(self):\n        nn.init.normal_(self.emb.weight, std=0.02)\n\n    def forward(self, x):\n        n = torch.arange(x.shape[1], device=x.device)\n        return self.emb(n)[None, :, :]\n\n\nclass FixedPositionalEmbedding(nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))\n        self.register_buffer('inv_freq', inv_freq)\n\n    def forward(self, x, seq_dim=1, offset=0):\n        t = (\n            torch.arange(x.shape[seq_dim], device=x.device).type_as(\n                self.inv_freq\n            )\n            + offset\n        )\n        sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)\n        emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)\n        return emb[None, :, :]\n\n\n# helpers\n\n\ndef exists(val):\n    return val is not None\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\n\ndef always(val):\n    def inner(*args, **kwargs):\n        return val\n\n    return inner\n\n\ndef not_equals(val):\n    def inner(x):\n        return x != val\n\n    return inner\n\n\ndef equals(val):\n    def inner(x):\n        return x == val\n\n    return inner\n\n\ndef max_neg_value(tensor):\n    return -torch.finfo(tensor.dtype).max\n\n\n# keyword argument helpers\n\n\ndef pick_and_pop(keys, d):\n    values = list(map(lambda key: d.pop(key), keys))\n    return dict(zip(keys, values))\n\n\ndef group_dict_by_key(cond, d):\n    return_val = [dict(), dict()]\n    for key in d.keys():\n        match = bool(cond(key))\n        ind = int(not match)\n        return_val[ind][key] = d[key]\n    return (*return_val,)\n\n\ndef string_begins_with(prefix, str):\n    return str.startswith(prefix)\n\n\ndef group_by_key_prefix(prefix, d):\n    return group_dict_by_key(partial(string_begins_with, prefix), d)\n\n\ndef groupby_prefix_and_trim(prefix, d):\n    kwargs_with_prefix, kwargs = group_dict_by_key(\n        partial(string_begins_with, prefix), d\n    )\n    kwargs_without_prefix = dict(\n        map(\n            lambda x: (x[0][len(prefix) :], x[1]),\n            tuple(kwargs_with_prefix.items()),\n        )\n    )\n    return kwargs_without_prefix, kwargs\n\n\n# classes\nclass Scale(nn.Module):\n    def __init__(self, value, fn):\n        super().__init__()\n        self.value = value\n        self.fn = fn\n\n    def forward(self, x, **kwargs):\n        x, *rest = self.fn(x, **kwargs)\n        return (x * self.value, *rest)\n\n\nclass Rezero(nn.Module):\n    def __init__(self, fn):\n        super().__init__()\n        self.fn = fn\n        self.g = nn.Parameter(torch.zeros(1))\n\n    def forward(self, x, **kwargs):\n        x, *rest = self.fn(x, **kwargs)\n        return (x * self.g, *rest)\n\n\nclass ScaleNorm(nn.Module):\n    def __init__(self, dim, eps=1e-5):\n        super().__init__()\n        self.scale = dim**-0.5\n        self.eps = eps\n        self.g = nn.Parameter(torch.ones(1))\n\n    def forward(self, x):\n        norm = torch.norm(x, dim=-1, keepdim=True) * self.scale\n        return x / norm.clamp(min=self.eps) * self.g\n\n\nclass RMSNorm(nn.Module):\n    def __init__(self, dim, eps=1e-8):\n        super().__init__()\n        self.scale = dim**-0.5\n        self.eps = eps\n        self.g = nn.Parameter(torch.ones(dim))\n\n    def forward(self, x):\n        norm = torch.norm(x, dim=-1, keepdim=True) * self.scale\n        return x / norm.clamp(min=self.eps) * self.g\n\n\nclass Residual(nn.Module):\n    def forward(self, x, residual):\n        return x + residual\n\n\nclass GRUGating(nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        self.gru = nn.GRUCell(dim, dim)\n\n    def forward(self, x, residual):\n        gated_output = self.gru(\n            rearrange(x, 'b n d -> (b n) d'),\n            rearrange(residual, 'b n d -> (b n) d'),\n        )\n\n        return gated_output.reshape_as(x)\n\n\n# feedforward\n\n\nclass GEGLU(nn.Module):\n    def __init__(self, dim_in, dim_out):\n        super().__init__()\n        self.proj = nn.Linear(dim_in, dim_out * 2)\n\n    def forward(self, x):\n        x, gate = self.proj(x).chunk(2, dim=-1)\n        return x * F.gelu(gate)\n\n\nclass FeedForward(nn.Module):\n    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):\n        super().__init__()\n        inner_dim = int(dim * mult)\n        dim_out = default(dim_out, dim)\n        project_in = (\n            nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())\n            if not glu\n            else GEGLU(dim, inner_dim)\n        )\n\n        self.net = nn.Sequential(\n            project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)\n        )\n\n    def forward(self, x):\n        return self.net(x)\n\n\n# attention.\nclass Attention(nn.Module):\n    def __init__(\n        self,\n        dim,\n        dim_head=DEFAULT_DIM_HEAD,\n        heads=8,\n        causal=False,\n        mask=None,\n        talking_heads=False,\n        sparse_topk=None,\n        use_entmax15=False,\n        num_mem_kv=0,\n        dropout=0.0,\n        on_attn=False,\n    ):\n        super().__init__()\n        if use_entmax15:\n            raise NotImplementedError(\n                'Check out entmax activation instead of softmax activation!'\n            )\n        self.scale = dim_head**-0.5\n        self.heads = heads\n        self.causal = causal\n        self.mask = mask\n\n        inner_dim = dim_head * heads\n\n        self.to_q = nn.Linear(dim, inner_dim, bias=False)\n        self.to_k = nn.Linear(dim, inner_dim, bias=False)\n        self.to_v = nn.Linear(dim, inner_dim, bias=False)\n        self.dropout = nn.Dropout(dropout)\n\n        # talking heads\n        self.talking_heads = talking_heads\n        if talking_heads:\n            self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))\n            self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))\n\n        # explicit topk sparse attention\n        self.sparse_topk = sparse_topk\n\n        # entmax\n        # self.attn_fn = entmax15 if use_entmax15 else F.softmax\n        self.attn_fn = F.softmax\n\n        # add memory key / values\n        self.num_mem_kv = num_mem_kv\n        if num_mem_kv > 0:\n            self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))\n            self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))\n\n        # attention on attention\n        self.attn_on_attn = on_attn\n        self.to_out = (\n            nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU())\n            if on_attn\n            else nn.Linear(inner_dim, dim)\n        )\n\n    def forward(\n        self,\n        x,\n        context=None,\n        mask=None,\n        context_mask=None,\n        rel_pos=None,\n        sinusoidal_emb=None,\n        prev_attn=None,\n        mem=None,\n    ):\n        b, n, _, h, talking_heads, device = (\n            *x.shape,\n            self.heads,\n            self.talking_heads,\n            x.device,\n        )\n        kv_input = default(context, x)\n\n        q_input = x\n        k_input = kv_input\n        v_input = kv_input\n\n        if exists(mem):\n            k_input = torch.cat((mem, k_input), dim=-2)\n            v_input = torch.cat((mem, v_input), dim=-2)\n\n        if exists(sinusoidal_emb):\n            # in shortformer, the query would start at a position offset depending on the past cached memory\n            offset = k_input.shape[-2] - q_input.shape[-2]\n            q_input = q_input + sinusoidal_emb(q_input, offset=offset)\n            k_input = k_input + sinusoidal_emb(k_input)\n\n        q = self.to_q(q_input)\n        k = self.to_k(k_input)\n        v = self.to_v(v_input)\n\n        q, k, v = map(\n            lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)\n        )\n\n        input_mask = None\n        if any(map(exists, (mask, context_mask))):\n            q_mask = default(\n                mask, lambda: torch.ones((b, n), device=device).bool()\n            )\n            k_mask = q_mask if not exists(context) else context_mask\n            k_mask = default(\n                k_mask,\n                lambda: torch.ones((b, k.shape[-2]), device=device).bool(),\n            )\n            q_mask = rearrange(q_mask, 'b i -> b () i ()')\n            k_mask = rearrange(k_mask, 'b j -> b () () j')\n            input_mask = q_mask * k_mask\n\n        if self.num_mem_kv > 0:\n            mem_k, mem_v = map(\n                lambda t: repeat(t, 'h n d -> b h n d', b=b),\n                (self.mem_k, self.mem_v),\n            )\n            k = torch.cat((mem_k, k), dim=-2)\n            v = torch.cat((mem_v, v), dim=-2)\n            if exists(input_mask):\n                input_mask = F.pad(\n                    input_mask, (self.num_mem_kv, 0), value=True\n                )\n\n        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale\n        mask_value = max_neg_value(dots)\n\n        if exists(prev_attn):\n            dots = dots + prev_attn\n\n        pre_softmax_attn = dots\n\n        if talking_heads:\n            dots = einsum(\n                'b h i j, h k -> b k i j', dots, self.pre_softmax_proj\n            ).contiguous()\n\n        if exists(rel_pos):\n            dots = rel_pos(dots)\n\n        if exists(input_mask):\n            dots.masked_fill_(~input_mask, mask_value)\n            del input_mask\n\n        if self.causal:\n            i, j = dots.shape[-2:]\n            r = torch.arange(i, device=device)\n            mask = rearrange(r, 'i -> () () i ()') < rearrange(\n                r, 'j -> () () () j'\n            )\n            mask = F.pad(mask, (j - i, 0), value=False)\n            dots.masked_fill_(mask, mask_value)\n            del mask\n\n        if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:\n            top, _ = dots.topk(self.sparse_topk, dim=-1)\n            vk = top[..., -1].unsqueeze(-1).expand_as(dots)\n            mask = dots < vk\n            dots.masked_fill_(mask, mask_value)\n            del mask\n\n        attn = self.attn_fn(dots, dim=-1)\n        post_softmax_attn = attn\n\n        attn = self.dropout(attn)\n\n        if talking_heads:\n            attn = einsum(\n                'b h i j, h k -> b k i j', attn, self.post_softmax_proj\n            ).contiguous()\n\n        out = einsum('b h i j, b h j d -> b h i d', attn, v)\n        out = rearrange(out, 'b h n d -> b n (h d)')\n\n        intermediates = Intermediates(\n            pre_softmax_attn=pre_softmax_attn,\n            post_softmax_attn=post_softmax_attn,\n        )\n\n        return self.to_out(out), intermediates\n\n\nclass AttentionLayers(nn.Module):\n    def __init__(\n        self,\n        dim,\n        depth,\n        heads=8,\n        causal=False,\n        cross_attend=False,\n        only_cross=False,\n        use_scalenorm=False,\n        use_rmsnorm=False,\n        use_rezero=False,\n        rel_pos_num_buckets=32,\n        rel_pos_max_distance=128,\n        position_infused_attn=False,\n        custom_layers=None,\n        sandwich_coef=None,\n        par_ratio=None,\n        residual_attn=False,\n        cross_residual_attn=False,\n        macaron=False,\n        pre_norm=True,\n        gate_residual=False,\n        **kwargs,\n    ):\n        super().__init__()\n        ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)\n        attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)\n\n        dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)\n\n        self.dim = dim\n        self.depth = depth\n        self.layers = nn.ModuleList([])\n\n        self.has_pos_emb = position_infused_attn\n        self.pia_pos_emb = (\n            FixedPositionalEmbedding(dim) if position_infused_attn else None\n        )\n        self.rotary_pos_emb = always(None)\n\n        assert (\n            rel_pos_num_buckets <= rel_pos_max_distance\n        ), 'number of relative position buckets must be less than the relative position max distance'\n        self.rel_pos = None\n\n        self.pre_norm = pre_norm\n\n        self.residual_attn = residual_attn\n        self.cross_residual_attn = cross_residual_attn\n\n        norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm\n        norm_class = RMSNorm if use_rmsnorm else norm_class\n        norm_fn = partial(norm_class, dim)\n\n        norm_fn = nn.Identity if use_rezero else norm_fn\n        branch_fn = Rezero if use_rezero else None\n\n        if cross_attend and not only_cross:\n            default_block = ('a', 'c', 'f')\n        elif cross_attend and only_cross:\n            default_block = ('c', 'f')\n        else:\n            default_block = ('a', 'f')\n\n        if macaron:\n            default_block = ('f',) + default_block\n\n        if exists(custom_layers):\n            layer_types = custom_layers\n        elif exists(par_ratio):\n            par_depth = depth * len(default_block)\n            assert 1 < par_ratio <= par_depth, 'par ratio out of range'\n            default_block = tuple(filter(not_equals('f'), default_block))\n            par_attn = par_depth // par_ratio\n            depth_cut = (\n                par_depth * 2 // 3\n            )  # 2 / 3 attention layer cutoff suggested by PAR paper\n            par_width = (depth_cut + depth_cut // par_attn) // par_attn\n            assert (\n                len(default_block) <= par_width\n            ), 'default block is too large for par_ratio'\n            par_block = default_block + ('f',) * (\n                par_width - len(default_block)\n            )\n            par_head = par_block * par_attn\n            layer_types = par_head + ('f',) * (par_depth - len(par_head))\n        elif exists(sandwich_coef):\n            assert (\n                sandwich_coef > 0 and sandwich_coef <= depth\n            ), 'sandwich coefficient should be less than the depth'\n            layer_types = (\n                ('a',) * sandwich_coef\n                + default_block * (depth - sandwich_coef)\n                + ('f',) * sandwich_coef\n            )\n        else:\n            layer_types = default_block * depth\n\n        self.layer_types = layer_types\n        self.num_attn_layers = len(list(filter(equals('a'), layer_types)))\n\n        for layer_type in self.layer_types:\n            if layer_type == 'a':\n                layer = Attention(\n                    dim, heads=heads, causal=causal, **attn_kwargs\n                )\n            elif layer_type == 'c':\n                layer = Attention(dim, heads=heads, **attn_kwargs)\n            elif layer_type == 'f':\n                layer = FeedForward(dim, **ff_kwargs)\n                layer = layer if not macaron else Scale(0.5, layer)\n            else:\n                raise Exception(f'invalid layer type {layer_type}')\n\n            if isinstance(layer, Attention) and exists(branch_fn):\n                layer = branch_fn(layer)\n\n            if gate_residual:\n                residual_fn = GRUGating(dim)\n            else:\n                residual_fn = Residual()\n\n            self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))\n\n    def forward(\n        self,\n        x,\n        context=None,\n        mask=None,\n        context_mask=None,\n        mems=None,\n        return_hiddens=False,\n        **kwargs,\n    ):\n        hiddens = []\n        intermediates = []\n        prev_attn = None\n        prev_cross_attn = None\n\n        mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers\n\n        for ind, (layer_type, (norm, block, residual_fn)) in enumerate(\n            zip(self.layer_types, self.layers)\n        ):\n            is_last = ind == (len(self.layers) - 1)\n\n            if layer_type == 'a':\n                hiddens.append(x)\n                layer_mem = mems.pop(0)\n\n            residual = x\n\n            if self.pre_norm:\n                x = norm(x)\n\n            if layer_type == 'a':\n                out, inter = block(\n                    x,\n                    mask=mask,\n                    sinusoidal_emb=self.pia_pos_emb,\n                    rel_pos=self.rel_pos,\n                    prev_attn=prev_attn,\n                    mem=layer_mem,\n                )\n            elif layer_type == 'c':\n                out, inter = block(\n                    x,\n                    context=context,\n                    mask=mask,\n                    context_mask=context_mask,\n                    prev_attn=prev_cross_attn,\n                )\n            elif layer_type == 'f':\n                out = block(x)\n\n            x = residual_fn(out, residual)\n\n            if layer_type in ('a', 'c'):\n                intermediates.append(inter)\n\n            if layer_type == 'a' and self.residual_attn:\n                prev_attn = inter.pre_softmax_attn\n            elif layer_type == 'c' and self.cross_residual_attn:\n                prev_cross_attn = inter.pre_softmax_attn\n\n            if not self.pre_norm and not is_last:\n                x = norm(x)\n\n        if return_hiddens:\n            intermediates = LayerIntermediates(\n                hiddens=hiddens, attn_intermediates=intermediates\n            )\n\n            return x, intermediates\n\n        return x\n\n\nclass Encoder(AttentionLayers):\n    def __init__(self, **kwargs):\n        assert 'causal' not in kwargs, 'cannot set causality on encoder'\n        super().__init__(causal=False, **kwargs)\n\n\nclass TransformerWrapper(nn.Module):\n    def __init__(\n        self,\n        *,\n        num_tokens,\n        max_seq_len,\n        attn_layers,\n        emb_dim=None,\n        max_mem_len=0.0,\n        emb_dropout=0.0,\n        num_memory_tokens=None,\n        tie_embedding=False,\n        use_pos_emb=True,\n    ):\n        super().__init__()\n        assert isinstance(\n            attn_layers, AttentionLayers\n        ), 'attention layers must be one of Encoder or Decoder'\n\n        dim = attn_layers.dim\n        emb_dim = default(emb_dim, dim)\n\n        self.max_seq_len = max_seq_len\n        self.max_mem_len = max_mem_len\n        self.num_tokens = num_tokens\n\n        self.token_emb = nn.Embedding(num_tokens, emb_dim)\n        self.pos_emb = (\n            AbsolutePositionalEmbedding(emb_dim, max_seq_len)\n            if (use_pos_emb and not attn_layers.has_pos_emb)\n            else always(0)\n        )\n        self.emb_dropout = nn.Dropout(emb_dropout)\n\n        self.project_emb = (\n            nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()\n        )\n        self.attn_layers = attn_layers\n        self.norm = nn.LayerNorm(dim)\n\n        self.init_()\n\n        self.to_logits = (\n            nn.Linear(dim, num_tokens)\n            if not tie_embedding\n            else lambda t: t @ self.token_emb.weight.t()\n        )\n\n        # memory tokens (like [cls]) from Memory Transformers paper\n        num_memory_tokens = default(num_memory_tokens, 0)\n        self.num_memory_tokens = num_memory_tokens\n        if num_memory_tokens > 0:\n            self.memory_tokens = nn.Parameter(\n                torch.randn(num_memory_tokens, dim)\n            )\n\n            # let funnel encoder know number of memory tokens, if specified\n            if hasattr(attn_layers, 'num_memory_tokens'):\n                attn_layers.num_memory_tokens = num_memory_tokens\n\n    def init_(self):\n        nn.init.normal_(self.token_emb.weight, std=0.02)\n\n    def forward(\n        self,\n        x,\n        return_embeddings=False,\n        mask=None,\n        return_mems=False,\n        return_attn=False,\n        mems=None,\n        embedding_manager=None,\n        **kwargs,\n    ):\n        b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens\n\n        embedded_x = self.token_emb(x)\n\n        if embedding_manager:\n            x = embedding_manager(x, embedded_x)\n        else:\n            x = embedded_x\n\n        x = x + self.pos_emb(x)\n        x = self.emb_dropout(x)\n\n        x = self.project_emb(x)\n\n        if num_mem > 0:\n            mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)\n            x = torch.cat((mem, x), dim=1)\n\n            # auto-handle masking after appending memory tokens\n            if exists(mask):\n                mask = F.pad(mask, (num_mem, 0), value=True)\n\n        x, intermediates = self.attn_layers(\n            x, mask=mask, mems=mems, return_hiddens=True, **kwargs\n        )\n        x = self.norm(x)\n\n        mem, x = x[:, :num_mem], x[:, num_mem:]\n\n        out = self.to_logits(x) if not return_embeddings else x\n\n        if return_mems:\n            hiddens = intermediates.hiddens\n            new_mems = (\n                list(\n                    map(\n                        lambda pair: torch.cat(pair, dim=-2),\n                        zip(mems, hiddens),\n                    )\n                )\n                if exists(mems)\n                else hiddens\n            )\n            new_mems = list(\n                map(\n                    lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems\n                )\n            )\n            return out, new_mems\n\n        if return_attn:\n            attn_maps = list(\n                map(\n                    lambda t: t.post_softmax_attn,\n                    intermediates.attn_intermediates,\n                )\n            )\n            return out, attn_maps\n\n        return out\n"
  },
  {
    "path": "src/stablediffusion/ldm/simplet2i.py",
    "content": "'''\nThis module is provided for backward compatibility with the\noriginal (hasty) API.\n\nPlease use ldm.generate instead.\n'''\n\nfrom src.stablediffusion.ldm.generate import Generate\n\nclass T2I(Generate):\n    def __init__(self,**kwargs):\n        print(f'>> The ldm.simplet2i module is deprecated. Use ldm.generate instead. It is a drop-in replacement.')\n        super().__init__(kwargs)\n"
  },
  {
    "path": "src/stablediffusion/ldm/util.py",
    "content": "import importlib\n\nimport torch\nimport numpy as np\nfrom collections import abc\nfrom einops import rearrange\nfrom functools import partial\n\nimport multiprocessing as mp\nfrom threading import Thread\nfrom queue import Queue\n\nfrom inspect import isfunction\nfrom PIL import Image, ImageDraw, ImageFont\n\n\ndef log_txt_as_img(wh, xc, size=10):\n    # wh a tuple of (width, height)\n    # xc a list of captions to plot\n    b = len(xc)\n    txts = list()\n    for bi in range(b):\n        txt = Image.new('RGB', wh, color='white')\n        draw = ImageDraw.Draw(txt)\n        font = ImageFont.load_default()\n        nc = int(40 * (wh[0] / 256))\n        lines = '\\n'.join(\n            xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)\n        )\n\n        try:\n            draw.text((0, 0), lines, fill='black', font=font)\n        except UnicodeEncodeError:\n            print('Cant encode string for logging. Skipping.')\n\n        txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0\n        txts.append(txt)\n    txts = np.stack(txts)\n    txts = torch.tensor(txts)\n    return txts\n\n\ndef ismap(x):\n    if not isinstance(x, torch.Tensor):\n        return False\n    return (len(x.shape) == 4) and (x.shape[1] > 3)\n\n\ndef isimage(x):\n    if not isinstance(x, torch.Tensor):\n        return False\n    return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)\n\n\ndef exists(x):\n    return x is not None\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\n\ndef mean_flat(tensor):\n    \"\"\"\n    https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86\n    Take the mean over all non-batch dimensions.\n    \"\"\"\n    return tensor.mean(dim=list(range(1, len(tensor.shape))))\n\n\ndef count_params(model, verbose=False):\n    total_params = sum(p.numel() for p in model.parameters())\n    if verbose:\n        print(\n            f'{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'\n        )\n    return total_params\n\n\ndef instantiate_from_config(config, **kwargs):\n    if not 'target' in config:\n        if config == '__is_first_stage__':\n            return None\n        elif config == '__is_unconditional__':\n            return None\n        raise KeyError('Expected key `target` to instantiate.')\n    return get_obj_from_str(config['target'])(\n        **config.get('params', dict()), **kwargs\n    )\n\n\ndef get_obj_from_str(string, reload=False):\n    module, cls = string.rsplit('.', 1)\n    if reload:\n        module_imp = importlib.import_module(module)\n        importlib.reload(module_imp)\n    return getattr(importlib.import_module(module, package=None), cls)\n\n\ndef _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):\n    # create dummy dataset instance\n\n    # run prefetching\n    if idx_to_fn:\n        res = func(data, worker_id=idx)\n    else:\n        res = func(data)\n    Q.put([idx, res])\n    Q.put('Done')\n\n\ndef parallel_data_prefetch(\n    func: callable,\n    data,\n    n_proc,\n    target_data_type='ndarray',\n    cpu_intensive=True,\n    use_worker_id=False,\n):\n    # if target_data_type not in [\"ndarray\", \"list\"]:\n    #     raise ValueError(\n    #         \"Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray.\"\n    #     )\n    if isinstance(data, np.ndarray) and target_data_type == 'list':\n        raise ValueError('list expected but function got ndarray.')\n    elif isinstance(data, abc.Iterable):\n        if isinstance(data, dict):\n            print(\n                f'WARNING:\"data\" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'\n            )\n            data = list(data.values())\n        if target_data_type == 'ndarray':\n            data = np.asarray(data)\n        else:\n            data = list(data)\n    else:\n        raise TypeError(\n            f'The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}.'\n        )\n\n    if cpu_intensive:\n        Q = mp.Queue(1000)\n        proc = mp.Process\n    else:\n        Q = Queue(1000)\n        proc = Thread\n    # spawn processes\n    if target_data_type == 'ndarray':\n        arguments = [\n            [func, Q, part, i, use_worker_id]\n            for i, part in enumerate(np.array_split(data, n_proc))\n        ]\n    else:\n        step = (\n            int(len(data) / n_proc + 1)\n            if len(data) % n_proc != 0\n            else int(len(data) / n_proc)\n        )\n        arguments = [\n            [func, Q, part, i, use_worker_id]\n            for i, part in enumerate(\n                [data[i : i + step] for i in range(0, len(data), step)]\n            )\n        ]\n    processes = []\n    for i in range(n_proc):\n        p = proc(target=_do_parallel_data_prefetch, args=arguments[i])\n        processes += [p]\n\n    # start processes\n    print(f'Start prefetching...')\n    import time\n\n    start = time.time()\n    gather_res = [[] for _ in range(n_proc)]\n    try:\n        for p in processes:\n            p.start()\n\n        k = 0\n        while k < n_proc:\n            # get result\n            res = Q.get()\n            if res == 'Done':\n                k += 1\n            else:\n                gather_res[res[0]] = res[1]\n\n    except Exception as e:\n        print('Exception: ', e)\n        for p in processes:\n            p.terminate()\n\n        raise e\n    finally:\n        for p in processes:\n            p.join()\n        print(f'Prefetching complete. [{time.time() - start} sec.]')\n\n    if target_data_type == 'ndarray':\n        if not isinstance(gather_res[0], np.ndarray):\n            return np.concatenate([np.asarray(r) for r in gather_res], axis=0)\n\n        # order outputs\n        return np.concatenate(gather_res, axis=0)\n    elif target_data_type == 'list':\n        out = []\n        for r in gather_res:\n            out.extend(r)\n        return out\n    else:\n        return gather_res\n"
  },
  {
    "path": "src/stablediffusion/text2image_compvis.py",
    "content": "import os\nimport torch\nimport numpy as np\nfrom PIL import Image\nfrom pytorch_lightning import seed_everything\nfrom torch import autocast\n\nfrom src.stablediffusion.ldm.generate import Generate\n\nimport uuid\nimport shutil\n\n# 0 = resize\n# 1 = crop and resize\n# 2 = resize and fill\ndef resize_image(resize_mode, im, width, height):\n    LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)\n    if resize_mode == 0:\n        res = im.resize((width, height), resample=LANCZOS)\n    elif resize_mode == 1:\n        ratio = width / height\n        src_ratio = im.width / im.height\n\n        src_w = width if ratio > src_ratio else im.width * height // im.height\n        src_h = height if ratio <= src_ratio else im.height * width // im.width\n\n        resized = im.resize((src_w, src_h), resample=LANCZOS)\n        res = Image.new(\"RGB\", (width, height))\n        res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))\n    else:\n        ratio = width / height\n        src_ratio = im.width / im.height\n\n        src_w = width if ratio < src_ratio else im.width * height // im.height\n        src_h = height if ratio >= src_ratio else im.height * width // im.width\n\n        resized = im.resize((src_w, src_h), resample=LANCZOS)\n        res = Image.new(\"RGB\", (width, height))\n        res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))\n\n        if ratio < src_ratio:\n            fill_height = height // 2 - src_h // 2\n            res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))\n            res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))\n        elif ratio > src_ratio:\n            fill_width = width // 2 - src_w // 2\n            res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))\n            res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))\n\n    return res\n\nclass Text2Image:\n    def __init__(self, model_path='models/model-epoch06-full.ckpt', use_gpu=True):\n        self.generator = Generate(weights=model_path, config='models/v1-inference.yaml')\n        try:\n            self.generator.load_model()\n        except:\n            import sys, traceback\n            traceback.print_exc(file=sys.stdout)\n        \n    def dream(self, prompt: str, ddim_steps: int, plms: bool, fixed_code: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int, progress: bool, sampler_name: str):\n        seed = seed_everything(seed)\n        id = str(uuid.uuid4())\n        results = self.generator.txt2img(prompt=prompt, iterations = 1, steps=ddim_steps, seed=seed, cfg_scale=cfg_scale, ddim_eta=ddim_eta, width=width, height=height, sampler_name=sampler_name, outdir='storage/outputs')\n        shutil.move(results[0][0], f'storage/outputs/{id}.png')\n        return [Image.open(f'storage/outputs/{id}.png')], results[0][1]\n    \n    def translation(self, prompt: str, init_img, ddim_steps: int, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, sampler_name: str):\n        seed = seed_everything(seed)\n        id = str(uuid.uuid4())\n        image = init_img.convert(\"RGB\")\n        image = resize_image(1, image, width, height)\n        image.save(f'storage/init/{id}.png')\n        results = self.generator.txt2img(prompt=prompt, iterations = 1, steps=ddim_steps, seed=seed, cfg_scale=cfg_scale, ddim_eta=ddim_eta, width=width, height=height, sampler_name=sampler_name, outdir='storage/outputs', init_img=f'storage/init/{id}.png', strength=denoising_strength)\n        shutil.move(results[0][0], f'storage/outputs/{id}.png')\n        return [Image.open(f'storage/outputs/{id}.png')], results[0][1]\n\n    def inpaint(self, prompt: str, init_img, mask_img, ddim_steps: int, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int):\n        seed = seed_everything(seed)\n        id = str(uuid.uuid4())\n        image = init_img.convert(\"RGB\")\n        image = resize_image(1, image, width, height)\n        image.save(f'storage/init/{id}.png')\n        image_mask = mask_image.convert(\"RGB\")\n        image_mask = resize_image(1, image_mask, width, height)\n        image_mask.save(f'storage/init/{id}-mask.png')\n        results = self.generator.txt2img(prompt=prompt, iterations = 1, steps=ddim_steps, seed=seed, cfg_scale=cfg_scale, ddim_eta=ddim_eta, width=width, height=height, sampler_name=sampler_name, outdir='storage/outputs', init_img=f'storage/init/{id}.png', init_mask=f'storage/init/{id}-mask.png', strength=denoising_strength)\n        shutil.move(results[0][0], f'storage/outputs/{id}.png')\n        return [Image.open(f'storage/outputs/{id}.png')], results[0][1]\n"
  },
  {
    "path": "src/stablediffusion/text2image_diffusers.py",
    "content": "import os\nimport torch\nimport numpy as np\nfrom PIL import Image\nfrom pytorch_lightning import seed_everything\nfrom torch import autocast\n\nfrom transformers import CLIPTextModel, CLIPTokenizer\nfrom diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler, StableDiffusionPipeline, DDIMScheduler, PNDMScheduler\n\nfrom src.stablediffusion.inpaint import StableDiffusionInpaintingPipeline, preprocess, preprocess_mask\nfrom src.stablediffusion.translation import StableDiffusionImg2ImgPipeline\nfrom src.stablediffusion.dream import StableDiffusionPipeline\n\n# 0 = resize\n# 1 = crop and resize\n# 2 = resize and fill\ndef resize_image(resize_mode, im, width, height):\n    LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)\n    if resize_mode == 0:\n        res = im.resize((width, height), resample=LANCZOS)\n    elif resize_mode == 1:\n        ratio = width / height\n        src_ratio = im.width / im.height\n\n        src_w = width if ratio > src_ratio else im.width * height // im.height\n        src_h = height if ratio <= src_ratio else im.height * width // im.width\n\n        resized = im.resize((src_w, src_h), resample=LANCZOS)\n        res = Image.new(\"RGB\", (width, height))\n        res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))\n    else:\n        ratio = width / height\n        src_ratio = im.width / im.height\n\n        src_w = width if ratio < src_ratio else im.width * height // im.height\n        src_h = height if ratio >= src_ratio else im.height * width // im.width\n\n        resized = im.resize((src_w, src_h), resample=LANCZOS)\n        res = Image.new(\"RGB\", (width, height))\n        res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))\n\n        if ratio < src_ratio:\n            fill_height = height // 2 - src_h // 2\n            res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))\n            res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))\n        elif ratio > src_ratio:\n            fill_width = width // 2 - src_w // 2\n            res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))\n            res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))\n\n    return res\n\nclass Text2Image:\n    def __init__(self, use_gpu=True):\n        self.device = torch.device('cuda' if use_gpu else 'cpu')\n        self.dtype = torch.float16 if use_gpu else torch.float32\n        model_name = 'CompVis/stable-diffusion-v1-4'\n        token = os.environ['HF_TOKEN']\n        \n        self.vae = AutoencoderKL.from_pretrained(model_name, subfolder='vae', revision=\"fp16\", use_auth_token=token)\n        self.unet = UNet2DConditionModel.from_pretrained(model_name, subfolder=\"unet\", revision=\"fp16\", use_auth_token=token)\n        self.tokenizer = CLIPTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\")\n        self.text_encoder = CLIPTextModel.from_pretrained(\"openai/clip-vit-large-patch14\")\n\n        self.scheduler = LMSDiscreteScheduler(\n            beta_start=0.00085, \n            beta_end=0.012, \n            beta_schedule=\"scaled_linear\", \n            num_train_timesteps=1000\n        )\n\n        self.img2img_scheduler = PNDMScheduler(\n            beta_start=0.00085,\n            beta_end=0.012, \n            beta_schedule=\"scaled_linear\",\n            num_train_timesteps=1000,\n            skip_prk_steps=True\n        )\n\n        self.vae = self.vae.to(self.dtype).eval().to(self.device)\n        self.text_encoder = self.text_encoder.to(self.dtype).eval().to(self.device)\n        self.unet = self.unet.to(self.dtype).eval().to(self.device)\n\n        self.inpaint_pipe = StableDiffusionInpaintingPipeline(\n            self.vae,\n            self.text_encoder,\n            self.tokenizer,\n            self.unet,\n            self.img2img_scheduler\n        )\n        \n        self.dream_pipe = StableDiffusionPipeline(\n            self.vae,\n            self.text_encoder,\n            self.tokenizer,\n            self.unet,\n            self.scheduler\n        )\n\n        self.translation_pipe = StableDiffusionImg2ImgPipeline(\n            self.vae,\n            self.text_encoder,\n            self.tokenizer,\n            self.unet,\n            self.img2img_scheduler\n        )\n        \n    def dream(self, prompt: str, ddim_steps: int, plms: bool, fixed_code: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int, progress: bool):\n        rng_seed = seed_everything(seed)\n\n        with autocast('cuda'):\n            image = self.dream_pipe(prompt, height=height, width=width, guidance_scale=cfg_scale, eta=ddim_eta, num_inference_steps=ddim_steps, progress=progress)['sample']\n\n        return image, rng_seed\n    \n    def translation(self, prompt: str, init_img, ddim_steps: int, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int):\n        rng_seed = seed_everything(seed)\n    \n        image = init_img.convert(\"RGB\")\n        image = resize_image(1, image, width, height)\n        image = np.array(image).astype(np.float32) / 255.0\n        image = image[None].transpose(0, 3, 1, 2)\n        image = torch.from_numpy(image)\n        image = 2.0 * image - 1.0\n\n        with autocast('cuda'):\n            image = self.translation_pipe(prompt, image, denoising_strength, ddim_steps, cfg_scale, ddim_eta, None, 'pil')['sample']\n\n        return image, rng_seed\n\n    def inpaint(self, prompt: str, init_img, mask_img, ddim_steps: int, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int):\n        rng_seed = seed_everything(seed)\n\n        init_img = resize_image(1, init_img, width, height)\n\n#        mask = np.array(init_img.convert('RGBA').split()[-1])\n#        mask = Image.fromarray(mask)\n\n        init_img_tensor = preprocess(init_img.convert('RGB'))\n\n        with autocast('cuda'):\n            image = self.inpaint_pipe(prompt, init_img_tensor, mask_img, denoising_strength, ddim_steps, cfg_scale, ddim_eta, None, 'pil')['sample']\n\n        return image, rng_seed\n\n    @torch.no_grad()\n    def vae_test(self, image, height: int, width: int):\n        image = image.convert(\"RGB\")\n        image = resize_image(1, image, width, height)\n        image = np.array(image).astype(np.float32) / 255.0\n        image = image[None].transpose(0, 3, 1, 2)\n        image = torch.from_numpy(image)\n        image = 2.0 * image - 1.0\n\n        with autocast('cuda'):\n            latent_image = self.vae.decode(self.vae.encode(image.to(self.device)).sample())\n            latent_image = (latent_image / 2 + 0.5).clamp(0, 1)\n            latent_image = latent_image.cpu().permute(0, 2, 3, 1).numpy()\n\n        if latent_image.ndim == 3:\n            latent_image = latent_image[None, ...]\n        latent_image = (latent_image * 255).round().astype('uint8')\n        latent_image = [Image.fromarray(image) for image in latent_image]\n\n        return latent_image\n"
  },
  {
    "path": "src/stablediffusion/translation.py",
    "content": "import inspect\nfrom typing import List, Optional, Union\n\nimport numpy as np\nimport torch\n\nimport PIL\nfrom diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, PNDMScheduler, UNet2DConditionModel\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer\n\n\ndef preprocess(image):\n    w, h = image.size\n    w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32\n    image = image.resize((w, h), resample=PIL.Image.LANCZOS)\n    image = np.array(image).astype(np.float32) / 255.0\n    image = image[None].transpose(0, 3, 1, 2)\n    image = torch.from_numpy(image)\n    return 2.0 * image - 1.0\n\n\nclass StableDiffusionImg2ImgPipeline(DiffusionPipeline):\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: Union[DDIMScheduler, PNDMScheduler],\n    ):\n        super().__init__()\n        scheduler = scheduler.set_format(\"pt\")\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler\n        )\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]],\n        init_image: torch.FloatTensor,\n        strength: float = 0.8,\n        num_inference_steps: Optional[int] = 50,\n        guidance_scale: Optional[float] = 7.5,\n        eta: Optional[float] = 0.0,\n        generator: Optional[torch.Generator] = None,\n        output_type: Optional[str] = \"pil\",\n    ):\n\n        if isinstance(prompt, str):\n            batch_size = 1\n        elif isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if strength < 0 or strength > 1:\n            raise ValueError(f\"The value of strength should in [0.0, 1.0] but is {strength}\")\n\n        # set timesteps\n        accepts_offset = \"offset\" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())\n        extra_set_kwargs = {}\n        offset = 0\n        if accepts_offset:\n            offset = 1\n            extra_set_kwargs[\"offset\"] = 1\n\n        self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)\n\n        # encode the init image into latents and scale the latents\n        init_latents = self.vae.encode(init_image.to(self.device)).sample()\n        init_latents = 0.18215 * init_latents\n\n        # prepare init_latents noise to latents\n        init_latents = torch.cat([init_latents] * batch_size)\n\n        # get the original timestep using init_timestep\n        init_timestep = int(num_inference_steps * strength) + offset\n        init_timestep = min(init_timestep, num_inference_steps)\n        timesteps = self.scheduler.timesteps[-init_timestep]\n        timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)\n\n        # add noise to latents using the timesteps\n        noise = torch.randn(init_latents.shape, generator=generator, device=self.device)\n        init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)\n\n        # get prompt text embeddings\n        text_input = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance:\n            max_length = text_input.input_ids.shape[-1]\n            uncond_input = self.tokenizer(\n                [\"\"] * batch_size, padding=\"max_length\", max_length=max_length, return_tensors=\"pt\"\n            )\n            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502\n        # and should be between [0, 1]\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        latents = init_latents\n        t_start = max(num_inference_steps - init_timestep + offset, 0)\n        for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):\n            # expand the latents if we are doing classifier free guidance\n            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n\n            # predict the noise residual\n            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)[\"sample\"]\n\n            # perform guidance\n            if do_classifier_free_guidance:\n                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n            # compute the previous noisy sample x_t -> x_t-1\n            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)[\"prev_sample\"]\n\n        # scale and decode the image latents with vae\n        latents = 1 / 0.18215 * latents\n        image = self.vae.decode(latents)\n\n        image = (image / 2 + 0.5).clamp(0, 1)\n        image = image.cpu().permute(0, 2, 3, 1).numpy()\n\n        if output_type == \"pil\":\n            image = self.numpy_to_pil(image)\n\n        return {\"sample\": image}\n"
  },
  {
    "path": "storage/init/.keep",
    "content": ""
  },
  {
    "path": "storage/outputs/.keep",
    "content": ""
  },
  {
    "path": "win10fix.bat",
    "content": "python src\\scripts\\win10patch.py"
  }
]